Commit 66fd4d95 authored by Connor Hainje's avatar Connor Hainje
Browse files

Add 2D efficiency, purity plots

parent df68e809
import numpy as np
import matplotlib.pyplot as plt
import warnings
from . import const, colors
from .compute import softmax
......@@ -9,8 +10,10 @@ from .utils import (
outer_legend,
_validate_bin_by_1D,
set_bin_by_xlabel,
text_array_values,
)
from .data import make_detector_dataframes
from .eff_pur import _plot_2D
BINARY = None
......@@ -69,10 +72,13 @@ def compute_confusion(df, norm=None, threshold=0.5):
conf[1, 0] = np.count_nonzero(other_events < threshold) # true `part`, pred pi
conf[1, 1] = np.count_nonzero(other_events > threshold) # true `part`, pred `part`
if norm in ["row", "eff", "efficiency"]:
conf = np.divide(conf, conf.sum(axis=1, keepdims=True), casting="unsafe")
elif norm in ["col", "column", "pur", "purity"]:
conf = np.divide(conf, conf.sum(axis=0, keepdims=True), casting="unsafe")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if norm in ["row", "eff", "efficiency"]:
conf = np.divide(conf, conf.sum(axis=1, keepdims=True), casting="unsafe")
elif norm in ["col", "column", "pur", "purity"]:
conf = np.divide(conf, conf.sum(axis=0, keepdims=True), casting="unsafe")
return conf
......@@ -118,7 +124,7 @@ def compute_purity_in_bins(df, threshold=0.5, bin_by="p"):
def subplot_efficiency(
ax,
df,
part,
particle,
threshold=0.5,
bin_by="p",
ls=None,
......@@ -143,7 +149,7 @@ def subplot_efficiency(
color_list = [color, color]
else:
if color_list is None:
if part == "pi":
if particle == "pi":
color_list = [colors.PARTICLES["pi"], colors.PARTICLES[BINARY]]
else:
color_list = [colors.PARTICLES[BINARY], colors.PARTICLES["pi"]]
......@@ -155,7 +161,7 @@ def subplot_efficiency(
else:
bincs = const.THETA_BIN_CENTERS
if part == "pi":
if particle == "pi":
ax.plot(bincs, effs[:, 0, 0], ls=ls_list[0], color=color_list[0])
ax.plot(bincs, effs[:, 1, 0], ls=ls_list[1], color=color_list[1])
else:
......@@ -165,15 +171,15 @@ def subplot_efficiency(
ax.grid(ls=":")
def plot_efficiency(df, part, threshold=0.5, bin_by="p"):
def plot_efficiency(df, particle, threshold=0.5, bin_by="p"):
fig, ax = plt.subplots(figsize=(7, 5))
subplot_efficiency(ax, df, part, threshold=threshold, bin_by=bin_by)
subplot_efficiency(ax, df, particle, threshold=threshold, bin_by=bin_by)
set_bin_by_xlabel(ax, bin_by)
fig.tight_layout()
return fig
def plot_efficiency_by_detector(df, part, threshold=0.5, bin_by="p", no_SVD=False):
def plot_efficiency_by_detector(df, particle, threshold=0.5, bin_by="p", no_SVD=False):
df_d = make_detector_dataframes(df)
fig, axs = plt.subplots(2, 3, figsize=(12, 8))
......@@ -181,20 +187,20 @@ def plot_efficiency_by_detector(df, part, threshold=0.5, bin_by="p", no_SVD=Fals
kwargs = dict(threshold=threshold, bin_by=bin_by)
if no_SVD and len(const.DETECTORS) == 5:
subplot_efficiency(axs[0, 0], df, part, **kwargs)
subplot_efficiency(axs[0, 0], df, particle, **kwargs)
axs[0, 0].set_title("Ensemble", fontsize=14)
for ax, det in zip(axs.flat[1:], const.DETECTORS):
subplot_efficiency(ax, df_d[det], part, **kwargs)
subplot_efficiency(ax, df_d[det], particle, **kwargs)
ax.set_title(det, fontsize=14)
else:
for ax, det in zip(axs.flat, const.DETECTORS):
if no_SVD and det == "SVD":
subplot_efficiency(ax, df, part, **kwargs)
subplot_efficiency(ax, df, particle, **kwargs)
ax.set_title("Ensemble", fontsize=14)
else:
subplot_efficiency(ax, df_d[det], part, **kwargs)
subplot_efficiency(ax, df_d[det], particle, **kwargs)
ax.set_title(det, fontsize=14)
fig.tight_layout()
......@@ -202,7 +208,7 @@ def plot_efficiency_by_detector(df, part, threshold=0.5, bin_by="p", no_SVD=Fals
return fig
def subplot_purity(ax, df, part, threshold=0.5, bin_by="p", color=None, ls="-"):
def subplot_purity(ax, df, particle, threshold=0.5, bin_by="p", color=None, ls="-"):
_validate_bin_by_1D(bin_by)
_validate_binary()
......@@ -214,18 +220,18 @@ def subplot_purity(ax, df, part, threshold=0.5, bin_by="p", color=None, ls="-"):
bincs = const.THETA_BIN_CENTERS
if color is None:
color = colors.PARTICLES[part]
color = colors.PARTICLES[particle]
if part == "pi":
if particle == "pi":
ax.plot(bincs, purs[:, 0, 0], ls=ls, c=color)
else:
ax.plot(bincs, purs[:, 1, 1], ls=ls, c=color)
ax.grid(ls=":")
def plot_purity(df, part, threshold=0.5, bin_by="p"):
def plot_purity(df, particle, threshold=0.5, bin_by="p"):
fig, ax = plt.subplots(figsize=(7, 5))
subplot_purity(ax, df, part, threshold=threshold, bin_by=bin_by)
subplot_purity(ax, df, particle, threshold=threshold, bin_by=bin_by)
set_bin_by_xlabel(ax, bin_by)
fig.tight_layout()
return fig
......@@ -233,7 +239,7 @@ def plot_purity(df, part, threshold=0.5, bin_by="p"):
def plot_purity_varying_threshold(
df,
part,
particle,
thresh_lims=(0.3, 0.7),
n_thresh=20,
bin_by="p",
......@@ -250,10 +256,10 @@ def plot_purity_varying_threshold(
cols = cmap(np.linspace(0, 1, num=len(cuts)))
for cut, col in zip(cuts, cols):
subplot_purity(ax, df, part, threshold=cut, bin_by=bin_by, color=col)
subplot_purity(ax, df, particle, threshold=cut, bin_by=bin_by, color=col)
# add 0.5 as a black dashed line
subplot_purity(ax, df, part, threshold=0.5, bin_by=bin_by, color="k", ls="--")
subplot_purity(ax, df, particle, threshold=0.5, bin_by=bin_by, color="k", ls="--")
ax.set_ylim(ylim)
set_bin_by_xlabel(ax, bin_by)
......@@ -270,6 +276,101 @@ def plot_purity_varying_threshold(
return fig
def _subplot_2D(ax, data, particle, cell_fontsize=None, color=None, vlims=None):
_validate_binary()
if color is None:
index = const.PARTICLES.index(particle)
color_list = colors.fill_color_list(None, detectors=False)
color = color_list[index]
cmap = colors.make_linear_colormap(color)
if particle == "pi":
val = data[:, :, 0, 0]
else:
val = data[:, :, 1, 1]
if vlims is None:
vmin, vmax = None, None
else:
vmin, vmax = vlims
im = ax.imshow(val, cmap=cmap, aspect="auto", origin="lower", vmin=vmin, vmax=vmax)
if cell_fontsize:
text_array_values(ax, val, fontsize=cell_fontsize)
return im
def subplot_efficiency_2D(
ax, df, particle, threshold=0.5, cell_fontsize=None, color=None, vlims=None
):
effs = compute_efficiency_in_bins(df, threshold=threshold, bin_by="both")
return _subplot_2D(
ax,
effs,
particle,
cell_fontsize=cell_fontsize,
color=color,
vlims=vlims,
)
def subplot_purity_2D(
ax, df, particle, threshold=0.5, cell_fontsize=None, color=None, vlims=None
):
purs = compute_purity_in_bins(df, threshold=threshold, bin_by="both")
return _subplot_2D(
ax,
purs,
particle,
cell_fontsize=cell_fontsize,
color=color,
vlims=vlims,
)
def plot_efficiency_2D(
df, particle, threshold=0.5, figsize=None, title=None, cell_fontsize=10, **kwargs
):
p_index = const.PARTICLES.index(particle)
p_label = const.PART_PLOT_LABELS[p_index]
label = f"{p_label} efficiency"
return _plot_2D(
subplot_efficiency_2D,
df,
particle,
threshold=threshold,
cell_fontsize=cell_fontsize,
figsize=figsize,
title=title,
legend_or_cbar="cbar",
legend_cb_label=label,
**kwargs,
)
def plot_purity_2D(
df, particle, threshold=0.5, figsize=None, title=None, cell_fontsize=10, **kwargs
):
p_index = const.PARTICLES.index(particle)
p_label = const.PART_PLOT_LABELS[p_index]
label = f"{p_label} purity"
return _plot_2D(
subplot_purity_2D,
df,
particle,
threshold=threshold,
cell_fontsize=cell_fontsize,
figsize=figsize,
title=title,
legend_or_cbar="cbar",
legend_cb_label=label,
**kwargs,
)
def subplot_contribution(ax, df, ls="-", bin_by="p", color_list=None):
_validate_bin_by_1D(bin_by)
if bin_by == "p":
......@@ -299,12 +400,14 @@ def plot_contribution(df, bin_by="p", color_list=None):
return fig
def plot_contribution_by_correctness(df, part, title=None, bin_by="p", color_list=None):
def plot_contribution_by_correctness(
df, particle, title=None, bin_by="p", color_list=None
):
_validate_bin_by_1D(bin_by)
plabel = const.PART_TO_LABEL[part]
plabel = const.PART_TO_LABEL[particle]
d = df.loc[df["labels"] == plabel]
if part == BINARY:
if particle == BINARY:
corr = d[f"binary_lr_{BINARY}"] > 0.5
else:
corr = d[f"binary_lr_{BINARY}"] < 0.5
......@@ -430,13 +533,13 @@ def plot_blame(df, bin_by="p", color_list=None):
return fig
def plot_blame_by_correctness(df, part, title=None, bin_by="p", color_list=None):
def plot_blame_by_correctness(df, particle, title=None, bin_by="p", color_list=None):
_validate_binary()
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
plabel = const.PART_TO_LABEL[part]
plabel = const.PART_TO_LABEL[particle]
d = df.loc[df["labels"] == plabel]
if part == BINARY:
if particle == BINARY:
corr = d[f"binary_lr_{BINARY}"] > 0.5
else:
corr = d[f"binary_lr_{BINARY}"] < 0.5
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment