Source code for detoxai.utils.datasets

import logging
import os
import random
from pathlib import Path
from typing import Callable, Dict, List, Optional, Set, Tuple, Union

import numpy as np
import pandas as pd
import PIL
import PIL.Image
import torch

# from torch.utils.data import Dataset
import yaml
from torchvision.datasets.folder import VisionDataset

logger = logging.getLogger(__name__)

DETOXAI_DATASET_PATH = os.environ.get("DETOXAI_DATASET_PATH", Path.home() / ".detoxai")

# NOTE: transforms and the combination of transform and target_transform are mutually exclusive

CELEBA_DATASET_CONFIG = {
    "name": "celeba",
    "variant": "default",  # or None
    "target": "Male",  # target attribute that should be predicted
    "splits": {"train": 0.6, "test": 0.2, "unlearn": 0.2},
}

CELEBA_VARIANT_CONFIG = {
    "dataset": "celeba",
    "variant": "default",
    "fraction": 1.0,
    "splits": {
        "train": {
            "fraction": 0.3,
            "balancing": [
                {
                    "attribute_combination": [
                        {"attribute": "Male", "label": 0},
                        {"attribute": "Smiling", "label": 1},
                    ],
                    "percentage": 0.1,
                },
                {
                    "attribute_combination": [
                        {"attribute": "Male", "label": 1},
                        {"attribute": "Smiling", "label": 0},
                    ],
                    "percentage": 0.9,
                },
            ],
        },
        "test": {
            "fraction": 0.5,
            "balancing": [
                {
                    "attribute_combination": [
                        {"attribute": "Male", "label": 1},
                        {"attribute": "Smiling", "label": 1},
                    ],
                    "percentage": 0.5,
                }
            ],
        },
        "unlearn": {
            "fraction": 0.2,
            "balancing": [
                {
                    "attribute_combination": [
                        {"attribute": "Male", "label": 1},
                        {"attribute": "Smiling", "label": 1},
                    ],
                    "percentage": 0.5,
                }
            ],
        },
    },
}


