Source code for detoxai.visualization.SSVisualizer

import numpy as np
from torch import nn
from torch.utils.data import DataLoader

from ..utils.dataloader import DetoxaiDataLoader
from ..utils.datasets import DetoxaiDataset
from .DataVisualizer import DataVisualizer
from .HeatmapVisualizer import ConditionOn, HeatmapVisualizer
from .ImageVisualizer import ImageVisualizer
from .LRPHandler import LRPHandler
from .utils import get_nth_batch


[docs] class SSVisualizer(ImageVisualizer): """SS - Side by Side visualizer for Heatmaps and Data""" def __init__( self, data_loader: DetoxaiDataLoader | DataLoader, model: nn.Module, lrp_object: LRPHandler = None, plot_config: dict = {}, draw_rectangles: bool = False, rectangle_config: dict = {}, ) -> None: """ Args: data_loader: DetoxaiDataLoader | DataLoader: model: nn.Module: lrp_object: LRPHandler = None: plot_config: dict = {}: draw_rectangles: bool = False: rectangle_config: dict = {}: Returns: None If you pass a DataLoader that is not a subclass of `DetoxaiDataLoader`, you must pass an LRPHandler with `n_classes` set! """ self.data_loader = data_loader self.model = model self.set_up_plots_configuration(plot_config) self.init_rectangle_painter(draw_rectangles, rectangle_config) if not isinstance(data_loader, DetoxaiDataLoader): # Check if the user passed an LRPHandler object with n_classes != None if lrp_object is None or lrp_object.n_classes is None: raise ValueError( "If you pass a DataLoader that is not a subclass of `DetoxaiDataLoader`, you must pass an LRPHandler with `n_classes` set." ) if not isinstance(data_loader.dataset, DetoxaiDataset): # Check if the user passed an LRPHandler object with n_classes != None if lrp_object is None or lrp_object.n_classes is None: raise ValueError( "If you pass a DataLoader that has a dataset that is not a subclass of `DetoxaiDataset`, you must pass an LRPHandler with `n_classes` set." ) self.lrp_vis = HeatmapVisualizer( data_loader, model, lrp_object, plot_config, draw_rectangles, rectangle_config, ) self.data_vis = DataVisualizer( data_loader, plot_config, draw_rectangles, rectangle_config ) self.model_device = next(model.parameters()).device
[docs] def visualize_batch( self, batch_num: int, lrp_condition_on: ConditionOn = ConditionOn.PROPER_LABEL, lrp_show_cbar: bool = True, max_images: int | None = 36, show_labels: bool = True, ) -> None: """ Args: batch_num: int: lrp_condition_on: ConditionOn: (Default value = ConditionOn.PROPER_LABEL) lrp_show_cbar: bool: (Default value = True) max_images: int | None: (Default value = 36) show_labels: bool: (Default value = True) Returns: """ self.lrp_vis.visualize_batch( batch_num, lrp_condition_on, lrp_show_cbar, max_images ) if show_labels: data = get_nth_batch(self.data_loader, batch_num)[0].to(self.model_device) preds = self.model(data).argmax(dim=1) else: preds = None self.data_vis.visualize_batch( batch_num, max_images, batch_preds=preds, show_labels=show_labels ) self.__build_plot()
[docs] def visualize_agg(self, batch_num: int) -> None: """ Args: batch_num: int: Returns: """ self.lrp_vis.visualize_agg(batch_num) self.data_vis.visualize_agg(batch_num) self.__build_plot()
def __build_plot(self) -> None: f1 = self.lrp_vis.figure f2 = self.data_vis.figure def figure_to_array(fig): """Convert a Matplotlib figure to a numpy array. Args: fig: Returns: """ fig.canvas.draw() # Draw the figure # Get the image as an RGBA buffer and convert it to a numpy array img = np.asarray(fig.canvas.buffer_rgba()) return img[..., :3] # Convert RGBA to RGB if needed # Extract image data from the figures img1 = figure_to_array(f1) img2 = figure_to_array(f2) # Plot them side by side fig, axs = self.get_canvas(cols=2, shape=(10, 6)) axs[0].imshow(img2) axs[1].imshow(img1) # Close plots in the sub-visualizers self.lrp_vis.close_plot() self.data_vis.close_plot()