Source code for detoxai.core.interface

import logging
import multiprocessing as mp
import signal
import traceback
from copy import deepcopy
from datetime import datetime

import torch.nn as nn
from torch.utils.data import DataLoader

# Project imports
from ..methods import (
    ACLARC,
    LEACE,
    PCLARC,
    RRCLARC,
    FineTune,
    ModelCorrectionMethod,
    NaiveThresholdOptimizer,
    SavaniAFT,
    SavaniLWO,
    SavaniRP,
    ZhangM,
)
from ..metrics.fairness_metrics import AllMetrics
from ..utils.dataloader import DetoxaiDataLoader, WrappedDataLoader
from ..utils.datasets import DetoxaiDataset
from .evaluation import evaluate_model
from .interface_helpers import construct_metrics_config, infer_layers
from .mcda_helpers import filter_pareto_front, select_best_method
from .model_wrappers import FairnessLightningWrapper
from .results_class import CorrectionResult

logger = logging.getLogger(__name__)


_method_mapping = {
    "SAVANIRP": SavaniRP,
    "SAVANILWO": SavaniLWO,
    "SAVANIAFT": SavaniAFT,
    "ZHANGM": ZhangM,
    "RRCLARC": RRCLARC,
    "PCLARC": PCLARC,
    "ACLARC": ACLARC,
    "LEACE": LEACE,
    "NT": NaiveThresholdOptimizer,
    "FINETUNE": FineTune,
}

SUPPORTED_METHODS = list(_method_mapping.keys())

DEFAULT_METHODS_CONFIG = {
    "global": {
        "last_layer_name": "last",
        "experiment_name": "default",
        "device": "cpu",
        "dataloader": None,
        "test_dataloader": None,
        "method_timeout": 600,  # seconds
    },
    "PCLARC": {
        "cav_type": "signal",
        "cav_layers": "penultimate",
        "use_cache": True,
    },
    "ACLARC": {
        "cav_type": "signal",
        "cav_layers": "penultimate",
        "use_cache": True,
    },
    "RRCLARC": {
        "cav_type": "signal",
        "cav_layers": "penultimate",
        "use_cache": True,
    },
    "LEACE": {
        "intervention_layers": "penultimate",
        "use_cache": True,
    },
    "SAVANIRP": {},
    "SAVANILWO": {
        "n_layers_to_optimize": 4,
    },
    "SAVANIAFT": {},
    "ZHANGM": {},
    "ROC": {
        "theta_range": (0.55, 0.95),
        "theta_steps": 20,
        "metric": "EO_GAP",
        "objective_function": "lambda fairness, accuracy: fairness * accuracy",  # ruff: noqa
    },
    "NT": {
        "threshold_range": (0.1, 0.9),
        "threshold_steps": 20,
        "metric": "EO_GAP",
        "objective_function": "lambda fairness, accuracy: -fairness",  # ruff: noqa
    },
    "FINETUNE": {
        "fine_tune_epochs": 1,
        "lr": 1e-4,
    },
}


