#!/usr/bin/env python
# -*- coding: utf-8 -*-
import copy
import inspect
import string
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
__all__ = ['plot_chains', 'noise_flower']
[docs]def plot_chains(core, hist=True, pars=None, exclude=None,
ncols=3, bins=40, suptitle=None, color='k',
publication_params=False, titles=None,
linestyle=None, plot_map=False, truths=None,
save=False, show=True, linewidth=1,
log=False, title_y=1.01, hist_kwargs={},
plot_kwargs={}, legend_labels=None, real_tm_pars=True,
legend_loc=None, **kwargs):
"""Function to plot histograms or traces of chains from cores.
Parameters
----------
core : {`la_forge.core.Core`,
`la_forge.core.HyperModelCore`,
`la_forge.core.TimingCore`,
`la_forge.slices.SlicedCore`}
hist : bool, optional
Whether to plot histograms. If False then traces of the chains will be
plotted.
pars : list of str, optional
List of the parameters to be plotted.
exclude : list of str, optional
List of the parameters to be excluded from plot.
ncols : int, optional
Number of columns of subplots to use.
bins : int, optional
Number of bins to use in histograms.
suptitle : str, optional
Title to use for the plots.
color : str or list of str, optional
Color to use for histograms.
publication_params=False,
titles=None,
linestyle : str,
plot_map=False,
save=False,
show=True,
linewidth=1,
log=False,
title_y=1.01,
hist_kwargs={},
plot_kwargs={},
legend_labels=None,
legend_loc=None,
"""
if pars is not None:
params = pars
elif exclude is not None and pars is not None:
raise ValueError('Please remove excluded parameters from `pars`.')
elif exclude is not None:
if isinstance(core, list):
params = set()
for c in core:
params.intersection_update(c.params)
else:
params = core.params
params = list(params)
for p in exclude:
params.remove(p)
elif pars is None and exclude is None:
if isinstance(core, list):
params = core[0].params
for c in core[1:]:
params = [p for p in params if p in c.params]
else:
params = core.params
if isinstance(core, list):
fancy_par_names=core[0].fancy_par_names
if linestyle is None:
linestyle = ['-' for ii in range(len(core))]
if isinstance(plot_map, list):
pass
else:
plot_map = [plot_map for ii in range(len(core))]
else:
fancy_par_names=core.fancy_par_names
L = len(params)
if suptitle is None:
psr_name = copy.deepcopy(params[0])
if psr_name[0] == 'B':
psr_name = psr_name[:8]
elif psr_name[0] == 'J':
psr_name = psr_name[:10]
else:
psr_name = None
nrows = int(L // ncols)
if L %ncols > 0:
nrows +=1
if publication_params:
fig = plt.figure()
else:
fig = plt.figure(figsize=[15, 4*nrows])
for ii, p in enumerate(params):
cell = ii+1
axis = fig.add_subplot(nrows, ncols, cell)
if hist:
if isinstance(core, list):
for jj, c in enumerate(core):
gpar_kwargs= _get_gpar_kwargs(c, real_tm_pars)
phist=plt.hist(c.get_param(p, **gpar_kwargs),
bins=bins, density=True, log=log,
linewidth=linewidth,
linestyle=linestyle[jj],
histtype='step', **hist_kwargs)
if plot_map[jj]:
pcol=phist[-1][-1].get_edgecolor()
plt.axvline(c.get_map_param(p), linewidth=1,
color=pcol, linestyle='--')
if truths is not None:
pcol=phist[-1][-1].get_edgecolor()
plt.axvline(truths[ii], linewidth=1,
color=pcol, linestyle='-')
else:
gpar_kwargs= _get_gpar_kwargs(core, real_tm_pars)
phist=plt.hist(core.get_param(p, **gpar_kwargs),
bins=bins, density=True, log=log,
linewidth=linewidth,
histtype='step', **hist_kwargs)
if plot_map:
pcol=phist[-1][-1].get_edgecolor()
plt.axvline(c.get_map_param(p), linewidth=1,
color=pcol, linestyle='--')
if truths is not None:
pcol=phist[-1][-1].get_edgecolor()
plt.axvline(truths[0], linewidth=1,
color=pcol, linestyle='-')
else:
gpar_kwargs= _get_gpar_kwargs(core, real_tm_pars)
plt.plot(core.get_param(p, to_burn=True, **gpar_kwargs),
lw=linewidth, **plot_kwargs)
if (titles is None) and (fancy_par_names is None):
if psr_name is not None:
par_name = p.replace(psr_name+'_', '')
else:
par_name = p
axis.set_title(par_name)
elif titles is not None:
axis.set_title(titles[ii])
elif fancy_par_names is not None:
axis.set_title(fancy_par_names[ii])
axis.set_yticks([])
xticks = kwargs.get('xticks')
if xticks is not None:
axis.set_xticks(xticks)
if suptitle is None:
guess_times = np.array([psr_name in p for p in params], dtype=int)
yes = np.sum(guess_times)
if yes/guess_times.size > 0.5:
suptitle = 'PSR {0} Noise Parameters'.format(psr_name)
else:
suptitle = 'Parameter Posteriors '
if legend_labels is not None:
patches = []
colors = ['C{0}'.format(ii) for ii in range(len(legend_labels))]
for ii, lab in enumerate(legend_labels):
patches.append(mpatches.Patch(color=colors[ii], label=lab))
fig.legend(handles=patches, loc=legend_loc)
fig.tight_layout(pad=0.4)
fig.suptitle(suptitle, y=title_y, fontsize=18)
# fig.subplots_adjust(top=0.96)
xlabel = kwargs.get('xlabel')
if xlabel is not None:
fig.text(0.5, -0.02, xlabel, ha='center', usetex=False)
if save:
plt.savefig(save, dpi=150, bbox_inches='tight')
if show:
plt.show()
plt.close()
[docs]def noise_flower(hmc,
colLabels=['Add', 'Your', 'Noise'],
cellText=[['Model', 'Labels', 'Here']],
colWidths=None,
psrname=None, norm2max=False,
show=True, plot_path=None):
"""
Parameters
----------
hmc : la_forge.core.HyperModelCore
colLabels : list, optional
Table column headers for legend.
cellText : nested list, 2d array, optional
Table entries. Column number must match `colLabels`.
psrname : str, optional
Name of pulsar. Only used in making the title of the plot.
key : list of str, optional
Labels for each of the models in the selection process.
norm2max : bool, optional
Whether to normalize the values to the maximum `nmodel` residency.
show : bool, optional
Whether to show the plot.
plot_path : str
Enter a file path to save the plot to file.
"""
# Number of models
nmodels = hmc.nmodels
if psrname is None:
pos_names = [p.split('_')[0] for p in hmc.params
if p.split('_')[0][0] in ['B', 'J']]
psrname = pos_names[0]
# Label dictionary
mod_letter_dict = dict(zip(range(1, 27), string.ascii_uppercase))
mod_letters = [mod_letter_dict[ii+1] for ii in range(nmodels)]
mod_index = np.arange(nmodels)
# Histogram
n, _ = np.histogram(hmc.get_param('nmodel', to_burn=True),
bins=np.linspace(-0.5, nmodels-0.5, nmodels+1),
density=True)
if norm2max:
n /= n.max()
fig = plt.figure(figsize=[8, 4])
ax = fig.add_subplot(121, polar=True)
bars = ax.bar(2.0 * np.pi * mod_index / nmodels, n,
width=0.9 * 2 * np.pi / nmodels,
bottom=np.sort(n)[1]/2.)
# Use custom colors and opacity
for r, bar in zip(n, bars):
bar.set_facecolor(plt.cm.Blues(r / 1.))
# Pretty formatting
ax.set_xticks(np.linspace(0., 2 * np.pi, nmodels+1)[:-1])
labels=[ii + '=' + str(round(jj, 2)) for ii, jj in zip(mod_letters, n)]
ax.set_xticklabels(labels, fontsize=11, rotation=0, color='grey')
ax.grid(alpha=0.4)
ax.tick_params(labelsize=10, labelcolor='k')
ax.set_yticklabels([])
plt.box(on=None)
ax2 = fig.add_subplot(122)
ax2.xaxis.set_visible(False)
ax2.yaxis.set_visible(False)
table = ax2.table(cellText=cellText,
colLabels=colLabels,
colWidths=colWidths,
loc='center')
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.05, 1.05)
plt.box(on=None)
ax2.set_title('PSR ' + psrname + '\n Noise Model Selection',
color='k', y=0.8, fontsize=13,
bbox=dict(facecolor='C3', edgecolor='k', alpha=0.2))
if plot_path is not None:
plt.savefig(plot_path, bbox_inches='tight', dpi=150)
if show:
plt.show()
def _get_gpar_kwargs(core, real_tm_pars):
'''
Convenience function to return a kwargs dictionary if their is a call
to convert timing parameters.
'''
if 'tm_convert'in inspect.getfullargspec(core.get_param)[0]:
gpar_kwargs = {'tm_convert': real_tm_pars}
else:
gpar_kwargs = {}
return gpar_kwargs