# -*- coding: utf-8 -*- # @Author: Weisen Pan # Import necessary modules from .vision import VisionDataset # Import the base VisionDataset class from PIL import Image # Import PIL for image loading and processing import os # For interacting with the file system import torch # PyTorch for tensor operations # Function to check if a file has an allowed extension def validate_file_extension(filename, extensions): """ Check if a file has an allowed extension. Args: filename (str): Path to the file. extensions (tuple of str): Extensions to consider (in lowercase). Returns: bool: True if the filename ends with one of the given extensions. """ return filename.lower().endswith(extensions) # Function to check if a file is an image def is_image_file(filename): """ Check if a file is an image based on its extension. Args: filename (str): Path to the file. Returns: bool: True if the filename is a known image format. """ return validate_file_extension(filename, IMG_EXTENSIONS) # Function to create a dataset of file paths and their corresponding class indices def generate_dataset(directory, class_to_idx, extensions=None, is_valid_file=None): """ Creates a list of file paths and their corresponding class indices. Args: directory (str): Root directory. class_to_idx (dict): Mapping of class names to class indices. extensions (tuple, optional): Allowed file extensions. is_valid_file (callable, optional): Function to validate files. Returns: list: A list of (file_path, class_index) tuples. """ instances = [] directory = os.path.expanduser(directory) # Expand user directory path if needed # Ensure only one of extensions or is_valid_file is specified if (extensions is None and is_valid_file is None) or (extensions is not None and is_valid_file is not None): raise ValueError("Specify either 'extensions' or 'is_valid_file', but not both.") # Define the validation function if extensions are provided if extensions is not None: def is_valid_file(x): return validate_file_extension(x, extensions) # Iterate through the directory, searching for valid image files for target_class in sorted(class_to_idx.keys()): class_index = class_to_idx[target_class] # Get the class index target_dir = os.path.join(directory, target_class) # Define the target class folder if not os.path.isdir(target_dir): # Skip if it's not a directory continue # Walk through the directory and subdirectories for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): for fname in sorted(fnames): path = os.path.join(root, fname) # Full file path if is_valid_file(path): # Check if it's a valid file instances.append((path, class_index)) # Append file path and class index to the list return instances # Return the dataset # DatasetFolder class: Generic data loader for samples arranged in subdirectories by class class DatasetFolder(VisionDataset): """ A generic data loader where samples are arranged in subdirectories by class. Args: root (str): Root directory path. loader (callable): Function to load a sample from its file path. extensions (tuple[str]): Allowed file extensions. apply_transformation (callable, optional): apply_transformation applied to each sample. target_apply_transformation (callable, optional): apply_transformation applied to each target. is_valid_file (callable, optional): Function to validate files. split_factor (int, optional): Number of times to apply the apply_transformation. Attributes: classes (list): Sorted list of class names. class_to_idx (dict): Mapping of class names to class indices. samples (list): List of (sample_path, class_index) tuples. targets (list): List of class indices corresponding to each sample. """ def __init__(self, root, loader, extensions=None, apply_transformation=None, target_apply_transformation=None, is_valid_file=None, split_factor=1): super().__init__(root, apply_transformation=apply_transformation, target_apply_transformation=target_apply_transformation) self.classes, self.class_to_idx = self._discover_classes(self.root) # Discover classes in the root directory self.samples = generate_dataset(self.root, self.class_to_idx, extensions, is_valid_file) # Create dataset from files # Raise an error if no valid files are found if len(self.samples) == 0: raise RuntimeError(f"Found 0 files in subfolders of: {self.root}. " f"Supported extensions are: {','.join(extensions)}") self.loader = loader # Function to load a sample self.extensions = extensions # Allowed file extensions self.targets = [s[1] for s in self.samples] # List of target class indices self.split_factor = split_factor # Number of apply_transformationations to apply # Function to find class subdirectories in the root directory def _discover_classes(self, dir): """ Discover class subdirectories in the root directory. Args: dir (str): Root directory. Returns: tuple: (classes, class_to_idx) where classes are subdirectories of 'dir', and class_to_idx is a mapping of class names to indices. """ classes = sorted([d.name for d in os.scandir(dir) if d.is_dir()]) # List of subdirectory names (classes) class_to_idx = {classes[i]: i for i in range(len(classes))} # Map class names to indices return classes, class_to_idx # Function to get a sample and its target by index def __getitem__(self, index): """ Retrieve a sample and its target by index. Args: index (int): Index of the sample. Returns: tuple: (sample, target), where the sample is the apply_transformationed image and the target is the class index. """ path, target = self.samples[index] # Get the file path and target class index sample = self.loader(path) # Load the sample (image) # Apply apply_transformationation to the sample 'split_factor' times imgs = [self.apply_transformation(sample) for _ in range(self.split_factor)] if self.apply_transformation else NotImplementedError # Apply target apply_transformationation if specified if self.target_apply_transformation: target = self.target_apply_transformation(target) return torch.cat(imgs, dim=0), target # Return concatenated apply_transformationed images and the target # Return the number of samples in the dataset def __len__(self): return len(self.samples) # List of supported image file extensions IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') # Function to load an image using PIL def load_image_pil(path): """ Load an image from the given path using PIL. Args: path (str): Path to the image. Returns: Image: RGB image. """ with open(path, 'rb') as f: img = Image.open(f) # Open the image file return img.convert('RGB') # Convert the image to RGB format # Function to load an image using accimage library with fallback to PIL def load_accimage(path): """ Load an image using the accimage library, falling back to PIL on failure. Args: path (str): Path to the image. Returns: Image: Image loaded with accimage or PIL. """ import accimage # accimage is a faster image loading library try: return accimage.Image(path) # Try loading with accimage except IOError: return load_image_pil(path) # Fall back to PIL on error # Function to load an image using the default backend (accimage or PIL) def basic_loader(path): """ Load an image using the default image backend (accimage or PIL). Args: path (str): Path to the image. Returns: Image: Loaded image. """ from torchvision import get_image_backend # Get the default image backend return load_accimage(path) if get_image_backend() == 'accimage' else load_image_pil(path) # Load using the appropriate backend # ImageFolder class: A dataset loader for images arranged in subdirectories by class class ImageFolder(DatasetFolder): """ A dataset loader for images arranged in subdirectories by class. Args: root (str): Root directory path. apply_transformation (callable, optional): apply_transformation applied to each image. target_apply_transformation (callable, optional): apply_transformation applied to each target. loader (callable, optional): Function to load an image from its path. is_valid_file (callable, optional): Function to validate files. Attributes: classes (list): Sorted list of class names. class_to_idx (dict): Mapping of class names to class indices. imgs (list): List of (image_path, class_index) tuples. """ def __init__(self, root, apply_transformation=None, target_apply_transformation=None, loader=basic_loader, is_valid_file=None, split_factor=1): super().__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, apply_transformation=apply_transformation, target_apply_transformation=target_apply_transformation, is_valid_file=is