Commit 9f0c418b by connor.hainje@pnnl.gov

### Add an implementation for blame fractions in 1D and 2D, as done with contribution

parent cc2f7788
 ... ... @@ -4,5 +4,6 @@ from .compute import * from .plot import * from .eff_pur import * from .contribution import * from .blame import * from .likelihood_ratios import * from . import binary
 import numpy as np import matplotlib.pyplot as plt from numpy.core.numeric import outer from . import const, colors from .utils import ( better_xlabel, connected_region_borders, imshow_2D_axis_labels, outer_legend, set_bin_by_xlabel, text_array_values, xticks_radians_to_degrees, ) def compute_blame(d): out = np.zeros(len(const.DETECTORS)) cols = [f"contrib_{det}" for det in const.DETECTORS] blames = np.argmax(d[cols].values, axis=1) unique, counts = np.unique(blames, return_counts=True) unique = unique.astype(int) counts = counts.astype(float) counts = counts / np.sum(counts) out[unique] = counts return out def compute_blame_1D(df, bin_by="p"): if bin_by == "p": gb = df.groupby(["p_bin"]) vals = np.zeros((const.N_P_BINS, len(const.DETECTORS))) elif bin_by == "theta": gb = df.groupby(["theta_bin"]) vals = np.zeros((const.N_THETA_BINS, len(const.DETECTORS))) else: raise ValueError() for i, g in gb: vals[i] = compute_blame(g) return vals def compute_blame_2D(df): vals = np.zeros((const.N_P_BINS, const.N_THETA_BINS, len(const.DETECTORS))) gb = df.groupby(["p_bin", "theta_bin"]) for (i, j), g in gb: vals[i, j] = compute_blame(g) return vals def _fill_label_list(label_list): if label_list is not None: return label_list else: return const.DETECTORS def subplot_blame_1D(ax, df, bin_by="p", color_list=None, label_list=None, ls="-"): blames = compute_blame_1D(df, bin_by=bin_by) color_list = colors.fill_color_list(color_list, detectors=True) label_list = _fill_label_list(label_list) bin_centers = const.P_BIN_CENTERS if bin_by == "p" else const.THETA_BIN_CENTERS for i in range(len(const.DETECTORS)): color = color_list[i] label = label_list[i] ax.plot(bin_centers, blames[:, i], c=color, ls=ls, label=label) def plot_blame_1D(df, bin_by="p", title=None, figsize=None): fig, ax = plt.subplots(figsize=figsize) subplot_blame_1D(ax, df, bin_by=bin_by) fig.tight_layout() set_bin_by_xlabel(ax, bin_by) if bin_by == "theta": xticks_radians_to_degrees(ax) ax.set_ylabel("Blame") outer_legend(fig, loc="top", ncol=len(const.DETECTORS), y=0.98) better_xlabel(fig, title, loc="top", y=1.08) return fig def subplot_blame_2D(ax, df, detector="max", color_list=None, cell_fontsize=10): blames = compute_blame_2D(df) color_list = colors.fill_color_list(color_list, detectors=True) if detector in ["min", "max"]: cmap = colors.make_colormap(color_list, detectors=True) idx = np.argmin(blames, 2) if detector == "min" else np.argmax(blames, 2) val = np.take_along_axis(blames, np.expand_dims(idx, 2), 2).squeeze() vmax = len(const.DETECTORS) - 0.5 im = ax.imshow(idx, cmap=cmap, vmin=-0.5, vmax=vmax, aspect="auto") connected_region_borders(ax, idx) elif detector in const.DETECTORS: index = const.DETECTORS.index(detector) cmap = colors.make_linear_colormap(color_list[index]) val = blames[:, :, index] im = ax.imshow(val, cmap=cmap, aspect="auto") else: raise ValueError("detector not understood") if cell_fontsize: text_array_values(ax, val, fontsize=cell_fontsize) return im def blame_2D_legend(fig, color_list=None, label_list=None, title=None): from matplotlib.patches import Patch label_list = _fill_label_list(label_list) color_list = colors.fill_color_list(color_list, detectors=True) handles = [Patch(color=color_list[i]) for i in range(len(const.DETECTORS))] outer_legend( fig, handles, label_list, loc="right", ncol=1, title=title, x=0.98, ) def plot_blame_2D(df, detector="max", title=None, figsize=None, cell_fontsize=10): fig, ax = plt.subplots(figsize=figsize) im = subplot_blame_2D(ax, df, detector=detector, cell_fontsize=cell_fontsize) fig.tight_layout() imshow_2D_axis_labels(ax) better_xlabel(fig, title, loc="top") if detector in const.DETECTORS: cb = fig.colorbar(im) cb.set_label(f"{detector} blame") else: title = ("Highest" if detector == "max" else "Lowest") + "\nblame" blame_2D_legend(fig, title=title) return fig
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!