Commit 80e40650 authored by connor.hainje@pnnl.gov's avatar connor.hainje@pnnl.gov
Browse files

Move utilities to utils, move contribution stuff to contribution file, etc.

parent 53313549
......@@ -4,6 +4,7 @@ import matplotlib.pyplot as plt
from . import const
from .data import make_detector_dataframes
from .utils import better_xlabel, better_ylabel, outer_legend
from .compute import (
compute_confusion,
compute_efficiency_in_all_bins,
......@@ -105,84 +106,6 @@ def _disable_subplot_legends_and_labels(subplot_kw):
return subplot_kw
def get_center_coordinates(fig):
x_lo = fig.axes[0].get_position().xmin
y_hi = fig.axes[0].get_position().ymax
x_hi = fig.axes[-1].get_position().xmax
y_lo = fig.axes[-1].get_position().ymin
x_center = 0.5 * (x_lo + x_hi)
y_center = 0.5 * (y_lo + y_hi)
return x_center, y_center
def better_xlabel(fig, xlabel, loc="bottom", x=None, y=None, fontsize=16):
x_center, _ = get_center_coordinates(fig)
x_val = x if x else x_center
y_val = y if y else (1.0 if loc == "top" else 0.0)
vert_align = "top" if loc == "bottom" else ("bottom" if loc == "top" else "center")
fig.text(x_val, y_val, xlabel, fontsize=fontsize, ha="center", va=vert_align)
def better_ylabel(fig, xlabel, loc="left", x=None, y=None, fontsize=16):
_, y_center = get_center_coordinates(fig)
x_val = x if x else (1.0 if loc == "right" else 0.0)
y_val = y if y else y_center
horiz_align = "right" if loc == "left" else ("left" if loc == "right" else "center")
fig.text(
x_val,
y_val,
xlabel,
fontsize=fontsize,
ha=horiz_align,
va="center",
rotation="vertical",
)
def outer_legend(fig, labels, loc="top", x=None, y=None, ncol=None, fontsize=16):
x_center, y_center = get_center_coordinates(fig)
defaults = {
"top": {
"x": x_center,
"y": 1.0,
"ncol": 6,
"anchor": "lower center",
},
"bottom": {
"x": x_center,
"y": 0.0,
"ncol": 6,
"anchor": "upper center",
},
"left": {
"x": 0.0,
"y": y_center,
"ncol": 1,
"anchor": "center right",
},
"right": {
"x": 1.0,
"y": y_center,
"ncol": 1,
"anchor": "center left",
},
}
x_val = x if x else defaults[loc]["x"]
y_val = y if y else defaults[loc]["y"]
n_col = ncol if ncol else defaults[loc]["ncol"]
fig.legend(
labels,
frameon=False,
fontsize=fontsize,
ncol=n_col,
loc=defaults[loc]["anchor"],
bbox_to_anchor=(x_val, y_val),
)
def add_p_theta_fig_labels(fig, axs, xlabel=None, ylabel=None):
x_pos = 0.0
for i, p in enumerate(const.P_BIN_LABELS):
......@@ -622,25 +545,7 @@ def plot_top_wrong_by_particle(df, **kwargs):
return fig
# CONTRIBUTION METRIC
def subplot_contribution(ax, df, nbins=30, yscale="log", legend=True):
for det in const.DETECTORS:
ax.hist(
df[f"contrib_{det}"], bins=np.linspace(-1, 1, nbins + 1), histtype="step"
)
ax.set_xlabel("Contribution metric")
ax.set_ylabel("Frequency")
ax.set_yscale(yscale)
if legend:
ax.legend(const.DETECTORS, loc="upper left", frameon=False)
def plot_contribution(df, nbins=30, yscale="log", legend=True, figsize=(12, 4)):
fig, ax = plt.subplots(figsize=figsize)
subplot_contribution(ax, df, nbins=nbins, yscale=yscale, legend=legend)
return fig
# BLAME PLOTS
def subplot_bar_radar(ax, fs, radar=False, series_l=None, axis_l=None):
......@@ -669,36 +574,6 @@ def subplot_bar_radar(ax, fs, radar=False, series_l=None, axis_l=None):
ax.set_xlabel("True particle type")
def subplot_avg_contributions(ax, df, radar=False, legend=True):
ctrbs = []
for i in range(const.N_PARTICLES):
d = df[df["labels"] == i]
ctrbs.append([d[f"contrib_{det}"].mean() for det in const.DETECTORS])
ctrbs = np.array(ctrbs).T
subplot_bar_radar(
ax, ctrbs, radar=radar, series_l=const.DETECTORS, axis_l=const.PART_PLOT_LABELS
)
ax.set_ylabel("Average detector contribution")
if legend:
if radar:
ax.legend(frameon=False, loc="center left", bbox_to_anchor=(1.05, 0.5))
else:
ax.legend(frameon=False, loc="upper right", ncol=3)
def plot_avg_contributions(df, radar=False, figsize=(7, 6)):
if radar:
fig, ax = plt.subplots(figsize=figsize, subplot_kw={"projection": "polar"})
else:
fig, ax = plt.subplots(figsize=figsize)
subplot_avg_contributions(ax, df, radar=radar)
return fig
# BLAME PLOTS
def subplot_blame_numbers(ax, df, frac=False, radar=False, legend=True):
blfqs = [
compute_blame_numbers(df[df["labels"] == i]) for i in range(const.N_PARTICLES)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment