Source code for detoxai.cavs.extract_activations
import logging
import os
import sys
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
logger = logging.getLogger(__name__)
[docs]
def get_all_layers(model: nn.Module, prefix: str = "") -> dict:
"""Recursively get all layers from the model.
Args:
model(nn.Module): The PyTorch model.
prefix(str): Prefix for the layer names (used during recursion).
model: nn.Module:
prefix: str: (Default value = "")
Returns:
dict: Dictionary mapping layer names to layer modules.
"""
layers = {}
for name, module in model.named_children():
full_name = f"{prefix}.{name}" if prefix else name
layers[full_name] = module
child_layers = get_all_layers(module, full_name)
layers.update(child_layers)
return layers
[docs]
def get_layer_by_name(model: nn.Module, layer_name: str) -> nn.Module:
"""Retrieve a layer from the model by its name.
Args:
model(nn.Module): The PyTorch model.
layer_name(str): Dot-separated name of the layer.
model: nn.Module:
layer_name: str:
Returns:
nn.Module: The layer module.
"""
components = layer_name.split(".")
module = model
for comp in components:
module = getattr(module, comp)
return module
[docs]
def load_activations(save_path: str) -> dict[str, np.ndarray]:
"""
Args:
save_path: str:
Returns:
"""
activations_np = np.load(save_path)
activations = {}
for key in activations_np:
activations[key] = activations_np[key]
logger.info(f"Loaded activations from '{save_path}'")
return activations
[docs]
def extract_activations(
model: nn.Module,
dataloader: DataLoader,
experiment_name: str,
save_dir: str,
layers: list | None = None,
device: str = "cuda",
use_cache: bool = True,
) -> dict[str, np.ndarray]:
"""Extract activations from all layers of a model for data from a dataloader.
Args:
model(nn.Module): The PyTorch model.
dataloader(DataLoader): The PyTorch DataLoader.
experiment_name(str): Name of the experiment.
save_dir(str): Directory to save the activations.
layers(list): List of layer names to extract activations from.
device(str): Device to run the model on.
use_cache(bool): Whether to use cached activations.
model: nn.Module:
dataloader: DataLoader:
experiment_name: str:
save_dir: str:
layers: list | None: (Default value = None)
device: str: (Default value = "cuda")
use_cache: bool: (Default value = True)
Returns:
dict: Dictionary mapping layer names to activations.
"""
save_path = os.path.join(save_dir, experiment_name + ".npz")
if use_cache and os.path.exists(save_path):
logger.debug(f"Loading activations from '{save_path}'")
return load_activations(save_path)
model.eval()
if not os.path.exists(save_dir):
logger.debug(f"Creating directory '{save_dir}' since it does not exist.")
os.makedirs(save_dir)
activations = {}
if layers is None:
layers = get_all_layers(model)
elif isinstance(layers, list):
layers_dict = {}
for name in layers:
try:
layer = get_layer_by_name(model, name)
layers_dict[name] = layer
except AttributeError:
raise ValueError(f"Layer '{name}' not found in the model.")
layers = layers_dict
elif isinstance(layers, dict):
pass
else:
raise ValueError(
"layers must be None, a list of layer names, or a dict of {name: module}"
)
handles = []
for name, layer in layers.items():
def get_activation(name):
"""
Args:
name:
Returns:
"""
def hook(model, input, output):
"""
Args:
model:
input:
output:
Returns:
"""
if name not in activations:
activations[name] = []
activations[name].append(output.detach().cpu())
return hook
handle = layer.register_forward_hook(get_activation(name))
handles.append(handle)
labels_np = np.array([]).reshape(-1, 2)
with torch.no_grad():
for batch_idx, batch in enumerate(
tqdm(dataloader, desc="Extracting Activations", file=sys.stdout)
):
data = batch[0].to(device)
labels = batch[1].cpu().detach().numpy()
prota = batch[2].cpu().detach().numpy()
tpl = (labels, prota)
rest = np.array(tpl).reshape(-1, len(tpl))
labels_np = np.concatenate((labels_np, rest), axis=0)
_ = model(data)
for handle in handles:
handle.remove()
activations_np = {}
activations_np["labels"] = labels_np
for name, acts in activations.items():
np_acts = torch.cat(acts).cpu().numpy()
if (
"resnet" in experiment_name
and "relu" in name.lower()
and np_acts.shape[0] == len(labels_np) // 2
):
activations_np[name + "_pre"] = np_acts[: len(labels_np) // 2]
activations_np[name + "_post"] = np_acts[len(labels_np) // 2 :]
else:
activations_np[name] = np_acts
np.savez(save_path, **activations_np)
logger.debug(f"Saved all activations at '{save_path}'")
return activations_np