Source code for detoxai.methods.savani.adversarial

import logging

import lightning as L
import torch
import torch.nn as nn
from torch.nn.functional import softmax
from torch.utils.data import DataLoader
from tqdm import tqdm

from ...metrics.bias_metrics import (
    BiasMetrics,
    calculate_bias_metric_torch,
)

# Project imports
from .savani_base import SavaniBase

logger = logging.getLogger(__name__)


[docs] class SavaniAFT(SavaniBase): """ """ def __init__( self, model: nn.Module | L.LightningModule, experiment_name: str, device: str, seed: int = 123, **kwargs, ) -> None: super().__init__(model, experiment_name, device, seed)
[docs] def apply_model_correction( self, dataloader: DataLoader, last_layer_name: str, epsilon: float = 0.1, bias_metric: BiasMetrics | str = BiasMetrics.EO_GAP, iterations: int = 10, critic_iterations: int = 5, model_iterations: int = 5, train_batch_size: int = 128, thresh_optimizer_maxiter: int = 100, tau_init: float = 0.5, lam: float = 1.0, delta: float = 0.01, critic_lr: float = 1e-4, model_lr: float = 1e-4, critic_filters: list[int] = [8, 16, 32], critic_linear: list[int] = [32], outputs_are_logits: bool = True, n_eval_batches: int = 3, soft_thresh_temperature: float = 10.0, **kwargs, ) -> None: """backward Do layer-wise optimization to find the best weights for each layer and the best threshold tau Args: dataloader: DataLoader: last_layer_name: str: epsilon: float: (Default value = 0.1) bias_metric: BiasMetrics | str: (Default value = BiasMetrics.EO_GAP) iterations: int: (Default value = 10) critic_iterations: int: (Default value = 5) model_iterations: int: (Default value = 5) train_batch_size: int: (Default value = 128) thresh_optimizer_maxiter: int: (Default value = 100) tau_init: float: (Default value = 0.5) lam: float: (Default value = 1.0) delta: float: (Default value = 0.01) critic_lr: float: (Default value = 1e-4) model_lr: float: (Default value = 1e-4) critic_filters: list[int]: (Default value = [8) 16: 32]: critic_linear: list[int]: (Default value = [32]) outputs_are_logits: bool: (Default value = True) n_eval_batches: int: (Default value = 3) soft_thresh_temperature: float: (Default value = 10.0) **kwargs: Returns: """ assert self.check_layer_name_exists(last_layer_name), ( f"Layer name {last_layer_name} not found in the model" ) self.last_layer_name = last_layer_name self.tau_init = tau_init self.epsilon = epsilon self.bias_metric = bias_metric self.outputs_are_logits = outputs_are_logits self.lam = lam self.delta = delta self.n_eval_batches = n_eval_batches self.initialize_dataloader(dataloader, train_batch_size) self.__sample_example, _, _ = self.sample_batch() channels = self.__sample_example.shape[1] self.critic = self.get_critic( channels, critic_filters, critic_linear, train_batch_size ) critic_criterion = nn.MSELoss() critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr) model_optimizer = torch.optim.Adam(self.model.parameters(), lr=model_lr) self.model_loss = nn.CrossEntropyLoss() for i in tqdm(range(iterations), desc="Savani: Adversarial Fine Tuning"): logger.debug(f"Minibatch no. {i}") # Train the critic for j in range(critic_iterations): self.model.eval() self.critic.train() x, y_true, prot_attr = self.sample_batch() with torch.no_grad(): # Assuming binary classification and logits y_logits = self.model(x) if self.outputs_are_logits: y_pred = softmax(y_logits, dim=1) else: # probabilties y_pred = y_logits y_pred = torch.argmax(y_pred, dim=1) bias = calculate_bias_metric_torch( self.bias_metric, y_pred, y_true, prot_attr ) c_loss = critic_criterion(self.critic(x)[0], bias) critic_optimizer.zero_grad() c_loss.backward() critic_optimizer.step() logger.debug(f"[{j}] Critic loss: {c_loss.item()}") # Train the model for j in range(model_iterations): self.model.train() self.critic.eval() x, y_true, prot_attr = self.sample_batch() y_logits = self.model(x) m_loss = self.fair_loss(y_logits, y_true, x) model_optimizer.zero_grad() m_loss.backward() model_optimizer.step() logger.debug(f"[{j}] Model loss: {m_loss.item()}") tau, phi = self.optimize_tau(tau_init, thresh_optimizer_maxiter) logger.info(f"Best tau: {tau}, Best phi: {phi}") if hasattr(self, "lightning_model"): self.lightning_model.model = self.model # Add a hook with the best transformation self.apply_hook(tau, soft_thresh_temperature)
[docs] def fair_loss(self, y_logits, y_true, input): """ Args: y_logits: y_true: input: Returns: """ fair = torch.max( torch.tensor(1, dtype=torch.float32, device=self.device), self.lam * (self.critic(input).squeeze() - self.epsilon + self.delta) + 1, ) return self.model_loss(y_logits, y_true) * fair
[docs] def get_critic( self, channels: int, critic_filters: list[int], critic_linear: list[int], batch_size: int, ) -> nn.Module: """ Args: channels: int: critic_filters: list[int]: critic_linear: list[int]: batch_size: int: Returns: """ encoder_layers = [ nn.Conv2d(channels, critic_filters[0], 3, padding="same"), nn.ReLU(), ] for i in range(1, len(critic_filters)): encoder_layers += [ nn.Conv2d(critic_filters[i - 1], critic_filters[i], 3, padding="same"), nn.ReLU(), nn.MaxPool2d(2), ] # Add adaptive pooling layer encoder_layers.append(nn.AdaptiveAvgPool2d(3)) encoder_layers.append(nn.Flatten(start_dim=0)) encoder = nn.Sequential(*encoder_layers).to(self.device) with torch.no_grad(): size_after = encoder(self.__sample_example[:batch_size]).shape[0] critic_layers = [encoder, nn.Linear(size_after, critic_linear[0]), nn.ReLU()] for i in range(1, len(critic_linear)): critic_layers += [ nn.Linear(critic_linear[i - 1], critic_linear[i]), nn.ReLU(), ] critic_layers.append(nn.Linear(critic_linear[-1], 1)) return nn.Sequential(*critic_layers).to(self.device)