# -*- coding: utf-8 -*- # @Author: Weisen Pan # Import necessary libraries from PIL import Image # For image handling import os # For file path operations import numpy as np # For numerical operations import pickle # For loading serialized data import torch # For PyTorch operations # Import custom classes and functions from the current package from .vision import VisionDataset from .utils import validate_integrity, fetch_and_extract_archive # CIFAR10 dataset class class CIFAR10(VisionDataset): """ CIFAR10 Dataset class that handles the CIFAR-10 dataset loading, processing, and apply_transformationations. Args: root (str): Directory where the dataset is stored or will be downloaded to. train (bool, optional): If True, load the training set. Otherwise, load the test set. apply_transformation (callable, optional): A function/apply_transformation that takes a PIL image and returns a apply_transformationed version. target_apply_transformation (callable, optional): A function/apply_transformation that takes the target and apply_transformations it. download (bool, optional): If True, download the dataset if it's not found locally. split_factor (int, optional): Number of apply_transformationations applied to each image. Default is 1. """ # Directory and URL details for downloading the CIFAR-10 dataset base_folder = 'cifar-10-batches-py' url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" filename = "cifar-10-python.tar.gz" tgz_md5 = 'c58f30108f718f92721af3b95e74349a' # MD5 checksum to verify the file's integrity # List of training batches with their corresponding MD5 checksums train_list = [ ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], ['data_batch_4', '634d18415352ddfa80567beed471001a'], ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'] ] # List of test batches with their corresponding MD5 checksums test_list = [ ['test_batch', '40351d587109b95175f43aff81a1287e'] ] # Info map to hold label names and their checksum info_map = { 'filename': 'batches.info_map', 'key': 'label_names', 'md5': '5ff9c542aee3614f3951f8cda6e48888' } # Initialization method def __init__(self, root, train=True, apply_transformation=None, target_apply_transformation=None, download=False, split_factor=1): super(CIFAR10, self).__init__(root, apply_transformation=apply_transformation, target_apply_transformation=target_apply_transformation) self.train = train # Whether to load the training set or test set self.split_factor = split_factor # Number of apply_transformationations to apply # Download dataset if necessary if download: self.download() # Check if the dataset is already downloaded and valid if not self._validate_integrity(): raise RuntimeError('Dataset not found or corrupted. Use download=True to download it.') # Load the dataset self.data, self.targets = self._load_data() # Load the label info map (to get class names) self._load_info_map() # Load dataset from the files def _load_data(self): data, targets = [], [] # Initialize lists to hold data and labels files = self.train_list if self.train else self.test_list # Choose train or test files # Load each file, deserialize with pickle, and append data and labels for file_name, _ in files: file_path = os.path.join(self.root, self.base_folder, file_name) with open(file_path, 'rb') as f: entry = pickle.load(f, encoding='latin1') # Load file data.append(entry['data']) # Append image data targets.extend(entry.get('labels', entry.get('fine_labels', []))) # Append labels # Reshape and format the data to (num_samples, height, width, channels) data = np.vstack(data).reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) # Reshape to HWC format return data, targets # Load label names (info map) def _load_info_map(self): info_map_path = os.path.join(self.root, self.base_folder, self.info_map['filename']) # Path to info map if not validate_integrity(info_map_path, self.info_map['md5']): # Check integrity of info map raise RuntimeError('info_mapdata file not found or corrupted. Use download=True to download it.') # Load the label names with open(info_map_path, 'rb') as info_map_file: info_map_data = pickle.load(info_map_file, encoding='latin1') # Load label names self.classes = info_map_data[self.info_map['key']] # Extract class labels self.class_to_idx = {label: idx for idx, label in enumerate(self.classes)} # Map class names to indices # Get item (image and target) by index def __getitem__(self, index): """ Get the item (image, target) at the specified index. Args: index (int): Index of the data. Returns: tuple: apply_transformationed image and the target class. """ img, target = self.data[index], self.targets[index] # Get image and target label img = Image.fromarray(img) # Convert numpy array to PIL image # Apply the apply_transformation multiple times based on split_factor imgs = [self.apply_transformation(img) for _ in range(self.split_factor)] if self.apply_transformation else None if imgs is None: raise NotImplementedError('apply_transformation must be provided.') # Apply target apply_transformationation if available if self.target_apply_transformation: target = self.target_apply_transformation(target) return torch.cat(imgs, dim=0), target # Return concatenated apply_transformationed images and the target # Return the number of items in the dataset def __len__(self): return len(self.data) # Check if the dataset files are valid and downloaded def _validate_integrity(self): files = self.train_list + self.test_list # All files to check for file_name, md5 in files: file_path = os.path.join(self.root, self.base_folder, file_name) if not validate_integrity(file_path, md5): # Verify integrity using MD5 return False return True # Download the dataset if it's not available def download(self): if self._validate_integrity(): print('Files already downloaded and verified') else: fetch_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) # Representation string to include the split type (Train/Test) def extra_repr(self): return f"Split: {'Train' if self.train else 'Test'}" # CIFAR100 is a subclass of CIFAR10, with minor modifications class CIFAR100(CIFAR10): """ CIFAR100 Dataset, a subclass of CIFAR10. """ # Directory and URL details for downloading CIFAR-100 dataset base_folder = 'cifar-100-vision' url = "https://www.cs.toronto.edu/~kriz/cifar-100-vision.tar.gz" filename = "cifar-100-vision.tar.gz" tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' # MD5 checksum # Training and test lists with their corresponding MD5 checksums for CIFAR-100 train_list = [ ['train', '16019d7e3df5f24257cddd939b257f8d'] ] test_list = [ ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'] ] # Info map to hold fine label names and their checksum info_map = { 'filename': 'info_map', 'key': 'fine_label_names', 'md5': '7973b15100ade9c7d40fb424638fde48' }