180 lines
8.2 KiB
Python
180 lines
8.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
#### Load CIFAR-10 dataset and preprocess it
|
|
import torchvision
|
|
import numpy as np
|
|
import random
|
|
import torch
|
|
from torchvision import apply_transformations
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
# Set random seed for reproducibility
|
|
np.random.seed(68) # Ensures that the random operations have consistent outputs
|
|
random.seed(68)
|
|
|
|
def get_cifar10(data_dir):
|
|
"""Return CIFAR-10 train/test data and labels as numpy arrays"""
|
|
# Download CIFAR-10 dataset
|
|
data_train = torchvision.datasets.CIFAR10(data_dir, train=True, download=True)
|
|
data_test = torchvision.datasets.CIFAR10(data_dir, train=False, download=True)
|
|
|
|
# Preprocess the train and test data to the correct format (channels first)
|
|
x_train, y_train = data_train.data.transpose((0, 3, 1, 2)), np.array(data_train.targets)
|
|
x_test, y_test = data_test.data.transpose((0, 3, 1, 2)), np.array(data_test.targets)
|
|
|
|
return x_train, y_train, x_test, y_test
|
|
|
|
def display_data_statistics(data, labels, dataset_type):
|
|
"""Print statistics of the dataset"""
|
|
print(f"\n{dataset_type} Set: ({data.shape}, {labels.shape}), Range: [{np.min(data):.3f}, {np.max(data):.3f}], "
|
|
f"Labels: {np.min(labels)},..,{np.max(labels)}")
|
|
|
|
def randomize_client_distributiony(train_len, n_clients):
|
|
"""
|
|
Distribute data among clients with a random distribution
|
|
Returns a list with the number of samples for each client
|
|
"""
|
|
# Randomly assign a number of samples to each client, ensuring the total matches the train_len
|
|
client_sizes = [random.randint(10, 100) for _ in range(n_clients - 1)]
|
|
total = sum(client_sizes)
|
|
client_sizes = np.array(client_sizes)
|
|
client_distributions = ((client_sizes / total) * train_len).astype(int) # Normalize to match the train_len
|
|
client_distributions = list(client_distributions)
|
|
client_distributions.append(train_len - sum(client_distributions)) # Ensure all data is allocated
|
|
return client_distributions
|
|
|
|
def divide_into_sections(n, m):
|
|
"""Return 'm' random integers that sum to 'n'"""
|
|
# Break the number 'n' into 'm' random parts that sum to 'n'
|
|
partitions = [1] * m
|
|
for _ in range(n - m):
|
|
partitions[random.randint(0, m - 1)] += 1
|
|
return partitions
|
|
|
|
def split_data_real_world_scenario(data, labels, n_clients=100):
|
|
"""Split data among clients simulating real-world non-IID distribution"""
|
|
n_classes = len(set(labels)) # Determine number of unique classes
|
|
class_indices = [np.where(labels == class_)[0] for class_ in range(n_classes)] # Indices for each class
|
|
|
|
client_classes = [np.random.randint(1, 10) for _ in range(n_clients)] # Random number of classes per client
|
|
total_partitions = sum(client_classes)
|
|
|
|
class_partition = divide_into_sections(total_partitions, len(class_indices)) # Partition classes to distribute
|
|
class_partition_split = {cls: np.array_split(class_indices[cls], n) for cls, n in enumerate(class_partition)}
|
|
|
|
clients_split = []
|
|
for client in client_classes:
|
|
selected_indices = []
|
|
for class_ in range(n_classes):
|
|
if class_partition_split[class_]:
|
|
selected_indices.extend(class_partition_split[class_].pop())
|
|
client -= 1
|
|
if client <= 0:
|
|
break
|
|
clients_split.append([data[selected_indices], labels[selected_indices]])
|
|
|
|
return np.array(clients_split)
|
|
|
|
def split_data_iid(data, labels, n_clients=100, classes_per_client=10, shuffle=True):
|
|
"""Split data among clients with IID (Independent and Identically Distributed) distribution"""
|
|
data_per_client = randomize_client_distributiony(len(data), n_clients)
|
|
label_indices = [np.where(labels == label)[0] for label in range(np.max(labels) + 1)]
|
|
|
|
if shuffle:
|
|
for indices in label_indices:
|
|
np.random.shuffle(indices)
|
|
|
|
clients_split = []
|
|
for client_data in data_per_client:
|
|
client_indices = []
|
|
class_ = np.random.randint(len(label_indices))
|
|
while client_data > 0:
|
|
take = min(client_data, len(label_indices[class_]))
|
|
client_indices.extend(label_indices[class_][:take])
|
|
label_indices[class_] = label_indices[class_][take:]
|
|
client_data -= take
|
|
class_ = (class_ + 1) % len(label_indices)
|
|
|
|
clients_split.append([data[client_indices], labels[client_indices]])
|
|
|
|
return np.array(clients_split)
|
|
|
|
def randomize_data_order(data):
|
|
"""Shuffle data while maintaining the mapping between inputs and labels"""
|
|
for i in range(len(data)):
|
|
index = np.arange(len(data[i][0]))
|
|
np.random.shuffle(index)
|
|
data[i][0], data[i][1] = data[i][0][index], data[i][1][index]
|
|
return data
|
|
|
|
class CustomImageDataset(Dataset):
|
|
"""Custom Dataset class for image data"""
|
|
def __init__(self, inputs, labels, apply_transformations=None, split_factor=1):
|
|
# Convert input data to torch tensors and apply apply_transformationations if provided
|
|
self.inputs = torch.Tensor(inputs)
|
|
self.labels = labels
|
|
self.apply_transformations = apply_transformations
|
|
self.split_factor = split_factor
|
|
|
|
def __getitem__(self, index):
|
|
img, label = self.inputs[index], self.labels[index]
|
|
# Apply apply_transformationations to the image multiple times if split_factor > 1
|
|
imgs = [self.apply_transformations(img) for _ in range(self.split_factor)] if self.apply_transformations else [img]
|
|
return torch.cat(imgs, dim=0), label
|
|
|
|
def __len__(self):
|
|
return len(self.inputs)
|
|
|
|
def get_default_apply_transformations(verbose=True):
|
|
"""Return default apply_transformationations for training and evaluation"""
|
|
apply_transformations_train = apply_transformations.Compose([
|
|
apply_transformations.ToPILImage(), # Convert numpy array to PIL image
|
|
apply_transformations.RandomCrop(32, padding=4), # Randomly crop to 32x32 with padding
|
|
apply_transformations.RandomHorizontalFlip(), # Randomly flip images horizontally
|
|
apply_transformations.ToTensor(), # Convert image to tensor
|
|
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # Normalize with CIFAR-10 mean and std
|
|
])
|
|
|
|
apply_transformations_eval = apply_transformations.Compose([
|
|
apply_transformations.ToPILImage(),
|
|
apply_transformations.ToTensor(),
|
|
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # Same normalization for evaluation
|
|
])
|
|
|
|
if verbose:
|
|
print("\nData preprocessing steps:")
|
|
for apply_transformationation in apply_transformations_train.apply_transformations:
|
|
print(f" - {apply_transformationation}")
|
|
|
|
return apply_transformations_train, apply_transformations_eval
|
|
|
|
def obtain_data_loaders(data_dir, n_clients, batch_size, classes_per_client=10, non_iid=None, split_factor=1):
|
|
"""Return DataLoader objects for clients with either IID or non-IID data split"""
|
|
x_train, y_train, _, _ = get_cifar10(data_dir)
|
|
display_data_statistics(x_train, y_train, "Train")
|
|
|
|
# Split data based on non-IID method specified (either 'quantity_skew' or 'label_skew')
|
|
if non_iid == 'quantity_skew':
|
|
clients_data = split_data_real_world_scenario(x_train, y_train, n_clients)
|
|
elif non_iid == 'label_skew':
|
|
clients_data = split_data_iid(x_train, y_train, n_clients, classes_per_client)
|
|
|
|
shuffled_clients_data = randomize_data_order(clients_data)
|
|
|
|
apply_transformations_train, apply_transformations_eval = get_default_apply_transformations(verbose=False)
|
|
client_loaders = [DataLoader(CustomImageDataset(x, y, apply_transformations_train, split_factor=split_factor),
|
|
batch_size=batch_size, shuffle=True) for x, y in shuffled_clients_data]
|
|
|
|
return client_loaders
|
|
|
|
def get_test_data_loader(data_dir, batch_size):
|
|
"""Return DataLoader for test data"""
|
|
_, _, x_test, y_test = get_cifar10(data_dir)
|
|
display_data_statistics(x_test, y_test, "Test")
|
|
|
|
_, apply_transformations_eval = get_default_apply_transformations(verbose=False)
|
|
test_loader = DataLoader(CustomImageDataset(x_test, y_test, apply_transformations_eval), batch_size=batch_size, shuffle=False)
|
|
|
|
return test_loader
|