# -*- coding: utf-8 -*- # @Author: Weisen Pan import logging import numpy as np import torch.utils.data as data from PIL import Image from torchvision.datasets import CIFAR10 # Set up logging logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) # Supported image extensions # These are the file extensions that the loaders will support for image formats IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') # Loader using accimage, a faster image loading library than PIL def load_accimage(path): import accimage try: # Try to load the image with accimage return accimage.Image(path) except IOError: # If there's an error, fallback to PIL for image loading return load_image_pil(path) # Loader using PIL (Python Imaging Library) def load_image_pil(path): # Open the file in binary mode to avoid resource warnings with open(path, 'rb') as f: img = Image.open(f) # Convert the image to RGB mode (3 channels) return img.convert('RGB') # Default image loader that chooses accimage if available, otherwise PIL def basic_loader(path): from torchvision import get_image_backend # Check if the image backend is accimage if get_image_backend() == 'accimage': return load_accimage(path) # Otherwise, fallback to PIL return load_image_pil(path) # Custom CIFAR10 dataset with truncation capabilities # This class extends the torch.utils.data.Dataset to support CIFAR10 with truncation of data class CIFAR10Truncated(data.Dataset): def __init__(self, root, dataidxs=None, train=True, apply_transformation=None, target_apply_transformation=None, download=False): self.root = root # Root directory for the dataset self.dataidxs = dataidxs # Subset of data indices (optional) self.train = train # Boolean flag indicating if the dataset is for training self.apply_transformation = apply_transformation # apply_transformationations to apply to the images (optional) self.target_apply_transformation = target_apply_transformation # apply_transformationations to apply to the labels (optional) self.download = download # Boolean flag to download the dataset if not available # Build the truncated dataset based on the provided indices self.data, self.target = self._build_truncated_dataset() def _build_truncated_dataset(self): # Log whether the dataset is being downloaded logger.info(f"Download: {self.download}") # Load the CIFAR10 dataset from torchvision cifar_data = CIFAR10(self.root, self.train, apply_transformation=self.apply_transformation, target_apply_transformation=self.target_apply_transformation, download=self.download) # Extract data (images) and targets (labels) from the CIFAR10 dataset data = cifar_data.data target = np.array(cifar_data.targets) # If data indices are provided, filter the data and targets accordingly if self.dataidxs is not None: data = data[self.dataidxs] target = target[self.dataidxs] # Return the truncated data and targets return data, target def truncate_channel(self, indices): # Zero out the second and third channels (green and blue) for selected images for idx in indices: self.data[idx, :, :, 1] = 0.0 # Zero out the green channel self.data[idx, :, :, 2] = 0.0 # Zero out the blue channel def __getitem__(self, index): """ Args: index (int): Index of the image Returns: tuple: (image, target) where target is the class label. """ img, target = self.data[index], self.target[index] # Apply image apply_transformationations if any are specified if self.apply_transformation is not None: img = self.apply_transformation(img) # Apply target apply_transformationations if any are specified if self.target_apply_transformation is not None: target = self.target_apply_transformation(target) # Return the apply_transformationed image and its corresponding target return img, target def __len__(self): # Return the total number of images in the dataset return len(self.data)