Source code for detoxai.methods.savani.zhang

import logging

import lightning as L
import torch
import torch.nn as nn
from torch import autograd
from torch.utils.data import DataLoader
from tqdm import tqdm

from ...metrics.bias_metrics import (
    BiasMetrics,
)

# Project imports
from .savani_base import SavaniBase

logger = logging.getLogger(__name__)


[docs] class ZhangM(SavaniBase): """Brian Hu Zhang, Blake Lemoine, Margaret Mitchell - "Mitigating unwanted biases with adversarial learning""" def __init__( self, model: nn.Module | L.LightningModule, experiment_name: str, device: str, seed: int = 123, **kwargs, ) -> None: super().__init__(model, experiment_name, device, seed)
[docs] def apply_model_correction( self, dataloader: DataLoader, last_layer_name: str, epsilon: float = 0.1, bias_metric: BiasMetrics | str = BiasMetrics.EO_GAP, iterations: int = 5, critic_iterations: int = 5, model_iterations: int = 2, train_batch_size: int = 128, thresh_optimizer_maxiter: int = 100, tau_init: float = 0.5, # alpha: float = 5.0, critic_lr: float = 2e-4, model_lr: float = 1e-4, critic_linear: list[int] = [256, 256, 256], outputs_are_logits: bool = True, n_eval_batches: int = 3, soft_thresh_temperature: float = 10.0, **kwargs, ) -> None: """backward Do layer-wise optimization to find the best weights for each layer and the best threshold tau In options you can specify that your model already outputs probabilities, in which case the model will not apply the softmax function options = {'outputs_are_logits': False} Args: dataloader: DataLoader: last_layer_name: str: epsilon: float: (Default value = 0.1) bias_metric: BiasMetrics | str: (Default value = BiasMetrics.EO_GAP) iterations: int: (Default value = 5) critic_iterations: int: (Default value = 5) model_iterations: int: (Default value = 2) train_batch_size: int: (Default value = 128) thresh_optimizer_maxiter: int: (Default value = 100) tau_init: float: (Default value = 0.5) # alpha: float: (Default value = 5.0) critic_lr: float: (Default value = 2e-4) model_lr: float: (Default value = 1e-4) critic_linear: list[int]: (Default value = [256) 256: 256]: outputs_are_logits: bool: (Default value = True) n_eval_batches: int: (Default value = 3) soft_thresh_temperature: float: (Default value = 10.0) **kwargs: Returns: """ assert self.check_layer_name_exists(last_layer_name), ( f"Layer name {last_layer_name} not found in the model" ) assert outputs_are_logits, "Only logits are supported at the moment" self.last_layer_name = last_layer_name self.tau_init = tau_init self.epsilon = epsilon self.bias_metric = bias_metric self.outputs_are_logits = outputs_are_logits self.n_eval_batches = n_eval_batches self.initialize_dataloader(dataloader, train_batch_size) if bias_metric.value == BiasMetrics.DP_GAP.value: # 2 because wepass only the predictions as input self.critic = self.get_critic(2, critic_linear) elif bias_metric.value == BiasMetrics.EO_GAP.value: # 4 because we pass the predictions and the true labels as input self.critic = self.get_critic(3, critic_linear) else: raise ValueError(f"Not supported: {bias_metric.value}") critic_criterion = nn.CrossEntropyLoss() critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr) model_optimizer = torch.optim.Adam(self.model.parameters(), lr=model_lr) model_loss = nn.CrossEntropyLoss() for i in tqdm(range(iterations + 1), desc="Zhang: Adversarial Fine Tuning"): t = i # t is the iteration number, starting from 1 alpha = t**0.5 # as in the paper alpha = sqrt(t) # in the paper it is lr = 1/t, but we want to be able to set a base lr, # so we multiply the base lr by 1/t if t > 0: new_model_lr = model_lr * (1 / t) for g in model_optimizer.param_groups: g["lr"] = new_model_lr logger.debug(f"Minibatch no. {i}") for param in self.critic.parameters(): param.requires_grad = True for param in self.model.parameters(): param.requires_grad = False self.model.eval() self.critic.train() # Train the critic for j in range(critic_iterations): x, y_true, prot_attr = self.sample_batch() with torch.no_grad(): y_logits = self.model(x) if bias_metric.value == BiasMetrics.DP_GAP.value: c_pred = self.critic(y_logits) elif bias_metric.value == BiasMetrics.EO_GAP.value: combined = torch.cat([y_logits, y_true.unsqueeze(1)], dim=1) c_pred = self.critic(combined) else: raise ValueError(f"Not supported: {bias_metric.value}") c_loss = critic_criterion(c_pred, prot_attr.long()) c_loss.backward() critic_optimizer.step() critic_optimizer.zero_grad() model_optimizer.zero_grad() logger.debug(f"[{j}] Critic loss: {c_loss.item()}") for param in self.critic.parameters(): param.requires_grad = False for param in self.model.parameters(): param.requires_grad = True self.model.train() self.critic.eval() if i > 0: # Skip the first iteration # Train the model for j in range(model_iterations): x, y_true, prot_attr = self.sample_batch() y_logits = self.model(x) if bias_metric.value == BiasMetrics.DP_GAP.value: c_pred = self.critic(y_logits).squeeze() elif bias_metric.value == BiasMetrics.EO_GAP.value: combined = torch.cat([y_logits, y_true.unsqueeze(1)], dim=1) c_pred = self.critic(combined).squeeze() else: raise ValueError(f"Not supported: {bias_metric.value}") c_loss = critic_criterion(c_pred, prot_attr.long()) m_loss = model_loss(y_logits, y_true.long()) for name, param in self.model.named_parameters(): try: m_grad = autograd.grad(m_loss, param, retain_graph=True)[0] c_grad = autograd.grad(c_loss, param, retain_graph=True)[0] except RuntimeError as e: logger.warning( RuntimeError( f"[{i},{j}] Grad error in layer {name}: {e}" ) ) continue shape = c_grad.shape m_grad = m_grad.flatten() c_grad = c_grad.flatten() m_grad_proj = (m_grad.T @ c_grad) * c_grad grad = m_grad - m_grad_proj - alpha * c_grad grad = grad.reshape(shape) param.backward(grad) model_optimizer.step() model_optimizer.zero_grad() critic_optimizer.zero_grad() logger.debug(f"[{j}] Model loss: {m_loss.item()}") tau, phi = self.optimize_tau(tau_init, thresh_optimizer_maxiter) logger.info(f"Best tau: {tau}, Best phi: {phi}") if hasattr(self, "lightning_model"): self.lightning_model.model = self.model # Add a hook with the best transformation self.apply_hook(tau, soft_thresh_temperature)
[docs] def get_critic( self, input_dim: int, critic_linear: list[int], ) -> nn.Module: """ Args: input_dim: int: critic_linear: list[int]: Returns: """ critic_layers = [ nn.Linear(input_dim, critic_linear[0]), nn.ReLU(), nn.Dropout(0.2), ] for i in range(1, len(critic_linear)): critic_layers += [ nn.Linear(critic_linear[i - 1], critic_linear[i]), nn.ReLU(), nn.Dropout(0.2), ] critic_layers.append(nn.Linear(critic_linear[-1], 2)) return nn.Sequential(*critic_layers).to(self.device)