Commit 55f4079d authored by connor.hainje@pnnl.gov's avatar connor.hainje@pnnl.gov
Browse files

Add support for using same plotting routines for AUC scores

parent eb719174
......@@ -35,6 +35,21 @@ def compute_purity(d):
return out
def compute_auc(d):
from sklearn.metrics import roc_auc_score
out = np.zeros(len(const.PARTICLES))
for h in range(len(const.PARTICLES)):
true = d["labels"].values == h
score = d[f"lr_{h}"].values
if np.count_nonzero(true) == 0:
continue
if np.count_nonzero(~true) == 0:
continue
out[h] = roc_auc_score(true.astype(int), score)
return out
def _compute_1D(compute_fn, df, bin_by="p"):
if bin_by == "p":
gb = df.groupby(["p_bin"])
......@@ -59,6 +74,10 @@ def compute_purity_1D(df, bin_by="p"):
return _compute_1D(compute_purity, df, bin_by=bin_by)
def compute_auc_1D(df, bin_by="p"):
return _compute_1D(compute_auc, df, bin_by=bin_by)
def _compute_2D(compute_fn, df):
vals = np.zeros((const.N_P_BINS, const.N_THETA_BINS, len(const.PARTICLES)))
gb = df.groupby(["p_bin", "theta_bin"])
......@@ -75,6 +94,10 @@ def compute_purity_2D(df):
return _compute_2D(compute_purity, df)
def compute_auc_2D(df):
return _compute_2D(compute_auc, df)
# *** 1D PLOT METHODS ***
......@@ -105,6 +128,11 @@ def subplot_purity_1D(ax, df, bin_by="p", color_list=None, ls="-"):
_subplot_1D(ax, purs, bin_by=bin_by, color_list=color_list, ls=ls)
def subplot_auc_1D(ax, df, bin_by="p", color_list=None, ls="-"):
aucs = compute_auc_1D(df, bin_by=bin_by)
_subplot_1D(ax, aucs, bin_by=bin_by, color_list=color_list, ls=ls)
def _plot_1D(subplot_fn, df, bin_by="p", figsize=None, ylabel=None, title=None):
fig, ax = plt.subplots(figsize=figsize)
subplot_fn(ax, df, bin_by=bin_by)
......@@ -140,6 +168,17 @@ def plot_purity_1D(df, bin_by="p", figsize=None, title=None):
)
def plot_auc_1D(df, bin_by="p", figsize=None, title=None):
return _plot_1D(
subplot_auc_1D,
df,
bin_by=bin_by,
figsize=figsize,
ylabel="AUC",
title=title,
)
# *** 2D PLOT METHODS ***
......@@ -182,6 +221,13 @@ def subplot_purity_2D(ax, df, particle="max", color_list=None, cell_fontsize=10)
)
def subplot_auc_2D(ax, df, particle="max", color_list=None, cell_fontsize=10):
aucs = compute_auc_2D(df)
return _subplot_2D(
ax, aucs, particle=particle, color_list=color_list, cell_fontsize=cell_fontsize
)
def _plot_2D(
subplot_fn,
df,
......@@ -284,3 +330,36 @@ def plot_purity_2D(df, particle="max", figsize=None, title=None, cell_fontsize=1
else:
raise ValueError("argument particle not understood")
def plot_auc_2D(df, particle="max", figsize=None, title=None, cell_fontsize=10):
if particle in const.PARTICLES:
p_index = const.PARTICLES.index(particle)
p_label = const.PART_PLOT_LABELS[p_index]
label = f"{p_label} AUC"
return _plot_2D(
subplot_auc_2D,
df,
particle=particle,
figsize=figsize,
title=title,
legend_or_cbar="cbar",
legend_cb_label=label,
cell_fontsize=cell_fontsize,
)
elif particle in ["min", "max"]:
label = ("Highest" if particle == "max" else "Lowest") + "\nAUC"
return _plot_2D(
subplot_auc_2D,
df,
particle=particle,
figsize=figsize,
title=title,
legend_or_cbar="legend",
legend_cb_label=label,
cell_fontsize=cell_fontsize,
)
else:
raise ValueError("argument particle not understood")
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment