Source code for detoxai.methods.savani.savani_base

import logging
from abc import ABC, abstractmethod

import lightning as L
import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import sigmoid, softmax
from torch.utils.data import DataLoader

from ...utils.dataloader import copy_data_loader

# Project imports
from ..model_correction import ModelCorrectionMethod
from .utils import phi_torch

logger = logging.getLogger(__name__)

Batch = tuple[torch.Tensor, torch.Tensor, torch.Tensor]


[docs] class SavaniBase(ModelCorrectionMethod, ABC): """ """ def __init__( self, model: nn.Module | L.LightningModule, experiment_name: str, device: str, seed: int = 123, ) -> None: super().__init__(model, experiment_name, device) self.seed = seed torch.manual_seed(seed) np.random.seed(seed)
[docs] @abstractmethod def apply_model_correction(self) -> None: """ """ raise NotImplementedError
[docs] def optimize_tau( self, tau_init: float, thresh_optimizer_maxiter: int ) -> tuple[float, float]: """ Args: tau_init: float: thresh_optimizer_maxiter: int: Returns: """ objective_fn = self.objective_thresh("torch", True, "max") best_phi = 1e-6 tau = tau_init for _tau in torch.linspace(0.05, 0.95, thresh_optimizer_maxiter): phi = objective_fn(_tau) if phi > best_phi: best_phi = phi tau = _tau return tau, best_phi
[docs] def objective_thresh( self, backend: str, cache_preds: bool = True, direction: str = "min" ) -> callable: """ Args: backend: str: cache_preds: bool: (Default value = True) direction: str: (Default value = "min") Returns: """ if cache_preds: y_probs, y_true, prot_attr = self.get_pred_true_prot() y_preds = y_probs[:, 1] if direction == "min": d_mul = -1 elif direction == "max": d_mul = 1 else: raise ValueError(f"Direction {direction} not supported") if backend == "torch": def objective(tau): """ Args: tau: Returns: """ phi, _ = self.phi_torch(tau, (y_preds, y_true, prot_attr)) return phi.detach().cpu().numpy() * d_mul elif backend == "np": raise NotImplementedError("Numpy backend not implemented") else: raise ValueError(f"Backend {backend} not supported") return objective
[docs] def phi_torch( self, tau: torch.Tensor, cached: tuple | None = None ) -> tuple[torch.Tensor, torch.Tensor]: """Calculate the phi metric for a given threshold tau Args: tau: torch.Tensor: cached: tuple | None: (Default value = None) Returns: """ if cached is None: y_probs, y_true, prot_attr = self.get_pred_true_prot() y_preds = y_probs[:, 1] else: y_preds, y_true, prot_attr = cached return phi_torch( y_true, y_preds > tau.to(self.device), prot_attr, self.epsilon, self.bias_metric, )
[docs] def apply_hook(self, tau: float, temperature: float = 100) -> None: """ Args: tau: float: temperature: float: (Default value = 100) Returns: """ def hook(module, input, output): """ Args: module: input: output: Returns: """ # output = (output > tau).int() # doesn't allow gradients to flow # Assuming binary classification if self.outputs_are_logits: probs = softmax(output, dim=1) # soft thresholding output[:, 1] = sigmoid((probs[:, 1] - tau) * temperature) output[:, 0] = 1 - output[:, 1] else: # soft thresholding output[:, 1] = sigmoid((output[:, 1] - tau) * temperature) output[:, 0] = 1 - output[:, 1] # logger.debug(f"Savani hook fired in layer: {module}") return output hook_fn = hook # Register the hook on the model hooks = [] for name, module in self.model.named_modules(): if isinstance(module, nn.Linear) and name == self.last_layer_name: handle = module.register_forward_hook(hook_fn) logger.debug(f"Hook registered on layer: {name}") hooks.append(handle) self.hooks = hooks
[docs] def get_pred_true_prot(self) -> Batch: """ """ Y_preds, Y_true, ProtAttr = [], [], [] with torch.no_grad(): for _ in range(self.n_eval_batches): x, y_true, prot = self.sample_batch() y_logit = self.model(x) if self.outputs_are_logits: y_probs = softmax(y_logit, dim=1) else: y_probs = y_logit Y_preds.append(y_probs) Y_true.append(y_true) ProtAttr.append(prot) Y_preds = torch.cat(Y_preds).to(self.device) Y_true = torch.cat(Y_true).to(self.device) ProtAttr = torch.cat(ProtAttr).to(self.device) return Y_preds, Y_true, ProtAttr
[docs] def check_layer_name_exists(self, layer_name: str) -> bool: """ Args: layer_name: str: Returns: """ for name, _ in self.model.named_modules(): if name == layer_name: return True return False
[docs] def sample_batch(self) -> Batch: """Sample a single batch from a dataloader""" try: batch: Batch = next(self.dl_iter) except StopIteration: self.dl_iter = iter(self.internal_dl) batch: Batch = next(self.dl_iter) except AttributeError: self.dl_iter = iter(self.internal_dl) batch: Batch = next(self.dl_iter) x, y, p = batch x = x.to(self.device) y = y.to(self.device) p = p.to(self.device) return x, y, p
[docs] def initialize_dataloader(self, dataloader: DataLoader, batch_size: int) -> None: """ Args: dataloader: DataLoader: batch_size: int: Returns: """ self.internal_dl = copy_data_loader( dataloader, batch_size=batch_size, shuffle=True, drop_last=True )