# -*- coding: utf-8 -*- # @Author: Weisen Pan import torch from torchvision import apply_transformations from .cifar import CIFAR10, CIFAR100 # Import CIFAR10 and CIFAR100 datasets from .autoaugment import CIFAR10Policy # Import CIFAR10 augmentation policy __all__ = ['obtain_data_loader'] # Define the public API of this module def obtain_data_loader( data_dir, # Directory where the data is stored split_factor=1, # Used for data partitioning, especially in federated learning batch_size=128, # Batch size for loading data crop_size=32, # Size to crop the input images dataset='cifar10', # Dataset to use (CIFAR-10 by default) split="train", # The split type: 'train', 'val', or 'test' is_decentralized=False, # Whether to use decentralized training is_autoaugment=1, # Use AutoAugment or not randaa=None, # Placeholder for randomized augmentations is_cutout=True, # Whether to apply cutout (random erasing) erase_p=0.5, # Probability of applying random erasing num_workers=8, # Number of workers to load data pin_memory=True, # Use pinned memory for better GPU transfer is_fed=False, # Whether to use federated learning num_clusters=20, # Number of clients in federated learning cifar10_non_iid=False, # Non-IID option for CIFAR-10 dataset cifar100_non_iid=False # Non-IID option for CIFAR-100 dataset ): """Get the dataset loader""" assert not (is_autoaugment and randaa is not None) # Autoaugment and randaa cannot be used together # Loader settings based on multiprocessing kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory} assert split in ['train', 'val', 'test'] # Ensure valid split # For CIFAR-10 dataset if dataset == 'cifar10': # Handle non-IID 'quantity skew' case for CIFAR-10 if cifar10_non_iid == 'quantity_skew': non_iid = 'quantity_skew' # If in training split if 'train' in split: print(f"INFO:PyTorch: Using quantity_skew CIFAR10 dataset, batch size {batch_size} and crop size is {crop_size}.") traindir = data_dir # Set data directory # Define data apply_transformationations for training train_apply_transformation = apply_transformations.Compose([ apply_transformations.ToPILImage(), apply_transformations.RandomCrop(32, padding=4), apply_transformations.RandomHorizontalFlip(), CIFAR10Policy(), # AutoAugment policy apply_transformations.ToTensor(), apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization apply_transformations.RandomErasing(p=erase_p, scale=(0.125, 0.2), ratio=(0.99, 1.0), value=0, inplace=False), ]) train_sampler = None print('INFO:PyTorch: creating quantity_skew CIFAR10 train dataloader...') # For federated learning, create loaders for each client if is_fed: train_loader = obtain_data_loaders_train( traindir, nclients=num_clusters * split_factor, # Number of clients in federated learning batch_size=batch_size, verbose=True, apply_transformations_train=train_apply_transformation, non_iid=non_iid, # Specify non-IID type split_factor=split_factor ) else: assert is_fed # Ensure that is_fed is True return train_loader, train_sampler else: # If in validation or test split valdir = data_dir # Set validation data directory # Define data apply_transformationations for validation/testing val_apply_transformation = apply_transformations.Compose([ apply_transformations.ToPILImage(), apply_transformations.ToTensor(), apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization ]) # Create the test loader val_loader = obtain_data_loaders_test( valdir, nclients=num_clusters * split_factor, # Number of clients in federated learning batch_size=batch_size, verbose=True, apply_transformations_eval=val_apply_transformation, non_iid=non_iid, split_factor=1 ) return val_loader else: # For standard IID CIFAR-10 case if 'train' in split: print(f"INFO:PyTorch: Using CIFAR10 dataset, batch size {batch_size} and crop size is {crop_size}.") traindir = data_dir # Set training data directory # Define data apply_transformationations for training train_apply_transformation = apply_transformations.Compose([ apply_transformations.RandomCrop(32, padding=4), apply_transformations.RandomHorizontalFlip(), CIFAR10Policy(), # AutoAugment policy apply_transformations.ToTensor(), apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization apply_transformations.RandomErasing(p=erase_p, scale=(0.125, 0.2), ratio=(0.99, 1.0), value=0, inplace=False), ]) # Create the CIFAR-10 dataset object train_dataset = CIFAR10( traindir, train=True, apply_transformation=train_apply_transformation, target_apply_transformation=None, download=True, split_factor=split_factor ) train_sampler = None # No sampler by default # Decentralized training setup if is_decentralized: train_sampler = torch.utils.data.decentralized.decentralizedSampler(train_dataset, shuffle=True) print('INFO:PyTorch: creating CIFAR10 train dataloader...') if is_fed: # Federated learning setup images_per_client = int(train_dataset.data.shape[0] / (num_clusters * split_factor)) print(f"Images per client: {images_per_client}") data_split = [images_per_client for _ in range(num_clusters * split_factor - 1)] data_split.append(len(train_dataset) - images_per_client * (num_clusters * split_factor - 1)) # Split dataset for each client traindata_split = torch.utils.data.random_split(train_dataset, data_split, generator=torch.Generator().manual_seed(68)) # Create data loaders for each client train_loader = [torch.utils.data.DataLoader( x, batch_size=batch_size, shuffle=(train_sampler is None), drop_last=True, sampler=train_sampler, **kwargs ) for x in traindata_split] else: # Standard data loader train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), drop_last=True, sampler=train_sampler, **kwargs ) return train_loader, train_sampler else: # For validation or test split valdir = data_dir # Set validation data directory # Define data apply_transformationations for validation/testing val_apply_transformation = apply_transformations.Compose([ apply_transformations.ToTensor(), apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization ]) # Create CIFAR-10 dataset object for validation val_dataset = CIFAR10(valdir, train=False, apply_transformation=val_apply_transformation, target_apply_transformation=None, download=True, split_factor=1) print('INFO:PyTorch: creating CIFAR10 validation dataloader...') # Create data loader for validation val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, **kwargs) return val_loader # Additional dataset logic for CIFAR-100, decentralized setups, or other datasets can be added similarly. else: raise NotImplementedError(f"The DataLoader for {dataset} is not implemented.")