Source code for detoxai.core.xai

import os
import time
from abc import ABC, abstractmethod

import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats
import torch
import torch.nn as nn
from tqdm import tqdm

from ..utils.dataloader import DetoxaiDataLoader
from ..visualization import ConditionOn, LRPHandler


[docs] class XAIMetricsCalculator: """ """ def __init__(self, dataloader: DetoxaiDataLoader, lrphandler: LRPHandler) -> None: self.dataloader = dataloader self.lrphandler = lrphandler def _symmetrize( self, sailmaps: np.ndarray, neutral_point: float = 0.5 ) -> np.ndarray: """ Args: sailmaps: np.ndarray: neutral_point: float: (Default value = 0.5) Returns: """ return np.abs(sailmaps - neutral_point)
[docs] def calculate_metrics( self, model: nn.Module, rect_pos: tuple[int, int], rect_size: tuple[int, int], vanilla_model: nn.Module = None, sailmap_metrics: list[str] = [ "RRF", "HRF", "MRR", "DET", "ADR", "DIF", "RDDT", ], batches: int = 2, condition_on: str = ConditionOn.PROPER_LABEL.value, verbose: bool = False, # source_range: tuple[float, float] = (0, 1), neutral_point: float = 0.5, abs_on_neutral: bool = True, ) -> dict[str, float]: """Calculate the metrics for the given model and sailmaps Args: model: nn rect_pos: tuple rect_size: tuple vanilla_model: nn sailmap_metrics: list batches: int condition_on: str verbose: bool model: nn.Module: rect_pos: tuple[int: int]: rect_size: tuple[int: vanilla_model: nn.Module: (Default value = None) sailmap_metrics: list[str]: batches: int: (Default value = 2) condition_on: str: (Default value = ConditionOn.PROPER_LABEL.value) verbose: bool: (Default value = False) # source_range: tuple[float: float]: (Default value = (0)) neutral_point: float: (Default value = 0.5) abs_on_neutral: bool: (Default value = True) Returns: - `dict[str, float]`: The calculated metrics where the key is the metric name and the value is the calculated metric """ metrics_calcs: list["SailRectMetric"] = [] for metric in sailmap_metrics: if metric == "RRF": metrics_calcs.append(RRF()) elif metric == "HRF": metrics_calcs.append(HRF()) elif metric == "MRR": metrics_calcs.append(MRR()) elif metric == "DET": metrics_calcs.append(DET()) elif metric == "ADR": if vanilla_model is None: raise ValueError("ADR requires a vanilla model for comparison") metrics_calcs.append(ADR()) elif metric == "DIF": if vanilla_model is None: raise ValueError("DIF requires a vanilla model for comparison") metrics_calcs.append(DIF()) elif metric == "RDDT": if vanilla_model is None: raise ValueError("RDDT requires a vanilla model for comparison") metrics_calcs.append(RDDT()) else: raise ValueError(f"Metric {metric} is not supported") for i in tqdm(range(batches), disable=not verbose, desc="Calculating metrics"): lrpres = self.lrphandler.calculate(model, self.dataloader, batch_num=i) if vanilla_model is not None: vanilla_lrpres = self.lrphandler.calculate( vanilla_model, self.dataloader, batch_num=i ) _, labels, _ = self.dataloader.get_nth_batch(i) # noqa conditioned = [] for i, label in enumerate(labels): # Assuming binary classification label = ( label if condition_on == ConditionOn.PROPER_LABEL.value else 1 - label ) conditioned.append(lrpres[label, i]) sailmaps: torch.Tensor = torch.stack(conditioned).to(dtype=float) sailmaps = sailmaps.cpu().detach().numpy() if abs_on_neutral: sailmaps = self._symmetrize(sailmaps, neutral_point) if vanilla_model is not None: vanilla_sailmaps = torch.stack( [vanilla_lrpres[label, i] for i, label in enumerate(labels)] ).to(dtype=float) vanilla_sailmaps = vanilla_sailmaps.cpu().detach().numpy() if abs_on_neutral: vanilla_sailmaps = self._symmetrize(vanilla_sailmaps, neutral_point) for metric in metrics_calcs: if isinstance(metric, (ADR, DIF, RDDT)): if np.allclose(sailmaps, vanilla_sailmaps): # If the sailmaps are the same, the metric will be 0 metric.metvals.extend(np.zeros(sailmaps.shape[0])) else: metric.aggregate( sailmaps, rect_pos, rect_size, vanilla_sailmaps ) else: metric.aggregate(sailmaps, rect_pos, rect_size) ret = {} for metric in metrics_calcs: ret[str(metric)] = metric.reduce() return ret
[docs] class SailRectMetric(ABC): """ """ def __init__(self) -> None: self.sailmaps = None self.metvals: np.ndarray = [] def _sailmaps_rect( self, sailmaps: np.ndarray, rect_pos: tuple[int, int], rect_size: tuple[int, int], ) -> np.ndarray: """ Args: sailmaps: np.ndarray: rect_pos: tuple[int: int]: rect_size: tuple[int: Returns: """ assert isinstance(sailmaps, np.ndarray), "Sailmaps should be a numpy array" return sailmaps[ :, rect_pos[0] : rect_pos[0] + rect_size[0], rect_pos[1] : rect_pos[1] + rect_size[1], ]
[docs] def calculate_batch( self, sailmaps: np.ndarray, rect_pos: tuple[int, int], rect_size: tuple[int, int], ret_format: tuple[str] = ("mean", "std"), ) -> dict[str, float]: """Calculate the metric for a single batch of sailmaps Args: sailmaps: np.ndarray: rect_pos: tuple[int: int]: rect_size: tuple[int: ret_format: tuple[str]: (Default value = ("mean") "std"): Returns: """ c = self._core(sailmaps, rect_pos, rect_size) return self.structure_output(c, ret_format)
[docs] def reduce(self, ret_format: tuple[str] = ("mean", "std")) -> dict[str, float]: """Calculate the metric for already aggregated sailmaps Args: ret_format: tuple[str]: (Default value = ("mean") "std"): Returns: """ return self.structure_output(self.metvals, ret_format)
[docs] def aggregate( self, sailmaps: np.ndarray, rect_pos: tuple[int, int], rect_size: tuple[int, int], vanilla_sailmaps: np.ndarray = None, ): """Aggregate sailmaps for later calculation Args: sailmaps: np.ndarray: rect_pos: tuple[int: int]: rect_size: tuple[int: vanilla_sailmaps: np.ndarray: (Default value = None) Returns: """ if vanilla_sailmaps is not None: c = self._core(sailmaps, rect_pos, rect_size, vanilla_sailmaps) else: c = self._core(sailmaps, rect_pos, rect_size) assert isinstance(c, np.ndarray), "Output should be a numpy array" self.metvals.extend(c)
@abstractmethod def _core( self, sailmaps: np.ndarray, rect_pos: tuple[int, int], rect_size: tuple[int, int], ) -> np.ndarray: """ Args: sailmaps: np.ndarray: rect_pos: tuple[int: int]: rect_size: tuple[int: Returns: """ pass
[docs] def structure_output( self, per_sample: np.ndarray[float], ret_format: tuple[str] = ("mean", "std") ) -> dict[str, float]: """ Args: per_sample: np.ndarray[float]: ret_format: tuple[str]: (Default value = ("mean") "std"): Returns: """ ret = {} if "mean" in ret_format: ret["mean"] = np.mean(per_sample) if "std" in ret_format: ret["std"] = np.std(per_sample) if "min" in ret_format: ret["min"] = np.min(per_sample) if "max" in ret_format: ret["max"] = np.max(per_sample) if "median" in ret_format: ret["median"] = np.median(per_sample) return ret
def __str__(self) -> str: if hasattr(self, "name"): return self.name return self.__class__.__name__ def __repr__(self) -> str: return self.__str__()
[docs] class RRF(SailRectMetric): """Rectangle Relevance Fraction \begin{equation} \mathbf{RRF} = \frac{\displaystyle \sum_{(i,j) \in R} p_{ij}}{\displaystyle \sum_{i = 1}^N \sum_{j = 1}^M p_{ij}} \end{equation} Here, $\mathbf{RRF}$ measures the fraction of total relevance that falls within ROI. Args: Returns: """ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.name = "RRF" def _core( self, sailmaps: np.ndarray, rect_pos: tuple[int, int], rect_size: tuple[int, int], ) -> np.ndarray: """ Args: sailmaps: np.ndarray: rect_pos: tuple[int: int]: rect_size: tuple[int: Returns: """ sm_rect = self._sailmaps_rect(sailmaps, rect_pos, rect_size) print(f"sm_rect shape: {sm_rect.shape}") # Sm_rectr is 32 x Width x Height, I want to matplotlib savefig this fig, ax = plt.subplots(4, 4, figsize=(20, 20)) ax = ax.flatten() for i in range(16): ax[i].imshow(sm_rect[i]) ax[i].axis("off") os.makedirs("/workspace/debug/xai_images", exist_ok=True) plt.savefig("/workspace/debug/xai_images" + str(time.time()) + ".png") plt.close() r_sum = sm_rect.reshape(len(sm_rect), -1).sum(axis=1) s_sum = sailmaps.reshape(len(sm_rect), -1).sum(axis=1) print(f"RRF shape: {r_sum.shape}, {s_sum.shape}") return r_sum / s_sum # safe bc s_sum > r_sum and never 0
[docs] class HRF(SailRectMetric): """\subsection{High-Relevance Fraction (HRF)} \begin{equation} \mathbf{HRF} = \displaystyle \frac{1}{\vert R \vert} \sum_{(i,j) \in R} \mathbbm{1}_{\{p_{ij} > \epsilon\}} \end{equation} $\mathbf{HRF}$ quantifies the proportion of pixels inside the ROI whose relevance exceeds a predefined threshold $\epsilon$, indicating how many pixels are highly important for prediction. Args: Returns: """ def __init__( self, epsilon: float = 0.05, **kwargs, ) -> None: super().__init__(**kwargs) self.epsilon = epsilon self.name = "HRF" def _core( self, sailmaps: np.ndarray, rect_pos: tuple[int, int], rect_size: tuple[int, int], ) -> np.ndarray: """ Args: sailmaps: np.ndarray: rect_pos: tuple[int: int]: rect_size: tuple[int: Returns: """ sm_rect = self._sailmaps_rect(sailmaps, rect_pos, rect_size) sm_rect = sm_rect.reshape(len(sm_rect), -1) rect_size = sm_rect.shape[1] high_relevance = np.sum(sm_rect > self.epsilon, axis=1) return high_relevance / rect_size
[docs] class MRR(SailRectMetric): """\subsection{Mean Relevance Ratio (MRR)} \begin{equation} \mathbf{MRR} = \frac{\displaystyle \frac{1}{\vert R \vert} \sum_{(i,j) \in R} p_{ij}}{\displaystyle \frac{1}{N M - \vert R \vert} \sum_{(i,j) \notin R} p_{ij}}, \end{equation} $\mathbf{MRR}$ quantifies the ratio of the mean pixel value inside the ROI to the mean pixel value outside it. $\mathbf{MRR} = 1$ indicates that the mean values are equal, while $\mathbf{MRR} > 1$ says the mean pixel within the ROI has a higher intensity. Args: Returns: """ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.name = "MRR" def _core( self, sailmaps: np.ndarray, rect_pos: tuple[int, int], rect_size: tuple[int, int], ) -> np.ndarray: """ Args: sailmaps: np.ndarray: rect_pos: tuple[int: int]: rect_size: tuple[int: Returns: """ sm_rect = self._sailmaps_rect(sailmaps, rect_pos, rect_size) sm_outside = sailmaps.copy() sm_outside[ :, rect_pos[0] : rect_pos[0] + rect_size[0], rect_pos[1] : rect_pos[1] + rect_size[1], ] = 0 sm_outside_sum = sm_outside.reshape(len(sm_outside), -1).sum(axis=1) total_pixels = sm_outside[0].size rect_pixels = sm_rect[0].size sm_outside_mean = sm_outside_sum / (total_pixels - rect_pixels) sm_rect_mean = sm_rect.reshape(len(sm_rect), -1).sum(axis=1) / rect_pixels return sm_rect_mean / sm_outside_mean #
[docs] class DET(SailRectMetric): """\subsection{Distribution Equivalence Testing (DET)} The goal of the statistical test is to determine whether the pixels \textit{inside} the rectangle have higher intensity than those \textit{outside} the rectangle. Since the number of pixels and their intensity distributions inside and outside the ROI can vary, a non-parametric, unpaired statistical Mann-Whitney-Wilcoxon test is used. This permutation test assesses whether the intensity values from one group (inside) tend to be higher than those from the other (outside). The null hypothesis $H_0$ for the test is that the intensity distributions inside and outside the rectangle are equal: \begin{equation} \begin{split} H_0: F_{\text{inside}}(x) &= F_{\text{outside}}(x) \\ H_1: F_{\text{inside}}(x) &> F_{\text{outside}}(x) \end{split} \end{equation} To perform the test, all pixel intensities are ranked, and the sum of ranks for each group (inside and outside the ROI) is computed. The test then evaluates the probability that the intensity values inside the rectangle are statistically higher than those outside. The final outcome of the DET is a binary decision: \textbf{TRUE} indicates that the null hypothesis is rejected (i.e., there is statistically significant evidence that the pixels inside the rectangle have higher intensity), while \textbf{FALSE} signifies that we fail to reject the null hypothesis, meaning that the evidence is inconclusive regarding a higher intensity inside the rectangle. Args: Returns: """ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.name = "DET" def _core( self, sailmaps: np.ndarray, rect_pos: tuple[int, int], rect_size: tuple[int, int], ) -> np.ndarray: """ Args: sailmaps: np.ndarray: rect_pos: tuple[int: int]: rect_size: tuple[int: Returns: """ sm_rect = self._sailmaps_rect(sailmaps, rect_pos, rect_size) sm_outside = sailmaps.copy() sm_outside[ :, rect_pos[0] : rect_pos[0] + rect_size[0], rect_pos[1] : rect_pos[1] + rect_size[1], ] = 0 aggregated_sm = sm_rect.reshape(len(sm_rect), -1) mean_sm = np.mean(aggregated_sm, axis=1) aggregated_outside = sm_outside.reshape(len(sm_outside), -1) mean_outside = np.mean(aggregated_outside, axis=1) return mean_sm - mean_outside
[docs] def reduce(self, ret_format: tuple[str] = ("mean", "std")) -> dict[str, float]: """Calculate the metric for already aggregated sailmaps Args: ret_format: tuple[str]: (Default value = ("mean") "std"): Returns: """ ret = dict() _, p = stats.ttest_1samp(self.metvals, 0, alternative="greater") ret["mean"] = p < 0.01 ret["result"] = p < 0.01 ret["std"] = p ret["p-value"] = p return ret
[docs] class ADR(SailRectMetric): """Average Difference in Region (ADR) ADR measures the mean pixel-wise difference between vanilla and debiased saliency maps within the region of interest (ROI). A positive value indicates that vanilla saliency values are generally higher than debiased ones in the region. Args: Returns: """ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.name = "ADR" def _core( self, sailmaps: np.ndarray, rect_pos: tuple[int, int], rect_size: tuple[int, int], vanilla_sailmaps: np.ndarray = None, ) -> np.ndarray: """ Args: sailmaps: np.ndarray: rect_pos: tuple[int: int]: rect_size: tuple[int: vanilla_sailmaps: np.ndarray: (Default value = None) Returns: """ sm_rect = self._sailmaps_rect(sailmaps, rect_pos, rect_size) vanilla_sm_rect = self._sailmaps_rect(vanilla_sailmaps, rect_pos, rect_size) # Calculate mean difference per image diff = vanilla_sm_rect - sm_rect return diff.reshape(len(diff), -1).mean(axis=1)
[docs] class DIF(SailRectMetric): """Decreased Intensity Fraction (DIF) DIF measures the ratio of pixels showing decreased intensity in the debiased model compared to the vanilla model. It represents the fraction of pixels inside a rectangle that significantly flipped their saliency value. Args: Returns: """ def __init__(self, eps: float = 1e-3, **kwargs) -> None: super().__init__(**kwargs) self.name = "DIF" self.eps = eps def _core( self, sailmaps: np.ndarray, rect_pos: tuple[int, int], rect_size: tuple[int, int], vanilla_sailmaps: np.ndarray = None, ) -> np.ndarray: """ Args: sailmaps: np.ndarray: rect_pos: tuple[int: int]: rect_size: tuple[int: vanilla_sailmaps: np.ndarray: (Default value = None) Returns: """ sm_rect = self._sailmaps_rect(sailmaps, rect_pos, rect_size) vanilla_sm_rect = self._sailmaps_rect(vanilla_sailmaps, rect_pos, rect_size) # Calculate fraction of pixels where debiased < vanilla diff = vanilla_sm_rect - sm_rect decreased = (diff > self.eps).reshape(len(diff), -1) return decreased.sum(axis=1) / decreased.shape[1]
[docs] class RDDT(SailRectMetric): """Rectangle Difference Distribution Testing (RDDT) Performs a Wilcoxon signed rank test to determine if pixels from the vanilla model have significantly higher intensity than those from the debiased model within the ROI. Returns 1 if the test rejects the null hypothesis (indicating vanilla has higher intensity), 0 otherwise. Args: Returns: """ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.name = "RDDT" def _core( self, sailmaps: np.ndarray, rect_pos: tuple[int, int], rect_size: tuple[int, int], vanilla_sailmaps: np.ndarray = None, ) -> np.ndarray: """ Args: sailmaps: np.ndarray: rect_pos: tuple[int: int]: rect_size: tuple[int: vanilla_sailmaps: np.ndarray: (Default value = None) Returns: """ sm_rect = self._sailmaps_rect(sailmaps, rect_pos, rect_size) vanilla_sm_rect = self._sailmaps_rect(vanilla_sailmaps, rect_pos, rect_size) aggregated_vanilla_sm = vanilla_sm_rect.reshape(len(vanilla_sm_rect), -1) aggregated_sm = sm_rect.reshape(len(sm_rect), -1) mean_vanilla_sm = np.mean(aggregated_vanilla_sm, axis=1) mean_sm = np.mean(aggregated_sm, axis=1) return mean_vanilla_sm - mean_sm
[docs] def reduce(self, ret_format: tuple[str] = ("mean", "std")) -> dict[str, float]: """ Args: ret_format: tuple[str]: (Default value = ("mean") "std"): Returns: """ ret = dict() _, p = stats.ttest_1samp(self.metvals, 0, alternative="greater") ret["mean"] = p < 0.01 ret["result"] = p < 0.01 ret["std"] = p ret["p-value"] = p return ret