[docs] def calculate_max_samples(df: pd.DataFrame, config: dict) -> int: """Calculate the maximum number of total samples possible given the constraints to avoid duplicates and maintain percentages. Args: df: pd.DataFrame: config: dict: Returns: """ max_possible_samples = [] for balance_rule in config["balancing"]: mask = pd.Series([True] * len(df), index=df.index) for condition in balance_rule["attribute_combination"]: mask &= df[condition["attribute"]] == condition["label"] available_samples = mask.sum() # Calculate maximum total samples possible for this rule # Example: if we need 30% and have 100 samples, max total is 100/0.3 = 333 if balance_rule["percentage"] > 0: max_total = int(available_samples / balance_rule["percentage"]) max_possible_samples.append(max_total) return min(max_possible_samples) if max_possible_samples else len(df)
[docs] def balance_dataset(df: pd.DataFrame, config: dict) -> Tuple[np.ndarray, int]: """ Args: df: pd.DataFrame: config: dict: Returns: """ total_samples = calculate_max_samples(df, config) selected_indices: Set[int] = set() remaining_indices = set(df.index.tolist()) total_matching_indices = set() for balance_rule in config["balancing"]: n_samples = int(balance_rule["percentage"] * total_samples) mask = pd.Series([True] * len(df), index=df.index) for condition in balance_rule["attribute_combination"]: attr = condition["attribute"] label = condition["label"] mask &= df[attr] == label matching_indices = df[mask].index.tolist() available_indices = list(set(matching_indices) - selected_indices) if len(available_indices) < n_samples: logger.warning( f"Warning: Reducing total samples. Not enough samples for combination " f"{balance_rule['attribute_combination']}. " f"Requested {n_samples}, but only {len(available_indices)} available." ) total_samples = int(len(available_indices) / balance_rule["percentage"]) n_samples = int(balance_rule["percentage"] * total_samples) if available_indices: selected = np.random.choice( available_indices, size=n_samples, replace=False ) selected_indices.update(selected) remaining_indices -= set(selected) total_matching_indices.update(matching_indices) total_percentage = sum(rule["percentage"] for rule in config["balancing"]) if total_percentage < 1 and remaining_indices: remaining_samples = int((1 - total_percentage) * total_samples) indices_that_do_not_adhere_to_any_rule = ( remaining_indices - total_matching_indices ) remaining_to_sample = min( remaining_samples, len(indices_that_do_not_adhere_to_any_rule) ) remaining_selected = np.random.choice( list(indices_that_do_not_adhere_to_any_rule), size=remaining_to_sample, replace=False, ) selected_indices.update(remaining_selected) return np.array(list(selected_indices)), total_samples
[docs] def make_detoxai_datasets_variant(variant_config): """ Args: variant_config: Returns: """ variant_path = ( Path(DETOXAI_DATASET_PATH) / variant_config["dataset"] / "variants" / variant_config["variant"] / "splits" ) os.makedirs(variant_path, exist_ok=True) labels = pd.read_csv( Path(DETOXAI_DATASET_PATH) / variant_config["dataset"] / "labels.csv" ) labels_fraction = labels.iloc[: int(variant_config["fraction"] * len(labels))] assert variant_config["fraction"] <= 1.0, ( "Fraction should be less than or equal to 1.0" ) assert ( sum( [ split_config["fraction"] for split_name, split_config in variant_config["splits"].items() ] ) <= 1.0 ), "Fractions should add up to less than or equal to 1.0" split_index_offset = 0 for split_name, split_config in variant_config["splits"].items(): split_path = variant_path / f"{split_name}.txt" split_num_samples = int(split_config["fraction"] * len(labels_fraction)) df_split = labels_fraction.iloc[ split_index_offset : split_index_offset + split_num_samples ] split_index_offset += split_num_samples final_split_indices, total_samples = balance_dataset(df_split, split_config) final_split_df = df_split.loc[final_split_indices] np.savetxt(split_path, final_split_df.index.to_numpy(), fmt="%d", delimiter=",") with open(str(variant_path / "variant_config.yaml"), "w") as f: yaml.dump(variant_config, f) return variant_path
[docs] def get_detoxai_datasets( config: dict, transform: Optional[ Callable ] = None, # takes in a PIL image and returns a transformed version transforms: Optional[ Callable ] = None, # takes in an image and a label and returns the transformed versions of both target_transform: Optional[ Callable ] = None, # A function/transform that takes in the target and transforms it. download: bool = False, seed: Optional[int] = None, device: Union[str, None] = None, saved_variant: Optional[str] = None, ) -> Dict[str, "DetoxaiDataset"]: """ Args: config: dict: transform: Optional[Callable]: (Default value = None) # takes in a PIL image and returns a transformed versiontransforms: Optional[Callable]: (Default value = None) # takes in an image and a label and returns the transformed versions of bothtarget_transform: Optional[Callable]: (Default value = None) # A function/transform that takes in the target and transforms it.download: bool: (Default value = False) seed: Optional[int]: (Default value = None) device: Union[str: None]: (Default value = None) saved_variant: Optional[str]: (Default value = None) Returns: """ detoxai_dataset_path = Path(DETOXAI_DATASET_PATH) if saved_variant is not None: variant_path = ( Path(DETOXAI_DATASET_PATH) / config["name"] / "variants" / saved_variant ) split_files = list(variant_path.glob("splits/*.txt")) split_indices = {} for split_file in split_files: split_name = split_file.stem split_indices[split_name] = np.loadtxt(split_file, dtype=int, delimiter=",") else: if seed is not None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # generate indices for all the splits randomly labels = pd.read_csv(detoxai_dataset_path / config["name"] / "labels.csv") all_indices = np.arange(len(labels)) np.random.shuffle(all_indices) split_indices = {} start = 0 for split_name, frac in config["splits"].items(): end = start + int(frac * len(all_indices)) split_indices[split_name] = all_indices[start:end] start = end datasets = {} for split, indices in split_indices.items(): datasets[split] = DetoxaiDataset( config, detoxai_dataset_path, indices, transform=transform, transforms=transforms, target_transform=target_transform, download=download, seed=seed, device=device, ) return datasets
[docs] class DetoxaiDataset(VisionDataset): """ """ def __init__( self, config: dict, root: Union[str, Path], split_indices: np.ndarray, transform: Optional[ Callable ] = None, # takes in a PIL image and returns a transformed version transforms: Optional[ Callable ] = None, # takes in an image and a label and returns the transformed versions of both target_transform: Optional[ Callable ] = None, # A function/transform that takes in the target and transforms it. download: bool = False, seed: Optional[int] = None, device: str = None, ) -> None: super().__init__( root, transform=transform, transforms=transforms, target_transform=target_transform, ) self.config = config self.root = Path(root) self.device = device if download: self.download() if not self._check_integrity(): raise RuntimeError( "Dataset not found or corrupted. You can use download=True to download it" ) self.labels = self._read_labels_from_file() self.labels_mapping = self._read_labels_mapping_from_file() # self._target_labels_translation = self.get_target_label_translation() self.split_indices = split_indices if seed is not None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) def _read_labels_from_file(self) -> pd.DataFrame: """ """ df = pd.read_csv(self.root / self.config["name"] / "labels.csv") return df def _read_labels_mapping_from_file(self) -> pd.DataFrame: """ """ labels_mapping_from_yaml = yaml.safe_load( (self.root / self.config["name"] / "labels_mapping.yaml").open() ) return labels_mapping_from_yaml
[docs] def download(self): """ """ pass
def _check_integrity(self) -> bool: """ """ return (self.root / self.config["name"]).exists() def __len__(self) -> int: return len(self.split_indices) def __getitem__(self, idx: int) -> Tuple[PIL.Image.Image, int, dict]: img = self._load_image(self.split_indices[idx]) label = self._load_label(self.split_indices[idx]) fairness_attributes = self._load_fairness_attributes(self.split_indices[idx]) if self.transforms is not None: img, label = self.transforms(img, label) else: if self.transform is not None: img = self.transform(img) if self.target_transform is not None: label = self.target_transform(label) return img, label, fairness_attributes def _load_image(self, idx: int) -> PIL.Image.Image: """ Args: idx: int: Returns: """ img_path = ( self.root / self.config["name"] / "data" / self.labels.iloc[idx]["image_id"] ) img = PIL.Image.open(img_path) return img def _load_label(self, idx: int): """ Args: idx: int: Returns: """ label = self.labels.iloc[idx][self.config["target"]] return label def _load_fairness_attributes(self, idx: int) -> dict: """ Args: idx: int: Returns: """ fairness_attributes = {} for key, value in self.labels_mapping.items(): # fairness_attributes[key] = value[self.labels.iloc[idx][key]] fairness_attributes[key] = self.labels.iloc[idx][key] return fairness_attributes
[docs] def get_class_names(self) -> List[str]: """ """ return [ f"{self.config['target']}_{str(item).replace(' ', '_')}" for key, item in self.labels_mapping[self.config["target"]].items() ]
# def get_target_label_translation(self) -> dict: # return {i: name for i, name in enumerate(self.get_class_names())}
[docs] def get_num_classes(self) -> int: """ """ return len(self.labels_mapping[self.config["target"]])
[docs] def get_collate_fn(self, protected_attribute: str, protected_attribute_value: str): """ Args: protected_attribute: str: protected_attribute_value: str: Returns: """ def collate_fn( batch: List[Tuple[torch.Tensor, str, Dict[str, Union[str, int]]]], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: batch: List[Tuple[torch.Tensor: str: Dict[str: Union[str: int]]]]: Returns: """ images = torch.stack([item[0] for item in batch]) labels = torch.tensor([item[1] for item in batch]) protected_attributes = torch.tensor( [ int(item[2].get(protected_attribute) == protected_attribute_value) for item in batch ] ) return images, labels, protected_attributes return collate_fn