Source code for detoxai.core.evaluation

import logging

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from ..metrics.metrics import comprehensive_metrics_torch

logger = logging.getLogger(__name__)


[docs] def evaluate_model( model: nn.Module, dataloader: DataLoader, pareto_metrics: list[str] | None = None, device: str | None = None, ) -> dict: """Evaluate the model on various metrics Args: - model: Model to evaluate - dataloader: DataLoader for the dataset - pareto_metrics: List of metrics to include in the pareto front - device: Device to use for evaluation ("cpu" or "cuda") Returns: """ logger.debug("Evaluating model") if device is not None: model_device = device model.to(device) else: model_device = next(model.parameters()).device logger.debug(f"Evaluating model on device: {model_device}") model.eval() preds = [] targets = [] protected_attributes = [] for batch in dataloader: x, y, prot_attr = batch x = x.to(model_device) y = y.to(model_device) prot_attr = prot_attr.to(model_device) with torch.no_grad(): pred = model(x).argmax(dim=1) preds.append(pred) targets.append(y) protected_attributes.append(prot_attr) preds = torch.cat(preds).to(model_device) targets = torch.cat(targets).to(model_device) protected_attributes = torch.cat(protected_attributes).to(model_device) raw_results = comprehensive_metrics_torch(targets, preds, protected_attributes) logger.debug(f"Raw results: {raw_results}") metrics = {"pareto": {}, "all": {}} for metric in raw_results: if pareto_metrics and metric in pareto_metrics: metrics["pareto"][metric] = raw_results[metric].cpu().detach().item() # Collect all metrics metrics["all"][metric] = raw_results[metric].cpu().detach().item() return metrics