Commit bc177fac authored by Connor Hainje's avatar Connor Hainje
Browse files

Update 'pidplots/globbal' notebook and pidplots package to work after rrefactor

parent d29d4ecc
This diff is collapsed.
This diff is collapsed.
......@@ -88,7 +88,9 @@ def subplot_blame_2D(ax, df, detector="max", color_list=None, cell_fontsize=10):
idx = np.argmin(blames, 2) if detector == "min" else np.argmax(blames, 2)
val = np.take_along_axis(blames, np.expand_dims(idx, 2), 2).squeeze()
vmax = len(const.DETECTORS) - 0.5
im = ax.imshow(idx, cmap=cmap, vmin=-0.5, vmax=vmax, aspect="auto", origin="lower")
im = ax.imshow(
idx, cmap=cmap, vmin=-0.5, vmax=vmax, aspect="auto", origin="lower"
)
connected_region_borders(ax, idx)
elif detector in const.DETECTORS:
......@@ -101,12 +103,12 @@ def subplot_blame_2D(ax, df, detector="max", color_list=None, cell_fontsize=10):
raise ValueError("detector not understood")
if cell_fontsize:
text_array_values(ax, val, fontsize=cell_fontsize)
text_array_values(ax, val, fmt="{:.0%}", fontsize=cell_fontsize, bbox_alpha=0)
return im
def blame_2D_legend(fig, color_list=None, label_list=None, title=None):
def blame_2D_legend(fig, x=0.98, color_list=None, label_list=None, title=None):
from matplotlib.patches import Patch
label_list = _fill_label_list(label_list)
......@@ -119,7 +121,7 @@ def blame_2D_legend(fig, color_list=None, label_list=None, title=None):
loc="right",
ncol=1,
title=title,
x=0.98,
x=x,
)
......@@ -138,3 +140,31 @@ def plot_blame_2D(df, detector="max", title=None, figsize=None, cell_fontsize=10
blame_2D_legend(fig, title=title)
return fig
def plot_blame_2D_by_particle(
df, detector="max", title=None, figsize=(12, 8), cell_fontsize=10
):
fig, axs = plt.subplots(2, 3, figsize=figsize)
for i, (ax, p_label) in enumerate(zip(axs.flat, const.PART_PLOT_LABELS)):
im = subplot_blame_2D(
ax,
df.loc[df["labels"] == i],
detector=detector,
cell_fontsize=cell_fontsize,
)
ax.set_title(f"True {p_label}", fontsize=16)
imshow_2D_axis_labels(ax)
fig.tight_layout()
better_xlabel(fig, title, loc="top")
if detector in const.DETECTORS:
cb = fig.colorbar(im)
cb.set_label(f"{detector} blame")
else:
title = ("Highest" if detector == "max" else "Lowest") + "\nblame"
blame_2D_legend(fig, x=1.0, title=title)
return fig
......@@ -99,6 +99,46 @@ def plot_avg_contribution(df, title=None, color_list=None, figsize=None):
return fig
def subplot_avg_contribution_by_particle(ax, df, color_list=None, legend_loc="best"):
color_list = colors.fill_color_list(color_list)
x_vals = np.arange(len(const.DETECTORS))
ctrb_cols = [f"contrib_{det}" for det in const.DETECTORS]
ctrbs = np.array(
[
np.mean(df.loc[df["labels"] == i, ctrb_cols].values, axis=0)
for i in range(len(const.PARTICLES))
]
).T
n_series = len(const.DETECTORS)
for i in range(n_series):
w = 0.8
adj = -(w / 2) + (w / 2 / n_series) + i * (w / n_series)
ax.bar(
x_vals + adj,
ctrbs[i],
width=w / n_series,
color=color_list[i],
label=const.DETECTORS[i],
)
ax.legend(frameon=False, loc=legend_loc, ncol=3)
ax.set_xticks(x_vals)
ax.set_xticklabels(const.PART_PLOT_LABELS)
ax.set_xlim(x_vals[0] - 0.5, x_vals[-1] + 0.5)
def plot_avg_contribution_by_particle(df, title=None, color_list=None, figsize=None):
fig, ax = plt.subplots(figsize=figsize)
subplot_avg_contribution_by_particle(ax, df, color_list=color_list)
ax.set_xlabel("Particle type")
ax.set_ylabel("Average contribution value")
ax.set_title(title)
fig.tight_layout()
return fig
def subplot_avg_contribution_1D(ax, df, bin_by="p", ls="-", color_list=None):
if bin_by == "p":
ctrbs = df.groupby(["p_bin"]).mean()
......@@ -163,7 +203,12 @@ def subplot_avg_contribution_2D(ax, df, det="max", color_list=None, cell_fontsiz
# TODO: make sure vlims are correct
im = ax.imshow(
idx, cmap=cmap, vmin=-0.5, vmax=len(const.DETECTORS) - 0.5, aspect="auto", origin="lower"
idx,
cmap=cmap,
vmin=-0.5,
vmax=len(const.DETECTORS) - 0.5,
aspect="auto",
origin="lower",
)
# add cell values
......
......@@ -15,6 +15,13 @@ from .utils import (
# *** COMPUTE NICETIES ***
def compute_num(d):
out = np.zeros(len(const.PARTICLES))
for h in range(len(const.PARTICLES)):
out[h] = np.count_nonzero(d["labels"].values == h)
return out
def compute_efficiency(d):
out = np.zeros(len(const.PARTICLES))
for h in range(len(const.PARTICLES)):
......@@ -66,6 +73,10 @@ def _compute_1D(compute_fn, df, bin_by="p"):
return vals
def compute_num_1D(df, bin_by="p"):
return _compute_1D(compute_num, df, bin_by=bin_by)
def compute_efficiency_1D(df, bin_by="p"):
return _compute_1D(compute_efficiency, df, bin_by=bin_by)
......@@ -86,6 +97,10 @@ def _compute_2D(compute_fn, df):
return vals
def compute_num_2D(df):
return _compute_2D(compute_num, df)
def compute_efficiency_2D(df):
return _compute_2D(compute_efficiency, df)
......@@ -183,7 +198,13 @@ def plot_auc_1D(df, bin_by="p", figsize=None, title=None):
def _subplot_2D(
ax, data, particle="max", cell_fontsize=10, color_list=None, vlims=None
ax,
data,
particle="max",
cell_fontsize=10,
color_list=None,
vlims=None,
num=None,
):
if particle == "min" or particle == "max":
cmap = colors.make_colormap(color_list, detectors=False)
......@@ -199,6 +220,14 @@ def _subplot_2D(
)
connected_region_borders(ax, idx)
elif particle == "avg":
cmap = "inferno"
val = np.average(data, axis=2, weights=num)
vmin, vmax = 0, 1
im = ax.imshow(
val, cmap=cmap, vmin=vmin, vmax=vmax, aspect="auto", origin="lower"
)
elif particle in const.PARTICLES:
index = const.PARTICLES.index(particle)
color_list = colors.fill_color_list(color_list, detectors=False)
......@@ -227,6 +256,7 @@ def subplot_efficiency_2D(
ax, df, particle="max", color_list=None, cell_fontsize=10, vlims=None
):
effs = compute_efficiency_2D(df)
num = compute_num_2D(df) if particle == "avg" else None
return _subplot_2D(
ax,
effs,
......@@ -234,6 +264,7 @@ def subplot_efficiency_2D(
color_list=color_list,
cell_fontsize=cell_fontsize,
vlims=vlims,
num=num,
)
......@@ -241,6 +272,7 @@ def subplot_purity_2D(
ax, df, particle="max", color_list=None, cell_fontsize=10, vlims=None
):
purs = compute_purity_2D(df)
num = compute_num_2D(df) if particle == "avg" else None
return _subplot_2D(
ax,
purs,
......@@ -248,6 +280,7 @@ def subplot_purity_2D(
color_list=color_list,
cell_fontsize=cell_fontsize,
vlims=vlims,
num=num,
)
......@@ -255,6 +288,7 @@ def subplot_auc_2D(
ax, df, particle="max", color_list=None, cell_fontsize=10, vlims=None
):
aucs = compute_auc_2D(df)
num = compute_num_2D(df) if particle == "avg" else None
return _subplot_2D(
ax,
aucs,
......@@ -262,6 +296,7 @@ def subplot_auc_2D(
color_list=color_list,
cell_fontsize=cell_fontsize,
vlims=vlims,
num=num,
)
......@@ -332,6 +367,19 @@ def plot_efficiency_2D(df, particle="max", figsize=None, title=None, cell_fontsi
cell_fontsize=cell_fontsize,
)
elif particle == "avg":
label = "Overall accuracy"
return _plot_2D(
subplot_efficiency_2D,
df,
particle=particle,
figsize=figsize,
title=title,
legend_or_cbar="cbar",
legend_cb_label=label,
cell_fontsize=cell_fontsize,
)
else:
raise ValueError("argument particle not understood")
......@@ -365,6 +413,19 @@ def plot_purity_2D(df, particle="max", figsize=None, title=None, cell_fontsize=1
cell_fontsize=cell_fontsize,
)
elif particle == "avg":
label = "Average purity"
return _plot_2D(
subplot_purity_2D,
df,
particle=particle,
figsize=figsize,
title=title,
legend_or_cbar="cbar",
legend_cb_label=label,
cell_fontsize=cell_fontsize,
)
else:
raise ValueError("argument particle not understood")
......@@ -398,5 +459,18 @@ def plot_auc_2D(df, particle="max", figsize=None, title=None, cell_fontsize=10):
cell_fontsize=cell_fontsize,
)
elif particle == "avg":
label = "Average AUC"
return _plot_2D(
subplot_auc_2D,
df,
particle=particle,
figsize=figsize,
title=title,
legend_or_cbar="cbar",
legend_cb_label=label,
cell_fontsize=cell_fontsize,
)
else:
raise ValueError("argument particle not understood")
......@@ -59,3 +59,43 @@ def plot_likelihood_ratio(
bin_lims=bin_lims,
n_bins=n_bins,
)
def plot_likelihood_ratio_by_particle(
df,
particle="all",
figsize=(16, 8),
yscale="log",
bin_lims=(0, 1),
n_bins=30,
):
from matplotlib.pyplot import subplots
from .utils import better_xlabel, outer_legend
if particle in ["all", "min", "max"]:
if particle == "all":
xlabel = "Likelihood ratios"
elif particle == "min":
xlabel = "Smallest likelihood ratio"
elif particle == "max":
xlabel = "Largest likelihood ratio"
else:
raise ValueError("argument particle not understood")
fig, axs = subplots(2, 3, figsize=figsize)
for i, (ax, p_label) in enumerate(zip(axs.flat, const.PART_PLOT_LABELS)):
subplot_likelihood_ratio(
ax,
df.loc[df["labels"] == i],
particle=particle,
bin_lims=bin_lims,
n_bins=n_bins,
)
ax.set_title(f"True {p_label}", fontsize=16)
ax.set_yscale(yscale)
fig.tight_layout()
better_xlabel(fig, xlabel, fontsize=16)
outer_legend(fig, const.PART_PLOT_LABELS, loc="top", fontsize=16)
return fig
......@@ -50,7 +50,7 @@ def add_confusion_axis_labels(ax):
ax.set_ylabel("True", fontsize=12)
def add_p_theta_axis_labels(ax):
def add_p_theta_axis_labels(ax, fontsize=12):
"""Sets the axis ticks, limits, and labels with the standard bins for theta
on the x-axis and p on the y-axis.
......@@ -60,15 +60,15 @@ def add_p_theta_axis_labels(ax):
# theta bins
ax.set_xticks(np.arange(len(const.THETA_BINS)) - 0.5)
ax.set_xticklabels(const.THETA_BIN_LABELS, fontsize=12)
ax.set_xticklabels(const.THETA_BIN_LABELS, fontsize=fontsize)
ax.set_xlim(-0.5, const.N_THETA_BINS - 0.5)
ax.set_xlabel(f"$\\theta$ [{const.THETA_BIN_UNIT}]", fontsize=12)
ax.set_xlabel(f"$\\theta$ [{const.THETA_BIN_UNIT}]", fontsize=fontsize)
# momentum bins
ax.set_yticks(np.arange(len(const.P_BINS)) - 0.5)
ax.set_yticklabels(const.P_BIN_LABELS, fontsize=12)
ax.set_yticklabels(const.P_BIN_LABELS, fontsize=fontsize)
ax.set_ylim(const.N_P_BINS - 0.5, -0.5)
ax.set_ylabel(f"$p$ [{const.P_BIN_UNIT}]", fontsize=12)
ax.set_ylabel(f"$p$ [{const.P_BIN_UNIT}]", fontsize=fontsize)
def add_cell_values(ax, x, fontsize=10):
......@@ -251,6 +251,7 @@ def _in_all_bins(
df,
xlabel=None,
ylabel=None,
label_fontsize=12,
legend_labels=[],
figsize=(16, 12),
xticks=True,
......@@ -471,7 +472,14 @@ def subplot_top_wrong(
cmap = colors.make_colormap(color_list, detectors=False)
wrong, freq = compute_top_wrong_in_all_bins(df)
ax.imshow(wrong, cmap=cmap, vmin=0, vmax=len(const.PARTICLES), origin="lower")
ax.imshow(
wrong,
cmap=cmap,
vmin=-0.5,
vmax=len(const.PARTICLES) - 0.5,
origin="lower",
alpha=0.7,
)
hatches = [".", "-", "/", "|", "\\", "x"]
ec = (0, 0, 0, hatch_alpha)
......@@ -500,7 +508,7 @@ def subplot_top_wrong(
# ax.text(c, r, lab, **text_kw)
connected_region_borders(ax, wrong)
text_array_values(ax, freq, fontsize=celltextsize)
text_array_values(ax, freq, fmt="{:.0%}", bbox_alpha=0, fontsize=celltextsize)
add_p_theta_axis_labels(ax)
if legend:
......
......@@ -155,11 +155,13 @@ def outer_legend(
)
def text_array_values(ax, arr, fmt="{:.2f}", fontsize=10):
text_kw = dict(ha="center", va="center", fontsize=fontsize)
def text_array_values(ax, arr, fmt="{:.2f}", fontsize=10, bbox_alpha=0.7):
bbox_kw = dict(lw=0, fc="white", alpha=bbox_alpha)
text_kw = dict(ha="center", va="center", fontsize=fontsize, bbox=bbox_kw)
for i in range(arr.shape[0]):
for j in range(arr.shape[1]):
ax.text(j, i, fmt.format(arr[i, j]), **text_kw)
txt = "N/A" if np.isnan(arr[i, j]) else fmt.format(arr[i, j])
ax.text(j, i, txt, **text_kw)
def connected_region_borders(ax, arr, line_fmt="k-", line_width=1.5):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment