Commit ca5458b6 authored by connor.hainje@pnnl.gov's avatar connor.hainje@pnnl.gov
Browse files

Refactor efficiency, purity plotting methods to better match contribution 1D, 2D methods

parent c2a6aa22
import numpy as np
import matplotlib.pyplot as plt
from . import const, colors
from .utils import (
_ax_xlabel,
better_xlabel,
connected_region_borders,
imshow_2D_axis_labels,
outer_legend,
text_array_values,
xticks_radians_to_degrees,
)
# *** COMPUTE NICETIES ***
def compute_efficiency(d):
out = np.zeros(len(const.PARTICLES))
for h in range(len(const.PARTICLES)):
d_h = d.loc[d["labels"] == h]
if len(d_h) == 0:
continue
out[h] = np.count_nonzero(d_h["pid"] == h) / len(d_h)
return out
def compute_purity(d):
out = np.zeros(len(const.PARTICLES))
for h in range(len(const.PARTICLES)):
d_h = d.loc[d["pid"] == h]
if len(d_h) == 0:
continue
out[h] = np.count_nonzero(d_h["labels"] == h) / len(d_h)
return out
def _compute_1D(compute_fn, df, bin_by="p"):
if bin_by == "p":
gb = df.groupby(["p_bin"])
vals = np.zeros((const.N_P_BINS, len(const.PARTICLES)))
elif bin_by == "theta":
gb = df.groupby(["theta_bin"])
vals = np.zeros((const.N_THETA_BINS, len(const.PARTICLES)))
else:
raise ValueError()
for i, g in gb:
vals[i] = compute_fn(g)
return vals
def compute_efficiency_1D(df, bin_by="p"):
return _compute_1D(compute_efficiency, df, bin_by=bin_by)
def compute_purity_1D(df, bin_by="p"):
return _compute_1D(compute_purity, df, bin_by=bin_by)
def _compute_2D(compute_fn, df):
vals = np.zeros((const.N_P_BINS, const.N_THETA_BINS, len(const.PARTICLES)))
gb = df.groupby(["p_bin", "theta_bin"])
for (i, j), g in gb:
vals[i, j] = compute_fn(g)
return vals
def compute_efficiency_2D(df):
return _compute_2D(compute_efficiency, df)
def compute_purity_2D(df):
return _compute_2D(compute_purity, df)
# *** 1D PLOT METHODS ***
def _subplot_1D(ax, data, bin_by="p", ls="-", color_list=None):
color_list = colors.fill_color_list(color_list, detectors=False)
if bin_by == "p":
bin_centers = const.P_BIN_CENTERS
elif bin_by == "theta":
bin_centers = const.THETA_BIN_CENTERS
else:
raise ValueError("invalid bin_by")
for i in range(len(const.PARTICLES)):
color = color_list[i]
label = const.PART_PLOT_LABELS[i]
vals = data[:, i]
ax.plot(bin_centers, vals, c=color, ls=ls, label=label)
def subplot_efficiency_1D(ax, df, bin_by="p", color_list=None, ls="-"):
effs = compute_efficiency_1D(df, bin_by=bin_by)
_subplot_1D(ax, effs, bin_by=bin_by, color_list=color_list, ls=ls)
def subplot_purity_1D(ax, df, bin_by="p", color_list=None, ls="-"):
purs = compute_purity_1D(df, bin_by=bin_by)
_subplot_1D(ax, purs, bin_by=bin_by, color_list=color_list, ls=ls)
def _plot_1D(subplot_fn, df, bin_by="p", figsize=None, ylabel=None, title=None):
fig, ax = plt.subplots(figsize=figsize)
subplot_fn(ax, df, bin_by=bin_by)
fig.tight_layout()
_ax_xlabel(ax, bin_by)
if bin_by == "theta":
xticks_radians_to_degrees(ax)
ax.set_ylabel(ylabel)
outer_legend(fig, loc="top", ncol=len(const.PARTICLES), y=0.98)
better_xlabel(fig, title, loc="top", y=1.08)
return fig
def plot_efficiency_1D(df, bin_by="p", figsize=None, title=None):
return _plot_1D(
subplot_efficiency_1D,
df,
bin_by=bin_by,
figsize=figsize,
ylabel="Efficiency",
title=title,
)
def plot_purity_1D(df, bin_by="p", figsize=None, title=None):
return _plot_1D(
subplot_purity_1D,
df,
bin_by=bin_by,
figsize=figsize,
ylabel="Purity",
title=title,
)
# *** 2D PLOT METHODS ***
def _subplot_2D(ax, data, particle="max", cell_fontsize=10, color_list=None):
if particle == "min" or particle == "max":
cmap = colors.make_colormap(color_list, detectors=False)
idx = np.argmin(data, axis=2) if particle == "min" else np.argmax(data, axis=2)
val = np.take_along_axis(data, np.expand_dims(idx, 2), axis=2).squeeze()
vmax = len(const.PARTICLES) - 0.5
im = ax.imshow(idx, cmap=cmap, vmin=-0.5, vmax=vmax, aspect="auto")
connected_region_borders(ax, idx)
elif particle in const.PARTICLES:
index = const.PARTICLES.index(particle)
color_list = colors.fill_color_list(color_list, detectors=False)
cmap = colors.make_linear_colormap(color_list[index])
val = data[:, :, index]
im = ax.imshow(val, cmap=cmap, aspect="auto")
else:
raise ValueError("particle not understood")
if cell_fontsize:
text_array_values(ax, val, fontsize=cell_fontsize)
return im
def subplot_efficiency_2D(ax, df, particle="max", color_list=None, cell_fontsize=10):
effs = compute_efficiency_2D(df)
return _subplot_2D(
ax, effs, particle=particle, color_list=color_list, cell_fontsize=cell_fontsize
)
def subplot_purity_2D(ax, df, particle="max", color_list=None, cell_fontsize=10):
purs = compute_purity_2D(df)
return _subplot_2D(
ax, purs, particle=particle, color_list=color_list, cell_fontsize=cell_fontsize
)
def _plot_2D(
subplot_fn,
df,
particle="max",
figsize=None,
title=None,
legend_or_cbar=None,
legend_cb_label="",
cell_fontsize=10,
):
fig, ax = plt.subplots(figsize=figsize)
im = subplot_fn(ax, df, particle=particle, cell_fontsize=cell_fontsize)
fig.tight_layout()
imshow_2D_axis_labels(ax)
better_xlabel(fig, title, loc="top")
if legend_or_cbar == "legend":
from matplotlib.patches import Patch
color_list = colors.fill_color_list(None, detectors=False)
handles = [Patch(color=color_list[i]) for i in range(len(const.PARTICLES))]
outer_legend(
fig,
handles,
const.PART_PLOT_LABELS,
loc="right",
ncol=1,
title=legend_cb_label,
x=0.98,
)
elif legend_or_cbar in ["cbar", "colorbar"]:
cb = fig.colorbar(im)
cb.set_label(legend_cb_label)
return fig
def plot_efficiency_2D(df, particle="max", figsize=None, title=None, cell_fontsize=10):
if particle in const.PARTICLES:
p_index = const.PARTICLES.index(particle)
p_label = const.PART_PLOT_LABELS[p_index]
label = f"{p_label} efficiency"
return _plot_2D(
subplot_efficiency_2D,
df,
particle=particle,
figsize=figsize,
title=title,
legend_or_cbar="cbar",
legend_cb_label=label,
cell_fontsize=cell_fontsize,
)
elif particle in ["min", "max"]:
label = ("Highest" if particle == "max" else "Lowest") + "\nefficiency"
return _plot_2D(
subplot_efficiency_2D,
df,
particle=particle,
figsize=figsize,
title=title,
legend_or_cbar="legend",
legend_cb_label=label,
cell_fontsize=cell_fontsize,
)
else:
raise ValueError("argument particle not understood")
def plot_purity_2D(df, particle="max", figsize=None, title=None, cell_fontsize=10):
if particle in const.PARTICLES:
p_index = const.PARTICLES.index(particle)
p_label = const.PART_PLOT_LABELS[p_index]
label = f"{p_label} purity"
return _plot_2D(
subplot_purity_2D,
df,
particle=particle,
figsize=figsize,
title=title,
legend_or_cbar="cbar",
legend_cb_label=label,
cell_fontsize=cell_fontsize,
)
elif particle in ["min", "max"]:
label = ("Highest" if particle == "max" else "Lowest") + "\npurity"
return _plot_2D(
subplot_purity_2D,
df,
particle=particle,
figsize=figsize,
title=title,
legend_or_cbar="legend",
legend_cb_label=label,
cell_fontsize=cell_fontsize,
)
else:
raise ValueError("argument particle not understood")
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment