Source code for detoxai.methods.model_correction

from abc import ABC, abstractmethod

import lightning as L
from torch import nn


[docs] class ModelCorrectionMethod(ABC): """ """ def __init__( self, model: nn.Module | L.LightningModule, experiment_name: str, device: str ) -> None: # Unwrap LightningModule if isinstance(model, L.LightningModule): self.lightning_model = model.to(device) self.model = model.model.to(device) else: self.model = model.to(device) self.experiment_name = experiment_name self.device = str(device) if "cuda" in self.device and ":" in self.device: self.devices_indices = [int(str(self.device).split(":")[1])] else: self.devices_indices = "auto" self.requires_cav: bool = False self.requires_acts: bool = False
[docs] @abstractmethod def apply_model_correction(self) -> None: """ """ raise NotImplementedError
[docs] def get_model(self) -> nn.Module: """ """ return self.model
[docs] def get_lightning_model(self) -> L.LightningModule: """ """ if hasattr(self, "lightning_model"): return self.lightning_model else: raise AttributeError("No Lightning model found")
[docs] def remove_hooks(self) -> None: """ """ if hasattr(self, "hooks"): self.hooks = list()