# -*- coding: utf-8 -*- # @Author: Weisen Pan import logging import numpy as np import torch.utils.data as data from PIL import Image from torchvision.datasets import CIFAR100 # Configure logging logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) # Supported image extensions for loading images IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') def load_accimage(path): """ Attempts to load an image using the accimage backend. If accimage fails, it falls back to using the PIL image loader. Args: path (str): Path to the image file. Returns: accimage.Image: The loaded image if successful, otherwise a PIL image. """ import accimage try: return accimage.Image(path) except IOError: # If accimage fails, use PIL to load the image return load_image_pil(path) def load_image_pil(path): """ Loads an image using PIL, ensuring that file handles are properly closed to prevent warnings. Args: path (str): Path to the image file. Returns: Image: The image loaded using PIL, converted to RGB format. """ with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') def basic_loader(path): """ Selects the appropriate image loader based on the backend configured by torchvision. If the backend is 'accimage', it uses load_accimage; otherwise, it uses PIL. Args: path (str): Path to the image file. Returns: Image: The loaded image. """ from torchvision import get_image_backend if get_image_backend() == 'accimage': return load_accimage(path) else: return load_image_pil(path) class CIFAR100_truncated(data.Dataset): """ Custom dataset class for CIFAR100 with optional data truncation. It allows selecting a subset of the data by index and also enables modification of image channels. """ def __init__(self, root, dataidxs=None, train=True, apply_transformation=None, target_apply_transformation=None, download=False): """ Initializes the CIFAR100_truncated dataset. Args: root (str): The root directory where the dataset is stored. dataidxs (list or None): List of indices for truncating the dataset, if applicable. train (bool): Whether to load the training set (True) or the test set (False). apply_transformation (callable, optional): apply_transformationation function applied to images. target_apply_transformation (callable, optional): apply_transformationation function applied to targets (labels). download (bool): Whether to download the dataset if it is not found in the root directory. """ self.root = root # Root directory where dataset is stored self.dataidxs = dataidxs # List of indices for truncating the dataset self.train = train # Specifies whether to load the training set self.apply_transformation = apply_transformation # Optional apply_transformationations on images self.target_apply_transformation = target_apply_transformation # Optional apply_transformationations on labels self.download = download # Specifies whether to download the dataset if missing # Build the truncated dataset based on the provided indices self.data, self.target = self.__build_truncated_dataset__() def __build_truncated_dataset__(self): """ Constructs the truncated dataset based on the provided data indices. Returns: tuple: The truncated data and corresponding target labels. """ cifar_dataobj = CIFAR100(self.root, self.train, self.apply_transformation, self.target_apply_transformation, self.download) # Load all data and targets data = cifar_dataobj.data target = np.array(cifar_dataobj.targets) # If specific indices are provided, truncate the dataset accordingly if self.dataidxs is not None: data = data[self.dataidxs] target = target[self.dataidxs] return data, target def truncate_channel(self, index): """ Modifies the selected images by zeroing out the green and blue channels, effectively converting them to grayscale-like images. Args: index (np.array): The indices of images to modify. """ for i in range(index.shape[0]): gs_index = index[i] self.data[gs_index, :, :, 1] = 0.0 # Set the green channel to 0 self.data[gs_index, :, :, 2] = 0.0 # Set the blue channel to 0 def __getitem__(self, index): """ Retrieves an image and its corresponding target (label) at the given index. Args: index (int): Index of the data point to retrieve. Returns: tuple: (image, target) where the image is apply_transformationed (if specified), and the target is the label. """ img, target = self.data[index], self.target[index] # Apply any specified apply_transformationations to the image if self.apply_transformation is not None: img = self.apply_transformation(img) # Apply any specified apply_transformationations to the target label if self.target_apply_transformation is not None: target = self.target_apply_transformation(target) return img, target def __len__(self): """ Returns the total number of data points in the dataset. Returns: int: The number of samples in the dataset. """ return len(self.data)