A minimal code snippet
[ ]:
import torch
import torchvision
import detoxai
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, 2) # Make it binary classification
X = torch.rand(128, 3, 224, 224)
Y = torch.randint(0, 2, size=(128,))
PA = torch.randint(0, 2, size=(128,))
dataloader = torch.utils.data.DataLoader(list(zip(X, Y, PA)), batch_size=32)
results: dict[str, detoxai.CorrectionResult] = detoxai.debias(model, dataloader)