import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from ..utils.dataloader import DetoxaiDataLoader
from .enums import ConditionOn
from .ImageVisualizer import ImageVisualizer
from .LRPHandler import LRPHandler
from .utils import get_nth_batch
[docs]
class HeatmapVisualizer(ImageVisualizer):
""" """
def __init__(
self,
data_loader: DetoxaiDataLoader | DataLoader,
model: nn.Module,
lrp_object: LRPHandler = None,
plot_config: dict = {},
draw_rectangles: bool = False,
rectangle_config: dict = {},
) -> None:
self.data_loader = data_loader
self.model = model
if not isinstance(data_loader, DetoxaiDataLoader):
# Check if the user passed an LRPHandler object with n_classes != None
if lrp_object is None or lrp_object.n_classes is None:
raise ValueError(
"If you pass a DataLoader that is not a subclass of `DetoxaiDataLoader`, you must pass an LRPHandler with `n_classes` set."
)
if lrp_object is None:
lrp_object = LRPHandler() # Default LRPHandler
self.lrp_object = lrp_object
self.init_rectangle_painter(draw_rectangles, rectangle_config)
self.set_up_plots_configuration(plot_config)
[docs]
def visualize_batch(
self,
batch_num: int,
condition_on: ConditionOn = ConditionOn.PROPER_LABEL,
show_cbar: bool = True,
max_images: int | None = 36,
return_fig: bool = False,
) -> None:
"""
Args:
batch_num: int:
condition_on: ConditionOn: (Default value = ConditionOn.PROPER_LABEL)
show_cbar: bool: (Default value = True)
max_images: int | None: (Default value = 36)
return_fig: bool: (Default value = False)
Returns:
"""
images = self._get_heatmaps(batch_num, condition_on, max_images)
if max_images is None:
max_images = images.shape[0]
images_to_show = min(images.shape[0], max_images)
rows = int(images_to_show**0.5)
cols = int(images_to_show**0.5)
fig, ax = self.get_canvas(
rows=rows,
cols=cols,
shape=(
int(rows) * self.plot_shape_multiplier,
int(cols) * self.plot_shape_multiplier,
),
)
for i, img in enumerate(images[:max_images]):
im = ax[i // cols, i % cols].imshow(img, cmap="seismic", vmin=0, vmax=1)
self.maybe_paint_rectangle(ax[i // cols, i % cols])
if show_cbar:
# Show colorbar at the bottom
fig.subplots_adjust(right=0.85)
cbar_ax = fig.add_axes([0.89, 0.15, 0.05, 0.7])
cbar = fig.colorbar(im, cax=cbar_ax)
# Modify ticks in the colorbar
cbar.set_ticks([0, 0.5, 1])
cbar.set_ticklabels(["-1", "0", "1"])
# Make cbar slimmer in width
cbar_ax.set_aspect(25)
if return_fig:
return fig, ax
[docs]
def visualize_agg(
self,
batch_num: int,
condition_on: ConditionOn = ConditionOn.PROPER_LABEL,
) -> None:
"""
Args:
batch_num: int:
condition_on: ConditionOn: (Default value = ConditionOn.PROPER_LABEL)
Returns:
"""
_, labels, prot_attr = get_nth_batch(self.data_loader, batch_num) # noqa
images = self._get_heatmaps(batch_num, condition_on, None)
if isinstance(labels, np.ndarray):
labels = torch.tensor(labels)
if isinstance(prot_attr, np.ndarray):
prot_attr = torch.tensor(prot_attr)
if isinstance(images, np.ndarray):
images = torch.tensor(images)
ulab = labels.unique()
uprot = prot_attr.unique()
fig, ax = self.get_canvas(
rows=len(ulab), cols=len(uprot), shape=(len(ulab) * 3, len(uprot) * 3)
)
for row, label in enumerate(ulab):
for col, prot_a in enumerate(uprot):
mask = (labels == label) & (prot_attr == prot_a)
img = images[mask].mean(dim=0).cpu().detach().numpy()
ax[row, col].imshow(img, cmap="seismic", vmin=0, vmax=1)
self.maybe_paint_rectangle(ax[row, col])
def _get_heatmaps(
self, batch_num: int, condition_on: ConditionOn, max_images: int | None
) -> np.ndarray:
"""
Args:
batch_num: int:
condition_on: ConditionOn:
max_images: int | None:
Returns:
"""
images, labels, prot_attr = get_nth_batch(self.data_loader, batch_num) # noqa
if max_images is None:
max_images = images.shape[0]
heatmaps = self.lrp_object.calculate(
self.model, self.data_loader, batch_num, max_images
)
conditioned = []
for i, label in enumerate(labels[:max_images]):
# Assuming binary classification
label = label if condition_on == ConditionOn.PROPER_LABEL else 1 - label
conditioned.append(heatmaps[label, i])
images: torch.Tensor = torch.stack(conditioned).to(dtype=float)
images = images.cpu().detach().numpy()
return images