# -*- coding: utf-8 -*- # @Author: Weisen Pan #### Get CIFAR-100 dataset in X and Y form import torchvision import numpy as np import random import torch from torchvision import apply_transformations from torch.utils.data import DataLoader, Dataset from .cifar10_non_iid import * # Set random seeds for reproducibility np.random.seed(68) random.seed(68) def get_cifar100(data_dir): ''' Load and return CIFAR-100 train/test data and labels as numpy arrays. Parameters: data_dir (str): Directory where the CIFAR-100 dataset will be downloaded/saved. Returns: x_train (ndarray): Training data. y_train (ndarray): Training labels. x_test (ndarray): Test data. y_test (ndarray): Test labels. ''' # Download CIFAR-100 training and test datasets data_train = torchvision.datasets.CIFAR100(data_dir, train=True, download=True) data_test = torchvision.datasets.CIFAR100(data_dir, train=False, download=True) # Transpose data for proper channel order and convert labels to numpy arrays x_train, y_train = data_train.data.transpose((0, 3, 1, 2)), np.array(data_train.targets) x_test, y_test = data_test.data.transpose((0, 3, 1, 2)), np.array(data_test.targets) return x_train, y_train, x_test, y_test def split_cf100_real_world_images(data, labels, n_clients=100, verbose=True): ''' Splits data and labels among n_clients to simulate a non-IID distribution. Parameters: data (ndarray): Dataset images [n_data x shape]. labels (ndarray): Dataset labels [n_data]. n_clients (int): Number of clients to split the data among. verbose (bool): Print detailed information if True. Returns: clients_split (ndarray): Split data and labels for each client. ''' n_labels = np.max(labels) + 1 # Number of unique labels/classes def divide_into_sections(n, m): '''Return m random integers that sum up to n.''' result = [1] * m for _ in range(n - m): result[random.randint(0, m - 1)] += 1 return result # Shuffle and partition classes n_classes = len(set(labels)) # Number of unique classes classes = list(range(n_classes)) np.random.shuffle(classes) # Shuffle class indices label_indices = [list(np.where(labels == class_)[0]) for class_ in classes] # Indices of each class in labels # Define number of classes for each client (randomized) tmp = [np.random.randint(1, 100) for _ in range(n_clients)] total_partition = sum(tmp) class_partition = divide_into_sections(total_partition, len(classes)) # Partition classes randomly # Split class indices among clients class_partition = sorted(class_partition, reverse=True) class_partition_split = {} for idx, class_ in enumerate(classes): # Split each class' indices according to the partition class_partition_split[class_] = [list(i) for i in np.array_split(label_indices[idx], class_partition[idx])] clients_split = [] for i in range(n_clients): n = tmp[i] # Number of classes for this client indices = [] j = 0 # Assign class data to the client while n > 0: class_ = classes[j] if class_partition_split[class_]: indices.extend(class_partition_split[class_].pop()) # Add indices of the class to the client n -= 1 j += 1 clients_split.append([data[indices], labels[indices]]) # Add client's data split # Re-sort classes based on available data to balance further splits classes = sorted(classes, key=lambda x: len(class_partition_split[x]), reverse=True) # Raise error if client partition criteria cannot be met if n > 0: raise ValueError("Unable to fulfill the client partition criteria.") # Verbose option to print split information if verbose: display_data_split(clients_split) return np.array(clients_split) def display_data_split(clients_split): '''Print the split information of the dataset for each client.''' print("Data split:") for i, client in enumerate(clients_split): split = np.sum(client[1].reshape(1, -1) == np.arange(np.max(client[1]) + 1).reshape(-1, 1), axis=1) print(f" - Client {i}: {split}") print() def get_default_data_apply_transformations_cf100(train=True, verbose=True): ''' Return default data apply_transformationations for CIFAR-100. Parameters: train (bool): Whether to apply apply_transformationations for training data. verbose (bool): Print apply_transformationation details if True. Returns: apply_transformations_train (Compose): Training apply_transformationations. apply_transformations_eval (Compose): Evaluation (test) apply_transformationations. ''' # Define apply_transformationations for training data apply_transformations_train = { 'cifar100': apply_transformations.Compose([ apply_transformations.ToPILImage(), apply_transformations.RandomCrop(32, padding=4), apply_transformations.RandomHorizontalFlip(), apply_transformations.ToTensor(), apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) } # Define apply_transformationations for test data apply_transformations_eval = { 'cifar100': apply_transformations.Compose([ apply_transformations.ToPILImage(), apply_transformations.ToTensor(), apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) } # Verbose option to print apply_transformationation steps if verbose: print("\nData preprocessing:") for apply_transformationation in apply_transformations_train['cifar100'].apply_transformations: print(f' - {apply_transformationation}') print() return apply_transformations_train['cifar100'], apply_transformations_eval['cifar100'] def obtain_data_loaders_train_cf100(data_dir, n_clients, batch_size, classes_per_client=10, verbose=True, apply_transformations_train=None, apply_transformations_eval=None, non_iid=None, split_factor=1): ''' Return data loaders for training on CIFAR-100. Parameters: data_dir (str): Directory where the CIFAR-100 dataset will be saved. n_clients (int): Number of clients for splitting the dataset. batch_size (int): Batch size for each data loader. classes_per_client (int): Number of classes per client. verbose (bool): Print detailed information if True. apply_transformations_train (Compose): apply_transformationations for training data. apply_transformations_eval (Compose): apply_transformationations for evaluation data. non_iid (str): Strategy to create a non-IID dataset split. split_factor (float): Factor to control the degree of splitting. Returns: client_loaders (list): Data loaders for each client. ''' x_train, y_train, _, _ = get_cifar100(data_dir) # Verbose option to print dataset statistics if verbose: print_image_data_stats_train(x_train, y_train) # Split data according to non-IID strategy (e.g., quantity_skew) split = None if non_iid == 'quantity_skew': split = split_cf100_real_world_images(x_train, y_train, n_clients=n_clients, verbose=verbose) split_tmp = shuffle_list(split) # Create DataLoaders for each client client_loaders = [DataLoader(CustomImageDataset(x, y, apply_transformations_train, split_factor=split_factor), batch_size=batch_size, shuffle=True) for x, y in split_tmp] return client_loaders def obtain_data_loaders_test_cf100(data_dir, batch_size, verbose=True, apply_transformations_eval=None): ''' Return data loaders for testing on CIFAR-100. Parameters: data_dir (str): Directory where the CIFAR-100 dataset will be saved. batch_size (int): Batch size for the test data loader. verbose (bool): Print detailed information if True. apply_transformations_eval (Compose): apply_transformationations for evaluation data. Returns: test_loader (DataLoader): Test data loader. ''' _, _, x_test, y_test = get_cifar100(data_dir) # Verbose option to print dataset statistics if verbose: print_image_data_stats_test(x_test, y_test) # Create DataLoader for the test dataset test_loader = DataLoader(CustomImageDataset(x_test, y_test, apply_transformations_eval, split_factor=1), batch_size=100, shuffle=False) return test_loader