Commit 84ead851 authored by connor.hainje@pnnl.gov's avatar connor.hainje@pnnl.gov
Browse files

Add support for specifying color lims

parent 692656e1
......@@ -182,13 +182,19 @@ def plot_auc_1D(df, bin_by="p", figsize=None, title=None):
# *** 2D PLOT METHODS ***
def _subplot_2D(ax, data, particle="max", cell_fontsize=10, color_list=None):
def _subplot_2D(
ax, data, particle="max", cell_fontsize=10, color_list=None, vlims=None
):
if particle == "min" or particle == "max":
cmap = colors.make_colormap(color_list, detectors=False)
idx = np.argmin(data, axis=2) if particle == "min" else np.argmax(data, axis=2)
val = np.take_along_axis(data, np.expand_dims(idx, 2), axis=2).squeeze()
vmax = len(const.PARTICLES) - 0.5
im = ax.imshow(idx, cmap=cmap, vmin=-0.5, vmax=vmax, aspect="auto")
if vlims is None:
vmin = -0.5
vmax = len(const.PARTICLES) - 0.5
else:
vmin, vmax = vlims
im = ax.imshow(idx, cmap=cmap, vmin=vmin, vmax=vmax, aspect="auto")
connected_region_borders(ax, idx)
elif particle in const.PARTICLES:
......@@ -196,7 +202,13 @@ def _subplot_2D(ax, data, particle="max", cell_fontsize=10, color_list=None):
color_list = colors.fill_color_list(color_list, detectors=False)
cmap = colors.make_linear_colormap(color_list[index])
val = data[:, :, index]
im = ax.imshow(val, cmap=cmap, aspect="auto")
if vlims is None:
vmin, vmax = None, None
else:
vmin, vmax = vlims
im = ax.imshow(val, cmap=cmap, vmin=vmin, vmax=vmax, aspect="auto")
else:
raise ValueError("particle not understood")
......@@ -207,24 +219,45 @@ def _subplot_2D(ax, data, particle="max", cell_fontsize=10, color_list=None):
return im
def subplot_efficiency_2D(ax, df, particle="max", color_list=None, cell_fontsize=10):
def subplot_efficiency_2D(
ax, df, particle="max", color_list=None, cell_fontsize=10, vlims=None
):
effs = compute_efficiency_2D(df)
return _subplot_2D(
ax, effs, particle=particle, color_list=color_list, cell_fontsize=cell_fontsize
ax,
effs,
particle=particle,
color_list=color_list,
cell_fontsize=cell_fontsize,
vlims=vlims,
)
def subplot_purity_2D(ax, df, particle="max", color_list=None, cell_fontsize=10):
def subplot_purity_2D(
ax, df, particle="max", color_list=None, cell_fontsize=10, vlims=None
):
purs = compute_purity_2D(df)
return _subplot_2D(
ax, purs, particle=particle, color_list=color_list, cell_fontsize=cell_fontsize
ax,
purs,
particle=particle,
color_list=color_list,
cell_fontsize=cell_fontsize,
vlims=vlims,
)
def subplot_auc_2D(ax, df, particle="max", color_list=None, cell_fontsize=10):
def subplot_auc_2D(
ax, df, particle="max", color_list=None, cell_fontsize=10, vlims=None
):
aucs = compute_auc_2D(df)
return _subplot_2D(
ax, aucs, particle=particle, color_list=color_list, cell_fontsize=cell_fontsize
ax,
aucs,
particle=particle,
color_list=color_list,
cell_fontsize=cell_fontsize,
vlims=vlims,
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment