Source code for detoxai.cavs.mass_mean_probe

import numpy as np
import torch


[docs] def compute_mass_mean_probe( vecs: np.ndarray, targets: np.ndarray ) -> tuple[torch.Tensor, torch.Tensor]: """Compute the mass mean probe from the activations of a model. Args: vecs(np.ndarray): Activations of the model, shape (samples, features). targets(np.ndarray): Target labels for the samples, shape (samples,). vecs: np.ndarray: targets: np.ndarray: Returns: tuple: A tuple containing - mass_mean_probe (torch.Tensor): The mass mean probe. - mean_activation_over_nonartifact_samples (torch.Tensor): The mean activation over non-artifact samples. - mean_activation_over_artifact_samples (torch.Tensor): The mean activation over artifact samples """ X = vecs mean_activation_over_artifact_samples = X[targets == 1].mean(0) mean_activation_over_nonartifact_samples = X[targets == 0].mean(0) mass_mean_probe = ( mean_activation_over_artifact_samples - mean_activation_over_nonartifact_samples ) mass_mean_probe = torch.tensor(mass_mean_probe, dtype=torch.float32) mean_activation_over_nonartifact_samples = torch.tensor( mean_activation_over_nonartifact_samples, dtype=torch.float32 ) mean_activation_over_artifact_samples = torch.tensor( mean_activation_over_artifact_samples, dtype=torch.float32 ) return ( mass_mean_probe, mean_activation_over_nonartifact_samples, mean_activation_over_artifact_samples, )