# -*- coding: utf-8 -*- # @Author: Weisen Pan import os import torch import torch.utils.data as data # VisionDataset is a custom dataset class inheriting from PyTorch's Dataset class. # It handles the initialization and representation of a vision-related dataset, # including optional apply_transformationation of input data and targets. class VisionDataset(data.Dataset): _repr_indent = 4 # Defines the indentation level for dataset representation def __init__(self, root, apply_transformations=None, apply_transformation=None, target_apply_transformation=None): # Initializes the dataset by setting root directory and optional apply_transformationations # If root is a string, expand any user directory shortcuts like "~" self.root = os.path.expanduser(root) if isinstance(root, str) else root # Check if either 'apply_transformations' or 'apply_transformation/target_apply_transformation' is provided (but not both) has_apply_transformations = apply_transformations is not None has_separate_apply_transformation = apply_transformation is not None or target_apply_transformation is not None if has_apply_transformations and has_separate_apply_transformation: raise ValueError("Only one of 'apply_transformations' or 'apply_transformation/target_apply_transformation' can be provided.") # Set apply_transformationations self.apply_transformation = apply_transformation self.target_apply_transformation = target_apply_transformation # If separate apply_transformations are provided, wrap them in a StandardTransform if has_separate_apply_transformation: apply_transformations = StandardTransform(apply_transformation, target_apply_transformation) self.apply_transformations = apply_transformations # Placeholder for the method to retrieve an item by index def __getitem__(self, index): raise NotImplementedError # Placeholder for the method to return dataset length def __len__(self): raise NotImplementedError # Representation of the dataset including number of datapoints, root directory, and apply_transformations def __repr__(self): head = f"Dataset {self.__class__.__name__}" body = [f"Number of datapoints: {self.__len__()}"] if self.root is not None: body.append(f"Root location: {self.root}") body += self.extra_repr().splitlines() # Include any additional representation details if hasattr(self, "apply_transformations") and self.apply_transformations is not None: body.append(repr(self.apply_transformations)) # Include apply_transformationation details if applicable lines = [head] + [" " * self._repr_indent + line for line in body] return '\n'.join(lines) # Utility to format the representation of the apply_transformation and target_apply_transformation attributes def _format_apply_transformation_repr(self, apply_transformation, head): lines = apply_transformation.__repr__().splitlines() return [f"{head}{lines[0]}"] + [f"{' ' * len(head)}{line}" for line in lines[1:]] # Hook for adding extra dataset-specific information in the representation def extra_repr(self): return "" # StandardTransform class handles the application of the apply_transformation and target_apply_transformation # during dataset iteration or data loading. class StandardTransform: def __init__(self, apply_transformation=None, target_apply_transformation=None): # Initialize with optional input and target apply_transformationations self.apply_transformation = apply_transformation self.target_apply_transformation = target_apply_transformation # Calls the appropriate apply_transformations on the input and target when invoked def __call__(self, input, target): if self.apply_transformation is not None: input = self.apply_transformation(input) if self.target_apply_transformation is not None: target = self.target_apply_transformation(target) return input, target # Utility to format the apply_transformationation representation def _format_apply_transformation_repr(self, apply_transformation, head): lines = apply_transformation.__repr__().splitlines() return [f"{head}{lines[0]}"] + [f"{' ' * len(head)}{line}" for line in lines[1:]] # Representation of the StandardTransform including both input and target apply_transformationations def __repr__(self): body = [self.__class__.__name__] if self.apply_transformation is not None: body += self._format_apply_transformation_repr(self.apply_transformation, "apply_transformation: ") if self.target_apply_transformation is not None: body += self._format_apply_transformation_repr(self.target_apply_transformation, "Target apply_transformation: ") return '\n'.join(body)