Source code for detoxai.methods.clarcs.hooks

import logging
import random

import torch
from torch import nn

logger = logging.getLogger(__name__)


[docs] def stabilize(x: torch.Tensor, epsilon: float = 1e-8) -> torch.Tensor: """ Args: x: torch.Tensor: epsilon: float: (Default value = 1e-8) Returns: """ return x + epsilon
[docs] def mass_mean_probe_hook(probe: torch.Tensor, alpha: float): """ Args: probe: torch.Tensor: alpha: float: Returns: """ def hook(module: nn.Module, input: tuple, output: torch.Tensor): """ Args: module: nn.Module: input: tuple: output: torch.Tensor: Returns: """ nonlocal probe, alpha o = output.clone().flatten(start_dim=1) perturbed = o - probe * alpha perturbed = perturbed.reshape(output.shape) # print("DEBUG: mass mean probe hook applied") return perturbed return hook
[docs] def add_mass_mean_probe_hook( model: nn.Module, probe: torch.Tensor, layer_names: list, alpha: float = 1.0 ) -> list: """Adds a probe to the specified layers of a PyTorch model. Args: model(nn.Module): The PyTorch model to be probed. probe(torch.Tensor): The probe tensor to be added to the output. layer_names(list): List of layer names (strings) to apply the hook on. alpha(float): Scaling factor for the probe. model: nn.Module: probe: torch.Tensor: layer_names: list: alpha: float: (Default value = 1.0) Returns: list: A list of hook handles. Keep them to remove hooks later if needed. """ hooks = [] for name, module in model.named_modules(): if name in layer_names: hook_fn = mass_mean_probe_hook(probe, alpha) handle = module.register_forward_hook(hook_fn) hooks.append(handle) # print(f"DEBUG: Added probe to layer: {name}") return hooks
[docs] def clarc_hook(cav: torch.Tensor, mean_length: torch.Tensor, alpha: float): """Creates a forward hook to adjust layer activations based on the CAV. Args: cav(torch.Tensor): Concept Activation Vector of shape (channels,). mean_length(float): Desired mean alignment length. cav: torch.Tensor: mean_length: torch.Tensor: alpha: float: Returns: function: A hook function to be registered with a PyTorch module. """ def hook(module: nn.Module, input: tuple, output: torch.Tensor) -> torch.Tensor: """ Args: module: nn.Module: input: tuple: output: torch.Tensor: Returns: """ nonlocal alpha, cav, mean_length output_shapes = output.shape v = stabilize(cav).squeeze(0) z = stabilize(mean_length).unsqueeze(0) x_copy_detached = output.clone().flatten(start_dim=1).detach() output = output.flatten(start_dim=1) vvt = torch.outer(v, v) A = torch.matmul(vvt, (x_copy_detached - z).T).T # (N, batch_size) if random.random() < 0.1: logger.debug(f"DEBUG: Magnitude of A: {A.detach().norm()}") logger.debug(f"DEBUG: Magnitude of cav: {cav.norm()}") logger.debug(f"DEBUG: Magnitude of mean_length: {mean_length.norm()}") logger.debug( f"DEBUG: Magnitude of x_copy_detached: {x_copy_detached.norm()}" ) logger.debug(f"DEBUG: Magnitude of vvt: {vvt.norm()}") results = output - A * alpha adjusted_output = results.reshape(output_shapes) # logger.debug(f"CLARC hook fired in layer: {module}") return adjusted_output return hook
[docs] def add_clarc_hook( model: nn.Module, cav: torch.Tensor, mean_length: torch.Tensor, layer_name: str, alpha: float = 1.0, ) -> list: """Applies debiasing to the specified layers of a PyTorch model using the provided CAV. Args: model(nn.Module): The PyTorch model to be debiased. cav(torch.Tensor): The Concept Activation Vector, shape (channels,). mean_length(torch.Tensor): Mean activation length of the unaffected activations. layer_names(list): List of layer names (strings) to apply the hook on. alpha(float): Scaling factor for the debiasing. model: nn.Module: cav: torch.Tensor: mean_length: torch.Tensor: layer_name: str: alpha: float: (Default value = 1.0) Returns: list: A list of hook handles. Keep them to remove hooks later if needed. """ hooks = [] for name, module in model.named_modules(): if name == layer_name: hook_fn = clarc_hook(cav, mean_length, alpha) handle = module.register_forward_hook(hook_fn) hooks.append(handle) logger.debug(f"Added CLARC hook to layer: {name}") return hooks