Source code for detoxai.methods.posthoc.reject_option_classification

import logging
from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from ...metrics.bias_metrics import calculate_bias_metric_torch
from ...metrics.metrics import balanced_accuracy_torch
from ...utils.dataloader import DetoxaiDataLoader
from .posthoc_base import PosthocBase

logger = logging.getLogger(__name__)


[docs] class ROCModelWrapper(nn.Module): """ """ def __init__(self, base_model: nn.Module, theta: float, L_values: Dict[int, int]): super().__init__() self.base_model = base_model self.theta = theta self.L_values = L_values def _is_in_critical_region(self, probs: torch.Tensor) -> torch.Tensor: """ Args: probs: torch.Tensor: Returns: """ max_probs, _ = torch.max(probs, dim=1) return max_probs <= self.theta
[docs] def forward(self, input, sensitive_features): """ Args: input: sensitive_features: Returns: """ # Get base model predictions output = self.base_model(input) if isinstance(output, tuple): output = output[0] # Apply ROC correction probs = F.softmax(output, dim=1) critical_mask = self._is_in_critical_region(probs) predictions = torch.argmax(probs, dim=1) # Modify predictions based on protected attributes for prot_value, L in self.L_values.items(): mask = (sensitive_features == prot_value) & critical_mask predictions[mask] = L return predictions
[docs] class RejectOptionClassification(PosthocBase): """Implements Reject Option Classification (ROC) for fairness optimization. This class implements a post-hoc fairness optimization method that modifies model predictions based on a confidence threshold (theta). Predictions with confidence below theta are flipped to optimize for both accuracy and fairness. Args: Returns: """ def __init__( self, model: nn.Module, experiment_name: str, device: str, dataloader: DetoxaiDataLoader, theta_range: Tuple[float, float] = (0.55, 0.95), theta_steps: int = 20, metric: str = "EO_GAP", objective_function: Optional[Callable[[float, float], float]] = None, **kwargs: Any, ) -> None: super().__init__(model, experiment_name, device) self.dataloader = dataloader self.theta_range = theta_range self.theta_steps = theta_steps self.hooks: List[Any] = [] assert ( theta_range[0] < theta_range[1] and theta_range[0] >= 0.5 and theta_range[1] <= 1.0 ) self.metric = metric self.objective_function = objective_function if self.objective_function is None: self.objective_function = lambda fairness, accuracy: fairness * accuracy self.best_config = { "theta": None, "L_values": {0: None, 1: None}, # L values for each protected attribute } def _evaluate_parameters( self, preds: torch.Tensor, targets: torch.Tensor, sensitive_features: torch.Tensor, theta: float, L_values: Dict[int, int], ) -> float: """Evaluates a specific parameter configuration. Args: preds: torch.Tensor: targets: torch.Tensor: sensitive_features: torch.Tensor: theta: float: L_values: Dict[int: int]: Returns: """ # Validate theta if not (0.5 <= theta <= 1.0): raise AssertionError(f"Theta must be between 0.5 and 1.0, got {theta}") # Validate L_values contains all protected attribute values unique_protected = torch.unique(sensitive_features).tolist() for protected_value in unique_protected: if protected_value not in L_values: raise KeyError( f"L_values missing value for protected attribute {protected_value}" ) modified_preds = self._modified_prediction( theta, preds, sensitive_features, L_values ) fairness_score = calculate_bias_metric_torch( self.metric, modified_preds, targets, sensitive_features ) accuracy_score = balanced_accuracy_torch(modified_preds, targets) # Convert tensors to floats and handle NaN fairness_score = float(fairness_score.item()) accuracy_score = float(accuracy_score.item()) if np.isnan(fairness_score) or np.isnan(accuracy_score): return 0.0 return float(self.objective_function(fairness_score, accuracy_score)) def _optimize_parameters(self) -> Tuple[float, Dict[int, int]]: """Optimizes both theta and L values for each protected attribute value.""" thetas = np.linspace(self.theta_range[0], self.theta_range[1], self.theta_steps) best_score = float("-inf") preds, targets, sensitive_features = self._get_model_predictions( self.dataloader ) # Validate shapes assert preds.shape[1] == 2, ( f"Expected binary classification, got {preds.shape[1]} classes" ) assert targets.dim() == 1, f"Expected 1D targets, got {targets.dim()}D" assert sensitive_features.dim() == 1, ( f"Expected 1D protected features, got {sensitive_features.dim()}D" ) # Grid search over theta and L values for theta in thetas: for L_protected_0 in [0, 1]: for L_protected_1 in [0, 1]: L_values = {0: L_protected_0, 1: L_protected_1} score = self._evaluate_parameters( preds, targets, sensitive_features, theta, L_values ) if score > best_score: best_score = score self.best_config["theta"] = theta self.best_config["L_values"] = L_values return self.best_config["theta"], self.best_config["L_values"] def _is_in_critical_region(self, theta: float, probs: torch.Tensor) -> torch.Tensor: """Determines which predictions fall in the critical region (confidence ≤ theta). Args: theta: Confidence threshold probs: Prediction probabilities (batch_size, 2) theta: float: probs: torch.Tensor: Returns: torch.Tensor: Boolean mask (batch_size,) indicating critical region predictions """ assert probs.shape[1] == 2, ( f"Expected binary classification, got {probs.shape[1]} classes" ) max_probs, _ = torch.max(probs, dim=1) return max_probs <= theta def _modified_prediction( self, theta: float, probs: torch.Tensor, sensitive_features: torch.Tensor, L_values: Dict[int, int], ) -> torch.Tensor: """Modifies predictions based on critical region and protected attributes. Args: theta: Confidence threshold probs: Prediction probabilities (batch_size, 2) sensitive_features: Protected attributes (batch_size,) L_values: Dictionary mapping protected attribute values to labels theta: float: probs: torch.Tensor: sensitive_features: torch.Tensor: L_values: Dict[int: int]: Returns: torch.Tensor: Modified predictions (batch_size,) """ assert probs.shape[1] == 2, "Expected binary classification" assert sensitive_features.dim() == 1, "Expected 1D protected features" critical_mask = self._is_in_critical_region(theta, probs) predictions = torch.argmax(probs, dim=1) # Apply different L values based on protected attribute for prot_value, L in L_values.items(): mask = (sensitive_features == prot_value) & critical_mask predictions[mask] = L return predictions
[docs] def apply_model_correction(self, **kwargs) -> nn.Module: """Returns a wrapped model that applies ROC correction during inference. Args: **kwargs: Returns: """ theta, L_values = self._optimize_parameters() return ROCModelWrapper(self.model, theta, L_values)