import copy
from pathlib import Path
from collections import defaultdict
import astropy.units as u
import numpy as np
from matplotlib import cm
from matplotlib import pyplot as plt
from matplotlib.colors import LogNorm, Normalize
from matplotlib.patches import Circle, Patch
from matplotlib.widgets import Slider
from stixpy.io.readers import read_subc_params
SubCollimatorConfig = read_subc_params(
Path(__file__).parent.parent / "config" / "data" / "detector" / "stx_subc_params.csv"
)
__all__ = ["PixelPlotter", "SliderCustomValue"]
[docs]
class PixelPlotter:
"""
Plot individual pixel data for each detector.
Support three kinds of plots:
* 'pixel' which show counts as rectangular patches in the correct pixel locations using a color map
* 'errorbar' which shows the counts and error as error bar plots one per detector
* 'config' which display per sub-collimator configuration
Parameters
----------
prod : `Product`
Pixel data product to plot
kind : `string` optional
This sets the visualization type of the subplots the supported options are: 'pixel', 'errorbar', 'config'.
time_indices : `list` or `numpy.ndarray`
If an 1xN array will be treated as mask if 2XN array will sum data between given
indices. For example `time_indices=[0, 2, 5]` would return only the first, third and
sixth times while `time_indices=[[0, 2],[3, 5]]` would sum the data between.
energy_indices : `list` or `numpy.ndarray`
If an 1xN array will be treated as mask if 2XN array will sum data between given
indices. For example `energy_indices=[0, 2, 5]` would return only the first, third and
sixth times while `energy_indices=[[0, 2],[3, 5]]` would sum the data between.
fig : optional `matplotlib.figure`
The figure where to which the pixel plot will be added.
cmap : `string` | `colormap` optional
If the kind is `pixels` a colormap will be shown.
NOTE: If the color of the special detectors 'cfl', 'bkg' is way above
the imaging detectors, the color will be automatically set to white.
Returns
-------
`tuple[matplotlib.figure.Figure,matplotlib.axes.Axes]`
"""
def __init__(self, prod, time_indices=None, energy_indices=None):
self.time_indices = time_indices
self.energy_indices = energy_indices
from stixpy.product.sources import CompressedPixelData, RawPixelData, SummedCompressedPixelData
if not isinstance(prod, (RawPixelData, CompressedPixelData, SummedCompressedPixelData)):
raise ValueError(f"Can not create a pixel plot as {prod.__class__} does not contain pixel data.")
self.prod = prod
self.kind = "pixel"
self.fig = None
self.axes = None
self.containers = defaultdict(list)
self._prepare_data()
[docs]
def plot(self, kind="pixel", fig=None, cmap=None):
r"""
Generates and returns the main plot figure and axes
Parameters
----------
kind : str
The visualization type: 'pixels', 'errorbar', or 'config'.
fig : matplotlib.figure.Figure, optional
An existing figure to draw on.
cmap : str or colormap, optional
The colormap for the 'pixels' plot.
Returns
-------
"""
if kind not in ["pixel", "errorbar", "config"]:
raise ValueError(f"Kind must be 'pixel', 'errorbar' or 'config' not '{kind}'.")
self.kind = kind
if fig is None:
fig, axes = plt.subplots(nrows=4, ncols=8, sharex=True, sharey=True, figsize=(7, 7))
else:
axes = fig.subplots(nrows=4, ncols=8, sharex=True, sharey=True)
self.fig = fig
self.axes = axes
self._setup_plot_elements(cmap)
self._create_main_layout()
self._create_sliders()
self._connect_update_function()
return self.fig, self.axes
def _setup_plot_elements(self, cmap):
"""Sets up normalization, colormaps, and fonts."""
max_counts = np.max(self.counts[np.isfinite(self.counts)]).value
min_counts = np.min(self.counts[self.counts > 0]).value
self.norm = CountNorm(min_counts, max_counts)
self.det_font = {"weight": "regular", "size": 8}
self.axes_font = {"weight": "regular", "size": 7}
self.quadrant_font = {"weight": "regular", "size": 15}
if cmap is None:
self.clrmap = copy.copy(cm.get_cmap("viridis"))
self.clrmap.set_over("gray")
self.clrmap.set_under("white")
self.clrmap.set_bad("gray")
elif isinstance(cmap, str):
self.clrmap = copy.copy(cm.get_cmap(cmap))
else:
self.clrmap = cmap
def _prepare_data(self):
# Get the necessary data from the product
counts, count_err, times, durations, energies = self.prod.get_data(
time_indices=self.time_indices, energy_indices=self.energy_indices
)
nt, ndet, npix, ne = counts.shape
dmask = self.prod.detector_masks.masks[0].astype(bool)
counts_pad = []
count_err_pad = []
for i, pm in enumerate(self.prod.data["pixel_masks"].value):
tmp_counts = np.full((32, 12, ne), np.nan)
tmp_err = np.full((32, 12, ne), np.nan)
tmp_counts[np.ix_(dmask, pm.astype(bool))] = counts[i][:, pm.astype(bool)[: counts.shape[2]], :]
tmp_err[np.ix_(dmask, pm.astype(bool))] = count_err[i][:, pm.astype(bool)[: counts.shape[2]], :]
counts_pad.append(tmp_counts)
count_err_pad.append(tmp_err)
counts_pad = np.stack(counts_pad)
count_err_pad = np.stack(count_err_pad)
self.times = times
self.energies = energies
self.counts = counts_pad * counts.unit
self.count_err = count_err_pad * count_err.unit
def _create_main_layout(self):
"""Draws the instrument layout and the 32 detector subplots."""
self._draw_instrument_layout()
if self.kind == "pixel":
self._draw_colorbar()
xnorm = Normalize(SubCollimatorConfig["SC Xcen"].min() * 1.5, SubCollimatorConfig["SC Xcen"].max() * 1.5)
ynorm = Normalize(SubCollimatorConfig["SC Ycen"].min() * 1.4, SubCollimatorConfig["SC Ycen"].max() * 1.4)
pixel_ids = [slice(0, 4), slice(4, 8), slice(8, 12)]
if self.counts.shape[2] == 4:
pixel_ids = [slice(0, 4)]
for det_id in range(32):
row, col = divmod(det_id, 8)
ax = self.axes[row, col]
plot_container = None
if self.kind == "pixel":
plot_container = self._det_pixels_plot(self.counts[0, det_id, :, 0], ax, last=(det_id == 31))
elif self.kind == "errorbar":
plot_container = self._det_errorbar_plot(
self.counts[0, det_id, :, 0], self.count_err[0, det_id, :, 0], pixel_ids, det_id, ax
)
elif self.kind == "config":
plot_container = self._det_config_plot(SubCollimatorConfig[det_id], ax, det_id)
ax.set_zorder(100)
ax.set_position(
[
xnorm(SubCollimatorConfig["SC Xcen"][det_id]),
ynorm(SubCollimatorConfig["SC Ycen"][det_id]),
1 / 11.0,
1 / 11.0,
]
)
self.containers[row, col].append(plot_container)
resolutions = np.arctan2(0.5 * SubCollimatorConfig["Front Pitch"].to("um"), 545.30 * u.mm).to("arcsec")
ax.set_title(
f"{SubCollimatorConfig['Det #'][det_id]}"
f" {SubCollimatorConfig['Grid Label'][det_id]}"
f'{resolutions[det_id].value: 0.1f}"',
y=0.89,
**self.det_font,
)
def _create_sliders(self):
"""Creates the time and energy sliders."""
axcolor = "lightgoldenrodyellow"
axenergy = plt.axes([0.15, 0.05, 0.55, 0.03], facecolor=axcolor)
self.senergy = SliderCustomValue(
ax=axenergy,
label="Energy",
valmin=0,
valmax=len(self.energies) - 1,
format_func=self._format_energy,
valinit=0,
valstep=1,
)
axetime = plt.axes([0.15, 0.01, 0.55, 0.03], facecolor=axcolor)
self.stime = SliderCustomValue(
ax=axetime,
label="Time",
valmin=0,
valmax=self.counts.shape[0] - 1,
format_func=self._format_time,
valinit=1,
valstep=1,
)
# --- Formatting and Drawing Helpers ---
def _format_time(self, val):
return self.times[val].isot
def _format_energy(self, val):
return f"{self.energies[val]['e_low'].value}-{self.energies[val]['e_high']}"
def _draw_colorbar(self):
"""Creates a colormap on the left side of the figure."""
cax = self.fig.add_axes([0.05, 0.15, 0.025, 0.8])
cbar = self.fig.colorbar(cm.ScalarMappable(norm=self.norm, cmap=self.clrmap), orientation="vertical", cax=cax)
cbar.ax.set_title(f"{str(self.counts.unit)}", rotation=90, x=-0.8, y=0.4)
def _draw_instrument_layout(self):
"""Shows the layout of the instrument."""
x = [0, 2]
y = [1, 1]
ax = self.fig.add_axes([0.06, 0.055, 0.97, 0.97])
ax.plot(x, y, c="b")
ax.plot(y, x, c="b")
ax.axis("off")
ax = self.fig.add_axes([0.09, 0.08, 0.91, 0.92])
draw_circle_1 = Circle((0.545, 0.540), 0.443, color="b", alpha=0.1)
draw_circle_2 = Circle((0.545, 0.540), 0.07, color="#2b330b", alpha=0.95)
self.fig.add_artist(draw_circle_1)
self.fig.add_artist(draw_circle_2)
ax.axis("off")
ax = self.fig.add_axes([0, 0, 1, 1])
ax.text(0.19, 0.89, "Q1", **self.quadrant_font)
ax.text(0.19, 0.17, "Q2", **self.quadrant_font)
ax.text(0.86, 0.17, "Q3", **self.quadrant_font)
ax.text(0.86, 0.89, "Q4", **self.quadrant_font)
ax.axis("off")
# --- Per-Detector Plotting Logic ---
def _det_pixels_plot(self, counts, axes, last=False):
"""Shows a plot to visualize the pixel counts."""
x_pos, bar1, bar2, bar3 = ["A", "B", "C", "D"], [1] * 4, [-1] * 4, [0.2] * 4
counts = counts.reshape(3, 4)
top = axes.bar(
x_pos, bar1, color=self.clrmap(self.norm(counts[0, :])), width=1, zorder=1, edgecolor="k", linewidth=0.5
)
bottom = axes.bar(
x_pos, bar2, color=self.clrmap(self.norm(counts[1, :])), width=1, zorder=1, edgecolor="k", linewidth=0.5
)
small = axes.bar(
x_pos,
bar3,
color=self.clrmap(self.norm(counts[2, :])),
width=-0.5,
align="edge",
bottom=-0.1,
zorder=1,
edgecolor="k",
linewidth=0.5,
)
axes.axes.get_yaxis().set_visible(False)
if last:
axes.set_xticks(range(4))
axes.set_xticklabels(x_pos)
axes.axes.get_xaxis().set_visible(True)
else:
axes.set_xticks([])
axes.axes.get_xaxis().set_visible(False)
for i in range(4):
top[i].data = counts[0, i]
bottom[i].data = counts[1, i]
small[i].data = counts[2, i]
self._create_hover_tooltip(axes, [top, bottom, small], last)
return top, bottom, small
def _det_errorbar_plot(self, counts, count_err, pixel_ids, detector_id, axes):
"""Shows an errorbar plot of counts."""
plot_cont = [
axes.errorbar((0.5, 1.5, 2.5, 3.5), counts[pid], yerr=count_err[pid], xerr=0.5, ls="") for pid in pixel_ids
]
axes.set_xticks([])
if detector_id > 0:
axes.set_ylabel("")
return plot_cont
def _det_config_plot(self, detector_config, axes, detector_id):
"""Shows a plot with detector configurations."""
# Create Functions to convert 'Front' and 'Rear Orient'.
def mm2deg(x):
return x * 360.0 / 1
def deg2mm(x):
return x / 360.0 * 1
# get the information that will be plotted
if detector_config["Phase Sense"] > 0:
phase_sense = "+"
elif detector_config["Phase Sense"] < 0:
phase_sense = "-"
else:
phase_sense = "n"
y = [
detector_config["Slit Width"],
detector_config["Front Pitch"],
detector_config["Rear Pitch"],
0,
deg2mm(detector_config["Front Orient"]),
deg2mm(detector_config["Rear Orient"]),
]
x = np.arange(len(y))
color = ["black", "orange", "#1f77b4", "b", "orange", "#1f77b4"]
# plot the information on axes
axes.bar(x, y, color=color)
axes.text(x=0.8, y=0.7, s=f"Phase: {phase_sense}", **self.axes_font)
axes.set_ylim(0, 1)
axes.axes.get_xaxis().set_visible(False)
# Create secondary y axis
ax2 = axes.secondary_yaxis("right", functions=(mm2deg, deg2mm))
ax2.set_yticks([0, 90, 270, 360])
ax2.set_yticklabels(["0°", "90°", "270°", "360°"], fontsize=8)
ax2.set_visible(False)
axes.axes.get_yaxis().set_visible(False)
# Create axes labeling and legend
if detector_id == 0:
axes.set_yticks([0, 1])
axes.set_ylabel("mm", **self.axes_font)
axes.yaxis.set_label_coords(-0.1, 0.5)
axes.axes.get_yaxis().set_visible(True)
legend_bars = [Patch(facecolor="orange"), Patch(facecolor="#1f77b4")]
axes.legend(legend_bars, ["Front", "Rear"], loc="center right", bbox_to_anchor=(0, 2.5))
if detector_id == 31:
ax2.set_visible(True)
axes.axes.get_xaxis().set_visible(True)
axes.set_xticks([0, 1.5, 4.5])
axes.set_xticklabels(["Slit Width", "Pitch", "Orientation"], rotation=90)
# leave the spaces to set the correct x position of the label!
ax2.set_ylabel(" deg °", rotation=0, **self.axes_font)
# x parameter doesn't change anything because it's a secondary
# y axis (has only 1 x position).
ax2.yaxis.set_label_coords(x=1, y=0.55)
def _create_hover_tooltip(self, axes, artists_list, last):
"""Creates and manages the hover annotation for a subplot."""
annot = axes.annotate(
"",
xy=(0, 0),
xytext=(-60, 20),
textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"),
arrowprops=dict(arrowstyle="-"),
zorder=33,
)
annot.set_visible(False)
def update_annot(artist):
center_x = artist.get_x() + artist.get_width() / 2
center_y = artist.get_y() + artist.get_height() / 2
annot.xy = (center_x, center_y)
annot.set_text(format(artist.data, ".2e"))
def hover(event):
annot.set_visible(False)
if event.inaxes == axes:
for artist_group in artists_list:
for artist in artist_group:
contains, _ = artist.contains(event)
if contains:
update_annot(artist)
annot.set_visible(True)
break
if last:
self.fig.canvas.draw_idle()
self.fig.canvas.mpl_connect("motion_notify_event", hover)
# --- Update Functions for Sliders ---
def _connect_update_function(self):
"""Connects the appropriate update function to the sliders."""
update_function = self._update_void
if self.kind == "pixel":
update_function = self._update_pixels
elif self.kind == "errorbar":
update_function = self._update_errorbar
self.senergy.on_changed(update_function)
self.stime.on_changed(update_function)
def _update_void(self, _):
"""Dummy update function for static plots."""
pass
def _update_pixels(self, _):
"""Updates the pixel colors based on slider values."""
energy_index, time_index = self.senergy.val, self.stime.val
for detector_id in range(32):
row, col = divmod(detector_id, 8)
top, bottom, small = self.containers[row, col][0]
cnts = self.counts[time_index, detector_id, :, energy_index].reshape([3, 4])
for idx in range(4):
norm_counts = self.norm(cnts[0][idx].value)
top[idx].set_color(self.clrmap(norm_counts))
top[idx].data = cnts[0][idx]
top[idx].set_edgecolor("k")
norm_counts = self.norm(cnts[1][idx].value)
bottom[idx].set_color(self.clrmap(norm_counts))
bottom[idx].data = cnts[1][idx]
bottom[idx].set_edgecolor("k")
norm_counts = self.norm(cnts[2][idx].value)
small[idx].set_color(self.clrmap(norm_counts))
small[idx].data = cnts[2][idx]
small[idx].set_edgecolor("k")
self.fig.canvas.draw_idle()
def _update_errorbar(self, _):
energy_index = self.senergy.val
time_index = self.stime.val
pids_ = [slice(0, 4), slice(4, 8), slice(8, 12)]
if self.counts.shape[2] == 4:
pids_ = [slice(0, 4)]
for did in range(32):
r, c = divmod(did, 8)
self.axes[r, c].set_ylim(0, np.nanmax(self.counts[time_index, :, :, energy_index]) * 1.2)
for i, pid in enumerate(pids_):
lines, caps, bars = self.containers[r, c][0][i]
lines.set_ydata(self.counts[time_index, did, pid, energy_index])
# horizontal bars at value
segs = np.array(bars[0].get_segments())
if segs.size > 0:
segs[:, 0, 0] = [0.0, 1.0, 2.0, 3.0]
segs[:, 1, 0] = [1.0, 2.0, 3.0, 4.0]
segs[:, 0, 1] = self.counts[time_index, did, pid, energy_index]
segs[:, 1, 1] = self.counts[time_index, did, pid, energy_index]
bars[0].set_segments(segs)
# vertical bars at +/- error
segs = np.array(bars[1].get_segments())
segs[:, 0, 0] = [0.5, 1.5, 2.5, 3.5]
segs[:, 1, 0] = [0.5, 1.5, 2.5, 3.5]
segs[:, 0, 1] = (
self.counts[time_index, did, pid, energy_index]
- self.count_err[time_index, did, pid, energy_index]
)
segs[:, 1, 1] = (
self.counts[time_index, did, pid, energy_index]
+ self.count_err[time_index, did, pid, energy_index]
)
bars[1].set_segments(segs)
self.fig.canvas.draw_idle()
[docs]
class SliderCustomValue(Slider):
"""
A slider with a customisable formatter
"""
def __init__(self, *args, format_func=None, **kwargs):
if format_func is not None:
self._format = format_func
super().__init__(*args, **kwargs)
class CountNorm(Normalize):
"""
A LogNorm but allows 0s to be kept and plotted with colormaps under color
"""
def __init__(self, vmin=None, vmax=None, **kwargs):
super().__init__(vmin, vmax, **kwargs)
self.lognorm = LogNorm(vmin=vmin, vmax=vmax)
def __call__(self, value):
tmp, is_scaler = self.process_value(value)
unit = tmp.unit if hasattr(tmp, "unit") else 1
zeros = np.nonzero(tmp == 0)
tmp[zeros] = np.finfo(np.float32).tiny * unit
res = self.lognorm(tmp)
return res if not is_scaler else res[0]
def inverse(self, value):
tmp, is_scaler = self.process_value(value)
unit = tmp.unit if hasattr(tmp, "unit") else 1
zeros = np.nonzero(tmp == 0)
tmp[zeros] = np.finfo(np.float32).tiny * unit
res = self.lognorm.inverse(tmp)
return res if not is_scaler else res[0]