Commit 9f0c418b authored by connor.hainje@pnnl.gov's avatar connor.hainje@pnnl.gov
Browse files

Add an implementation for blame fractions in 1D and 2D, as done with contribution

parent cc2f7788
......@@ -4,5 +4,6 @@ from .compute import *
from .plot import *
from .eff_pur import *
from .contribution import *
from .blame import *
from .likelihood_ratios import *
from . import binary
import numpy as np
import matplotlib.pyplot as plt
from numpy.core.numeric import outer
from . import const, colors
from .utils import (
better_xlabel,
connected_region_borders,
imshow_2D_axis_labels,
outer_legend,
set_bin_by_xlabel,
text_array_values,
xticks_radians_to_degrees,
)
def compute_blame(d):
out = np.zeros(len(const.DETECTORS))
cols = [f"contrib_{det}" for det in const.DETECTORS]
blames = np.argmax(d[cols].values, axis=1)
unique, counts = np.unique(blames, return_counts=True)
unique = unique.astype(int)
counts = counts.astype(float)
counts = counts / np.sum(counts)
out[unique] = counts
return out
def compute_blame_1D(df, bin_by="p"):
if bin_by == "p":
gb = df.groupby(["p_bin"])
vals = np.zeros((const.N_P_BINS, len(const.DETECTORS)))
elif bin_by == "theta":
gb = df.groupby(["theta_bin"])
vals = np.zeros((const.N_THETA_BINS, len(const.DETECTORS)))
else:
raise ValueError()
for i, g in gb:
vals[i] = compute_blame(g)
return vals
def compute_blame_2D(df):
vals = np.zeros((const.N_P_BINS, const.N_THETA_BINS, len(const.DETECTORS)))
gb = df.groupby(["p_bin", "theta_bin"])
for (i, j), g in gb:
vals[i, j] = compute_blame(g)
return vals
def _fill_label_list(label_list):
if label_list is not None:
return label_list
else:
return const.DETECTORS
def subplot_blame_1D(ax, df, bin_by="p", color_list=None, label_list=None, ls="-"):
blames = compute_blame_1D(df, bin_by=bin_by)
color_list = colors.fill_color_list(color_list, detectors=True)
label_list = _fill_label_list(label_list)
bin_centers = const.P_BIN_CENTERS if bin_by == "p" else const.THETA_BIN_CENTERS
for i in range(len(const.DETECTORS)):
color = color_list[i]
label = label_list[i]
ax.plot(bin_centers, blames[:, i], c=color, ls=ls, label=label)
def plot_blame_1D(df, bin_by="p", title=None, figsize=None):
fig, ax = plt.subplots(figsize=figsize)
subplot_blame_1D(ax, df, bin_by=bin_by)
fig.tight_layout()
set_bin_by_xlabel(ax, bin_by)
if bin_by == "theta":
xticks_radians_to_degrees(ax)
ax.set_ylabel("Blame")
outer_legend(fig, loc="top", ncol=len(const.DETECTORS), y=0.98)
better_xlabel(fig, title, loc="top", y=1.08)
return fig
def subplot_blame_2D(ax, df, detector="max", color_list=None, cell_fontsize=10):
blames = compute_blame_2D(df)
color_list = colors.fill_color_list(color_list, detectors=True)
if detector in ["min", "max"]:
cmap = colors.make_colormap(color_list, detectors=True)
idx = np.argmin(blames, 2) if detector == "min" else np.argmax(blames, 2)
val = np.take_along_axis(blames, np.expand_dims(idx, 2), 2).squeeze()
vmax = len(const.DETECTORS) - 0.5
im = ax.imshow(idx, cmap=cmap, vmin=-0.5, vmax=vmax, aspect="auto")
connected_region_borders(ax, idx)
elif detector in const.DETECTORS:
index = const.DETECTORS.index(detector)
cmap = colors.make_linear_colormap(color_list[index])
val = blames[:, :, index]
im = ax.imshow(val, cmap=cmap, aspect="auto")
else:
raise ValueError("detector not understood")
if cell_fontsize:
text_array_values(ax, val, fontsize=cell_fontsize)
return im
def blame_2D_legend(fig, color_list=None, label_list=None, title=None):
from matplotlib.patches import Patch
label_list = _fill_label_list(label_list)
color_list = colors.fill_color_list(color_list, detectors=True)
handles = [Patch(color=color_list[i]) for i in range(len(const.DETECTORS))]
outer_legend(
fig,
handles,
label_list,
loc="right",
ncol=1,
title=title,
x=0.98,
)
def plot_blame_2D(df, detector="max", title=None, figsize=None, cell_fontsize=10):
fig, ax = plt.subplots(figsize=figsize)
im = subplot_blame_2D(ax, df, detector=detector, cell_fontsize=cell_fontsize)
fig.tight_layout()
imshow_2D_axis_labels(ax)
better_xlabel(fig, title, loc="top")
if detector in const.DETECTORS:
cb = fig.colorbar(im)
cb.set_label(f"{detector} blame")
else:
title = ("Highest" if detector == "max" else "Lowest") + "\nblame"
blame_2D_legend(fig, title=title)
return fig
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment