Source code for detoxai.methods.savani.lay_wis_opt

import logging
import sys
from copy import deepcopy

import lightning as L
import numpy as np
import torch
import torch.nn as nn
from skopt import forest_minimize  # noqa

# from sklearn.ensemble import RandomForestRegressor
from skopt.learning.forest import RandomForestRegressor
from skopt.space import Real
from torch.utils.data import DataLoader
from tqdm import tqdm

from ...metrics.bias_metrics import BiasMetrics

# Project imports
from .savani_base import SavaniBase

logger = logging.getLogger(__name__)


[docs] class SavaniLWO(SavaniBase): """ """ def __init__( self, model: nn.Module | L.LightningModule, experiment_name: str, device: str, seed: int = 123, **kwargs, ) -> None: super().__init__(model, experiment_name, device, seed)
[docs] def apply_model_correction( self, dataloader: DataLoader, last_layer_name: str, epsilon: float = 0.1, bias_metric: BiasMetrics | str = BiasMetrics.EO_GAP, n_layers_to_optimize: int | str = "all", thresh_optimizer_maxiter: int = 100, beta: float = 2.2, params_to_opt: int | float = 0.5, never_more_than: int = 50_000, tau_init: float = 0.5, outputs_are_logits: bool = True, n_eval_batches: int = 3, eval_batch_size: int = 128, skopt_verbose: bool = False, skopt_njobs: int = 4, skopt_npoints: int = 1000, skopt_maxiter: int = 10, soft_thresh_temperature: float = 10.0, **kwargs, ) -> None: """Do layer-wise optimization to find the best weights for each layer and the best threshold tau Args: dataloader: DataLoader: last_layer_name: str: epsilon: float: (Default value = 0.1) bias_metric: BiasMetrics | str: (Default value = BiasMetrics.EO_GAP) n_layers_to_optimize: int | str: (Default value = "all") thresh_optimizer_maxiter: int: (Default value = 100) beta: float: (Default value = 2.2) params_to_opt: int | float: (Default value = 0.5) never_more_than: int: (Default value = 50_000) tau_init: float: (Default value = 0.5) outputs_are_logits: bool: (Default value = True) n_eval_batches: int: (Default value = 3) eval_batch_size: int: (Default value = 128) skopt_verbose: bool: (Default value = False) skopt_njobs: int: (Default value = 4) skopt_npoints: int: (Default value = 1000) skopt_maxiter: int: (Default value = 10) soft_thresh_temperature: float: (Default value = 10.0) **kwargs: Returns: """ assert self.check_layer_name_exists(last_layer_name), ( f"Layer name {last_layer_name} not found in the model" ) self.last_layer_name = last_layer_name self.tau_init = tau_init self.epsilon = epsilon self.bias_metric = bias_metric self.outputs_are_logits = outputs_are_logits self.n_eval_batches = n_eval_batches self.initialize_dataloader(dataloader, eval_batch_size) best_model = deepcopy(self.model) best_phi, best_tau = self.optimize_tau(tau_init, thresh_optimizer_maxiter) total_layers = len(list(self.model.parameters())) if n_layers_to_optimize == "all": n_layers_to_optimize = total_layers assert n_layers_to_optimize <= total_layers, ( "n_layers_to_optimize must be less than the total number of layers" ) with tqdm( desc=f"LWO layer -1 (global phi: {best_phi:.3f}, tau: {best_tau:.3f})", total=n_layers_to_optimize, file=sys.stdout, ) as pbar: for i, (name, o_params) in enumerate(self.model.named_parameters()): # We're optimizing the last n_layers_to_optimize layers # -3 to avoid the last layer (2 outputs) weights and bias, then to avoid second to last layer's bias, we dont want to optimize bias as it doesn't make sense if i < total_layers - n_layers_to_optimize - 1 or i >= total_layers - 1: continue logger.debug(f"Optimizing {name} layer ({i})") total_params_cnt = o_params.numel() if isinstance(params_to_opt, float): n = max(int(params_to_opt * total_params_cnt), 1) else: n = params_to_opt if n > total_params_cnt: n = total_params_cnt logger.info( f"Even though you asked for {params_to_opt} of the parameters, we're capping it to {total_params_cnt}" ) # Cap the number of neurons to optimize, this is useful for large models # Otherwise skopt will literally kill your machine if n > never_more_than: n = never_more_than logger.info( f"Even though you asked for {params_to_opt} of the parameters, we're capping it to {never_more_than}" ) # Cap the number of neurons to optimize, this is useful for large models logger.debug(f"Optimizing lay. {i} w. {n}/{total_params_cnt} params") sel_params, indices = self.flatten_select(o_params, n, total_params_cnt) logging.debug(f"Flattened parameters cnt: {sel_params.numel()}") std = o_params.std().detach().cpu().numpy() space = [ Real( x - beta * std, x + beta * std, ) for x in sel_params.detach().cpu().numpy() ] logger.debug(f"Optimizing {len(space)} parameters") regressor = RandomForestRegressor( n_estimators=50, n_jobs=skopt_njobs, max_depth=10, verbose=skopt_verbose, min_samples_leaf=2, random_state=self.seed, min_impurity_decrease=1e-4, ) res = forest_minimize( self.objective_LWO(o_params, best_tau, indices), dimensions=space, base_estimator=regressor, n_calls=skopt_maxiter, n_jobs=skopt_njobs, random_state=self.seed, n_points=skopt_npoints, verbose=skopt_verbose, ) if -res.fun > best_phi: best_p: list = res.x best_p_t = torch.tensor(best_p, device=self.device) # Update the weights with torch.no_grad(): o_params.data = self.unflatten(o_params, best_p_t, indices) tau, phi = self.optimize_tau(tau_init, thresh_optimizer_maxiter) if phi > best_phi: best_phi = phi best_tau = tau best_model = deepcopy(self.model) logger.debug(f"New best phi: {best_phi}, best tau: {best_tau}") pbar.update(1) pbar.set_description( f"LWO layer {i} (global phi: {best_phi:.3f}, tau: {best_tau:.3f})" ) self.model = best_model self.best_tau = best_tau if hasattr(self, "lightning_model"): self.lightning_model.model = best_model # Add a hook with the best transformation self.apply_hook(best_tau, soft_thresh_temperature)
[docs] def objective_LWO( self, o_params: torch.Tensor, tau: float, indices: list ) -> callable: """Objective function for the layer-wise optimization Args: o_params: The original parameters (torch.Tensor) tau: The threshold value (float) indices: The indices of the selected neurons (list) o_params: torch.Tensor: tau: float: indices: list: Returns: : The objective function """ if not isinstance(tau, torch.Tensor): tau = torch.tensor(tau, device=self.device) def objective(new_params: list) -> float: """ Args: new_params: list: Returns: """ nonlocal tau, o_params, indices # Update the weights with torch.no_grad(): np_trch = torch.tensor(new_params, device=self.device) o_params.data = self.unflatten(o_params, np_trch, indices) phi, _ = self.phi_torch(tau) return -phi.detach().cpu().numpy() return objective
[docs] def flatten_select( self, params: torch.Tensor, select_cnt: float | int, total_params: int ) -> tuple[torch.Tensor, list]: """Take an n-dimensional array, Args: Args: select_cnt: The number of neurons to select total_params: The total number of parameters params: torch.Tensor: select_cnt: float | int: total_params: int: Returns: A 1-dimensional array of selected neurons A 1-dimensional array of indices of the selected neurons """ if isinstance(select_cnt, float): select_cnt = int(select_cnt * total_params) assert select_cnt <= total_params, ( "select_cnt must be less than the total number of parameters" ) indices = np.random.choice(total_params, select_cnt, replace=False) indices = list(indices) return params.flatten()[indices], indices
[docs] def unflatten( self, o_params: torch.Tensor, f_params: torch.Tensor, indices: list ) -> torch.Tensor: """Unflatten the parameters Args: o_params: The original parameters f_params: The flattened parameters indices: The indices of the selected neurons o_params: torch.Tensor: f_params: torch.Tensor: indices: list: Returns: : The unflattened parameters """ o_shape = o_params.shape o_params = o_params.flatten() o_params[indices] = f_params return o_params.reshape(o_shape)