Source code for detoxai.methods.savani.utils

import torch

from ...metrics.bias_metrics import (
    BiasMetrics,
    calculate_bias_metric_torch,
)
from ...metrics.metrics import balanced_accuracy_torch


[docs] def phi_torch( Y_true: torch.Tensor, Y_pred: torch.Tensor, ProtAttr: torch.Tensor, epsilon: float = 0.05, bias_metric: BiasMetrics | str = BiasMetrics.TPR_GAP, ) -> tuple[torch.Tensor, torch.Tensor]: """Calculate phi as in the paper phi = balanced_accuracy(Y_true, Y_pred) if bias < epsilon else 0 Args: Y_true: torch.Tensor: Y_pred: torch.Tensor: ProtAttr: torch.Tensor: epsilon: float: (Default value = 0.05) bias_metric: BiasMetrics | str: (Default value = BiasMetrics.TPR_GAP) Returns: """ assert Y_true.shape == Y_pred.shape == ProtAttr.shape, ( f"Y_true {Y_true.shape}, Y_pred {Y_pred.shape}, ProtAttr {ProtAttr.shape} must have the same shape" ) # Compute the bias metric bias = calculate_bias_metric_torch(bias_metric, Y_pred, Y_true, ProtAttr) # Compute phi phi = ( balanced_accuracy_torch(Y_true, Y_pred) if bias < epsilon else torch.tensor(0, dtype=torch.float32, device=Y_true.device) ) return phi, bias