[docs] def parse_methods_config(methods_config: dict) -> dict: """Here we compare what was passed and overwrite the default configuration Args: methods_config: dict: Returns: """ for key, dic in DEFAULT_METHODS_CONFIG.items(): if key not in methods_config: methods_config[key] = dic else: # Else we overwrite common values with the passed ones # And add the missing ones for pk in dic: if pk not in methods_config[key]: methods_config[key][pk] = dic[pk] return methods_config
[docs] def debias( model: nn.Module, dataloader: DetoxaiDataLoader | DataLoader, methods: list[str] | str = "all", metrics: list[str] | str = "all", methods_config: dict = {}, pareto_metrics: list[str] = ["balanced_accuracy", "equalized_odds"], return_type: str = "all", device: str = "cpu", include_vanila_in_results: bool = True, test_dataloader: DetoxaiDataLoader | DataLoader = None, num_of_classes: int | None = None, ) -> CorrectionResult | dict[str, CorrectionResult]: """ Run a suite of correction methods on the model and return the results Args: `model`: Model to run the correction methods on `dataloader`: DetoxaiDataLoader object with the dataset `harmful_concept`: Concept to debias -- this is the protected attribute # NOT SUPPORTED YET `methods`: List of correction methods to run `metrics`: List of metrics to include in the configuration `methods_config`: Configuration for each correction method `pareto_metrics`: List of metrics to use for the pareto front and selection of best method `return_type` (optional): Type of results to return. Options are 'pareto-front', 'all', 'best' "pareto-front": Return the results CorrectionResult objects only for results on the pareto front "all": Return the results for all correction methods "best": Return the results for the best correction method, chosen with ideal point method from pareto front `device` (optional): Device to run the correction methods on `include_vanila_in_results` (optional): Include the vanilla model in the results `test_dataloader` (optional): DataLoader for the test dataset. If not provided, the original dataloader is used `num_of_classes` (optional): Number of classes in the dataset. Default is None, which means the number of classes will be inferred from the dataloader """ if not isinstance(dataloader, DetoxaiDataLoader) or not isinstance( dataloader.dataset, DetoxaiDataset ): unique_classes = set() if num_of_classes is None: logger.warning( "Detoxai will infer the number of classes from the dataloader" ) for batch in dataloader: # Assuming the first element of the batch is the input labels = batch[1] unique_classes.update(labels.unique().tolist()) num_of_classes = len(unique_classes) logger.warning(f"Inferred number of classes: {num_of_classes}") dataloader = WrappedDataLoader(dataloader.dataset, num_of_classes) logging.debug(f"Received configuration:\n {methods_config}") # Parse methods config (deepcopy to avoid modifying the original) config = parse_methods_config(deepcopy(methods_config)) logging.debug(f"Resolved configuration to:\n {config}") # Parse methods if methods == "all": methods = SUPPORTED_METHODS else: # Ensure all methods passed are supported for method in methods: if method.upper() not in SUPPORTED_METHODS: raise ValueError(f"Method {method} not supported") # Capitalize all methods methods = [method.upper() for method in methods] config["global"]["device"] = device # Append a timestamp to the experiment name timestep = datetime.now().strftime("%Y%m%d-%H%M%S%f") exp_name = f"{config['global']['experiment_name']}_{timestep}" config["global"]["experiment_name"] = exp_name logging.info(f"Experiment name: {config['global']['experiment_name']}") # If somebody passes a dataloader that is not detoxai's we still allow it # but fine-tuning metrics won't be available (final metrics still will be calculated) # if isinstance(dataloader, DetoxaiDataLoader): class_labels = dataloader.get_class_names() prot_attr_arity = 2 # TODO only supported binary protected attributes # Create an AllMetrics object metrics_calculator = AllMetrics( construct_metrics_config(metrics), class_labels=class_labels, num_groups=prot_attr_arity, ) # Wrap model model = FairnessLightningWrapper( model, performance_metrics=metrics_calculator.get_performance_metrics(), fairness_metrics=metrics_calculator.get_fairness_metrics(), ) # else: # model = FairnessLightningWrapper(model) results = {} for method in methods: logger.info("=" * 50 + f" Running method {method} " + "=" * 50) method_kwargs = config[method] | config["global"] method_kwargs["model"] = deepcopy(model) method_kwargs["dataloader"] = dataloader method_kwargs["test_dataloader"] = test_dataloader result = run_correction(method, method_kwargs, pareto_metrics) results[method] = result if include_vanila_in_results: vanilla_result = CorrectionResult( method="Vanilla", model=model, metrics=evaluate_model( model, dataloader if test_dataloader is None else test_dataloader, pareto_metrics, device=device, ), ) results["Vanilla"] = vanilla_result if return_type == "pareto-front": return filter_pareto_front(results) elif return_type == "all": return results elif return_type == "best": return select_best_method(results) else: raise ValueError(f"Invalid return type {return_type}")
[docs] def run_correction( method: str, method_kwargs: dict, pareto_metrics: list[str] | None = None ) -> CorrectionResult: """Run the specified correction method Args: method: Correction method to run kwargs: Arguments for the correction method method: str: method_kwargs: dict: pareto_metrics: list[str] | None: (Default value = None) Returns: """ metrics = {"pareto": {}, "all": {}} failed = False # Copy and remove model from kwargs used for debug printing __cfg_copy = deepcopy(method_kwargs) __cfg_copy.pop("model") logging.debug(f"Running correction method {method} with kwargs: \n {__cfg_copy}") # Resolve the method and create the corrector try: corrector_class = _method_mapping[method.upper()] corrector = corrector_class(**method_kwargs) except KeyError: logger.error(ValueError(f"Correction method {method} not found")) failed = True if not failed: # Parse intervention layers if "intervention_layers" in method_kwargs: method_kwargs["intervention_layers"] = infer_layers( corrector, method_kwargs["intervention_layers"] ) logging.info( f"Resolved intervention layers: {method_kwargs['intervention_layers']}" ) # Parse cav layers if "cav_layers" in method_kwargs: method_kwargs["cav_layers"] = infer_layers( corrector, method_kwargs["cav_layers"] ) logging.info(f"Resolved CAV layers: {method_kwargs['cav_layers']}") # Parse last layer name if "last_layer_name" in method_kwargs: method_kwargs["last_layer_name"] = infer_layers( corrector, method_kwargs["last_layer_name"] )[0] logging.info( f"Resolved last layer name: {method_kwargs['last_layer_name']}" ) # Precompute CAVs if required if corrector.requires_acts: if "intervention_layers" not in method_kwargs: lays = method_kwargs["cav_layers"] else: lays = method_kwargs["intervention_layers"] corrector.extract_activations(method_kwargs["dataloader"], lays) logger.debug(f"Computing CAVs on layers: {lays}") if corrector.requires_cav: corrector.compute_cavs(method_kwargs["cav_type"], lays) logger.debug(f"Running correction method {method}") # Here we finally run the correction method try: timeout = method_kwargs.pop("method_timeout", None) if timeout is not None and timeout > 0: logger.debug(f"Running {method} w {timeout} s timeout") # if isinstance(corrector, LEACE): # logger.debug(f"Running {method} with multiprocessing") # success = _mp_apply_model_correction_w_timeout( # corrector, method_kwargs, timeout # ) # else: success = _apply_model_correction_w_timeout( corrector, method_kwargs, timeout ) if not success: failed = True logger.error(traceback.format_exc()) logger.error(f"Correction method {method} failed") else: corrector.apply_model_correction(**method_kwargs) if not failed: logger.debug(f"Correction method {method} applied") method_kwargs["model"] = corrector.get_lightning_model() # Remove gradients for param in method_kwargs["model"].parameters(): param.requires_grad = False test_dl = method_kwargs["test_dataloader"] metrics = evaluate_model( method_kwargs["model"], method_kwargs["dataloader"] if test_dl is None else test_dl, pareto_metrics, device=method_kwargs["device"], ) # Move to CPU method_kwargs["model"].to("cpu") except Exception as e: logger.error(traceback.format_exc()) logger.error(f"Error running correction method {method}: {e}") failed = True else: metrics = {"pareto": {}, "all": {}} return CorrectionResult( method=method, model=method_kwargs["model"], metrics=metrics )
def _apply_model_correction_w_timeout( corrector: ModelCorrectionMethod, method_kwargs: dict, timeout: float ) -> bool: """Execute the apply_model_correction method of the corrector as a task with timeout to prevent infinite execution. Args: corrector: Object with an apply_model_correction method. method_kwargs: Arguments to pass to the method. timeout: Maximum execution time in seconds. corrector: ModelCorrectionMethod: method_kwargs: dict: timeout: float: Returns: bool: True if successful, False on error or timeout. """ def handler(signum, frame): """ Args: signum: frame: Returns: """ raise Exception("Timeout") signal.signal(signal.SIGALRM, handler) signal.alarm(int(timeout)) try: corrector.apply_model_correction(**method_kwargs) signal.alarm(0) # Disable the alarm return True except Exception as e: signal.alarm(0) if "Timeout" not in str(e): logger.error(traceback.format_exc()) logger.error( f"Error running correction method {corrector.__class__.__name__}: {e}" ) return False else: logger.error(traceback.format_exc()) logger.error(f"Correction method {corrector.__class__.__name__} timed out") return False def _mp_apply_model_correction_w_timeout( corrector: ModelCorrectionMethod, method_kwargs: dict, timeout: float ) -> bool: """Execute the apply_model_correction method of the corrector in a separate process as a task with timeout to prevent infinite execution. Args: corrector: Object with an apply_model_correction method. method_kwargs: Arguments to pass to the method. timeout: Maximum execution time in seconds. corrector: ModelCorrectionMethod: method_kwargs: dict: timeout: float: Returns: bool: True if successful, False on error or timeout. """ try: p = mp.Process( target=_apply_model_correction_w_timeout, args=(corrector, method_kwargs, timeout), ) p.start() p.join(timeout) if p.is_alive(): p.kill() p.join() return False else: return True except Exception as e: logger.error(traceback.format_exc()) logger.error( f"Error running correction method {corrector.__class__.__name__}: {e}" ) return False
[docs] def get_supported_methods() -> list[str]: """Get a list of supported methods Returns: list[str]: List of supported methods """ return SUPPORTED_METHODS