# -*- coding: utf-8 -*- # @Author: Weisen Pan #### Load CIFAR-10 dataset and preprocess it import torchvision import numpy as np import random import torch from torchvision import apply_transformations from torch.utils.data import DataLoader, Dataset # Set random seed for reproducibility np.random.seed(68) # Ensures that the random operations have consistent outputs random.seed(68) def get_cifar10(data_dir): """Return CIFAR-10 train/test data and labels as numpy arrays""" # Download CIFAR-10 dataset data_train = torchvision.datasets.CIFAR10(data_dir, train=True, download=True) data_test = torchvision.datasets.CIFAR10(data_dir, train=False, download=True) # Preprocess the train and test data to the correct format (channels first) x_train, y_train = data_train.data.transpose((0, 3, 1, 2)), np.array(data_train.targets) x_test, y_test = data_test.data.transpose((0, 3, 1, 2)), np.array(data_test.targets) return x_train, y_train, x_test, y_test def display_data_statistics(data, labels, dataset_type): """Print statistics of the dataset""" print(f"\n{dataset_type} Set: ({data.shape}, {labels.shape}), Range: [{np.min(data):.3f}, {np.max(data):.3f}], " f"Labels: {np.min(labels)},..,{np.max(labels)}") def randomize_client_distributiony(train_len, n_clients): """ Distribute data among clients with a random distribution Returns a list with the number of samples for each client """ # Randomly assign a number of samples to each client, ensuring the total matches the train_len client_sizes = [random.randint(10, 100) for _ in range(n_clients - 1)] total = sum(client_sizes) client_sizes = np.array(client_sizes) client_distributions = ((client_sizes / total) * train_len).astype(int) # Normalize to match the train_len client_distributions = list(client_distributions) client_distributions.append(train_len - sum(client_distributions)) # Ensure all data is allocated return client_distributions def divide_into_sections(n, m): """Return 'm' random integers that sum to 'n'""" # Break the number 'n' into 'm' random parts that sum to 'n' partitions = [1] * m for _ in range(n - m): partitions[random.randint(0, m - 1)] += 1 return partitions def split_data_real_world_scenario(data, labels, n_clients=100): """Split data among clients simulating real-world non-IID distribution""" n_classes = len(set(labels)) # Determine number of unique classes class_indices = [np.where(labels == class_)[0] for class_ in range(n_classes)] # Indices for each class client_classes = [np.random.randint(1, 10) for _ in range(n_clients)] # Random number of classes per client total_partitions = sum(client_classes) class_partition = divide_into_sections(total_partitions, len(class_indices)) # Partition classes to distribute class_partition_split = {cls: np.array_split(class_indices[cls], n) for cls, n in enumerate(class_partition)} clients_split = [] for client in client_classes: selected_indices = [] for class_ in range(n_classes): if class_partition_split[class_]: selected_indices.extend(class_partition_split[class_].pop()) client -= 1 if client <= 0: break clients_split.append([data[selected_indices], labels[selected_indices]]) return np.array(clients_split) def split_data_iid(data, labels, n_clients=100, classes_per_client=10, shuffle=True): """Split data among clients with IID (Independent and Identically Distributed) distribution""" data_per_client = randomize_client_distributiony(len(data), n_clients) label_indices = [np.where(labels == label)[0] for label in range(np.max(labels) + 1)] if shuffle: for indices in label_indices: np.random.shuffle(indices) clients_split = [] for client_data in data_per_client: client_indices = [] class_ = np.random.randint(len(label_indices)) while client_data > 0: take = min(client_data, len(label_indices[class_])) client_indices.extend(label_indices[class_][:take]) label_indices[class_] = label_indices[class_][take:] client_data -= take class_ = (class_ + 1) % len(label_indices) clients_split.append([data[client_indices], labels[client_indices]]) return np.array(clients_split) def randomize_data_order(data): """Shuffle data while maintaining the mapping between inputs and labels""" for i in range(len(data)): index = np.arange(len(data[i][0])) np.random.shuffle(index) data[i][0], data[i][1] = data[i][0][index], data[i][1][index] return data class CustomImageDataset(Dataset): """Custom Dataset class for image data""" def __init__(self, inputs, labels, apply_transformations=None, split_factor=1): # Convert input data to torch tensors and apply apply_transformationations if provided self.inputs = torch.Tensor(inputs) self.labels = labels self.apply_transformations = apply_transformations self.split_factor = split_factor def __getitem__(self, index): img, label = self.inputs[index], self.labels[index] # Apply apply_transformationations to the image multiple times if split_factor > 1 imgs = [self.apply_transformations(img) for _ in range(self.split_factor)] if self.apply_transformations else [img] return torch.cat(imgs, dim=0), label def __len__(self): return len(self.inputs) def get_default_apply_transformations(verbose=True): """Return default apply_transformationations for training and evaluation""" apply_transformations_train = apply_transformations.Compose([ apply_transformations.ToPILImage(), # Convert numpy array to PIL image apply_transformations.RandomCrop(32, padding=4), # Randomly crop to 32x32 with padding apply_transformations.RandomHorizontalFlip(), # Randomly flip images horizontally apply_transformations.ToTensor(), # Convert image to tensor apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # Normalize with CIFAR-10 mean and std ]) apply_transformations_eval = apply_transformations.Compose([ apply_transformations.ToPILImage(), apply_transformations.ToTensor(), apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # Same normalization for evaluation ]) if verbose: print("\nData preprocessing steps:") for apply_transformationation in apply_transformations_train.apply_transformations: print(f" - {apply_transformationation}") return apply_transformations_train, apply_transformations_eval def obtain_data_loaders(data_dir, n_clients, batch_size, classes_per_client=10, non_iid=None, split_factor=1): """Return DataLoader objects for clients with either IID or non-IID data split""" x_train, y_train, _, _ = get_cifar10(data_dir) display_data_statistics(x_train, y_train, "Train") # Split data based on non-IID method specified (either 'quantity_skew' or 'label_skew') if non_iid == 'quantity_skew': clients_data = split_data_real_world_scenario(x_train, y_train, n_clients) elif non_iid == 'label_skew': clients_data = split_data_iid(x_train, y_train, n_clients, classes_per_client) shuffled_clients_data = randomize_data_order(clients_data) apply_transformations_train, apply_transformations_eval = get_default_apply_transformations(verbose=False) client_loaders = [DataLoader(CustomImageDataset(x, y, apply_transformations_train, split_factor=split_factor), batch_size=batch_size, shuffle=True) for x, y in shuffled_clients_data] return client_loaders def get_test_data_loader(data_dir, batch_size): """Return DataLoader for test data""" _, _, x_test, y_test = get_cifar10(data_dir) display_data_statistics(x_test, y_test, "Test") _, apply_transformations_eval = get_default_apply_transformations(verbose=False) test_loader = DataLoader(CustomImageDataset(x_test, y_test, apply_transformations_eval), batch_size=batch_size, shuffle=False) return test_loader