Source code for detoxai.utils.dataloader

import itertools
import logging

from torch.utils.data import DataLoader, Dataset

from .datasets import DetoxaiDataset

logger = logging.getLogger(__name__)


[docs] class DetoxaiDataLoader(DataLoader): """ """ def __init__(self, dataset: DetoxaiDataset, **kwargs): super().__init__(dataset, **kwargs)
[docs] def get_class_names(self): """ """ assert isinstance(self.dataset, DetoxaiDataset), ( "Dataset must be an instance of DetoxaiDataset, as we rely on its internal structure" ) return self.dataset.get_class_names()
[docs] def get_nth_batch(self, n: int) -> tuple: """ Args: n: int: Returns: """ for i, batch in enumerate(self): if i == n: return batch return None
[docs] def get_nth_batch2(self, n: int) -> tuple: """ Args: n: int: Returns: """ if n < 0 or n >= len(self): return None # Create a new iterator dataiter = iter(self) # Use itertools.islice to get to the desired batch directly batch = next(itertools.islice(dataiter, n, n + 1), None) return batch
[docs] def get_num_classes(self) -> int: assert isinstance(self.dataset, DetoxaiDataset), ( "Dataset must be an instance of DetoxaiDataset, as we rely on its internal structure" ) return self.dataset.get_num_classes()
[docs] class WrappedDataLoader(DetoxaiDataLoader): def __init__(self, dataset: Dataset, num_of_classes: int, **kwargs): super().__init__(dataset, **kwargs) self.num_of_classes = num_of_classes
[docs] def get_num_classes(self) -> int: return self.num_of_classes
[docs] def get_class_names(self) -> list[str]: """Name them A B C ..""" return [chr(i) for i in range(65, 65 + self.num_of_classes)]
[docs] def copy_data_loader( dataloader: DataLoader, batch_size: int | None = None, shuffle: bool = False, drop_last: bool = False, ) -> DetoxaiDataLoader | WrappedDataLoader: """Copy the dataloader Args: dataloader: DataLoader: batch_size: int | None: (Default value = None) shuffle: bool: (Default value = False) drop_last: bool: (Default value = False) Returns: """ if batch_size is None: batch_size = dataloader.batch_size if isinstance(dataloader, DetoxaiDataLoader): return DetoxaiDataLoader( dataset=dataloader.dataset, batch_size=batch_size, num_workers=dataloader.num_workers, collate_fn=dataloader.collate_fn, shuffle=shuffle, drop_last=drop_last, ) elif isinstance(dataloader, WrappedDataLoader): return WrappedDataLoader( dataset=dataloader.dataset, num_of_classes=dataloader.num_of_classes, batch_size=batch_size, num_workers=dataloader.num_workers, shuffle=shuffle, drop_last=drop_last, ) else: raise ValueError( f"Unsupported DataLoader type: {type(dataloader)}. " "Please use DetoxaiDataLoader or WrappedDataLoader." )