Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

180 lines
8.2 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
#### Load CIFAR-10 dataset and preprocess it
import torchvision
import numpy as np
import random
import torch
from torchvision import apply_transformations
from torch.utils.data import DataLoader, Dataset
# Set random seed for reproducibility
np.random.seed(68) # Ensures that the random operations have consistent outputs
random.seed(68)
def get_cifar10(data_dir):
"""Return CIFAR-10 train/test data and labels as numpy arrays"""
# Download CIFAR-10 dataset
data_train = torchvision.datasets.CIFAR10(data_dir, train=True, download=True)
data_test = torchvision.datasets.CIFAR10(data_dir, train=False, download=True)
# Preprocess the train and test data to the correct format (channels first)
x_train, y_train = data_train.data.transpose((0, 3, 1, 2)), np.array(data_train.targets)
x_test, y_test = data_test.data.transpose((0, 3, 1, 2)), np.array(data_test.targets)
return x_train, y_train, x_test, y_test
def display_data_statistics(data, labels, dataset_type):
"""Print statistics of the dataset"""
print(f"\n{dataset_type} Set: ({data.shape}, {labels.shape}), Range: [{np.min(data):.3f}, {np.max(data):.3f}], "
f"Labels: {np.min(labels)},..,{np.max(labels)}")
def randomize_client_distributiony(train_len, n_clients):
"""
Distribute data among clients with a random distribution
Returns a list with the number of samples for each client
"""
# Randomly assign a number of samples to each client, ensuring the total matches the train_len
client_sizes = [random.randint(10, 100) for _ in range(n_clients - 1)]
total = sum(client_sizes)
client_sizes = np.array(client_sizes)
client_distributions = ((client_sizes / total) * train_len).astype(int) # Normalize to match the train_len
client_distributions = list(client_distributions)
client_distributions.append(train_len - sum(client_distributions)) # Ensure all data is allocated
return client_distributions
def divide_into_sections(n, m):
"""Return 'm' random integers that sum to 'n'"""
# Break the number 'n' into 'm' random parts that sum to 'n'
partitions = [1] * m
for _ in range(n - m):
partitions[random.randint(0, m - 1)] += 1
return partitions
def split_data_real_world_scenario(data, labels, n_clients=100):
"""Split data among clients simulating real-world non-IID distribution"""
n_classes = len(set(labels)) # Determine number of unique classes
class_indices = [np.where(labels == class_)[0] for class_ in range(n_classes)] # Indices for each class
client_classes = [np.random.randint(1, 10) for _ in range(n_clients)] # Random number of classes per client
total_partitions = sum(client_classes)
class_partition = divide_into_sections(total_partitions, len(class_indices)) # Partition classes to distribute
class_partition_split = {cls: np.array_split(class_indices[cls], n) for cls, n in enumerate(class_partition)}
clients_split = []
for client in client_classes:
selected_indices = []
for class_ in range(n_classes):
if class_partition_split[class_]:
selected_indices.extend(class_partition_split[class_].pop())
client -= 1
if client <= 0:
break
clients_split.append([data[selected_indices], labels[selected_indices]])
return np.array(clients_split)
def split_data_iid(data, labels, n_clients=100, classes_per_client=10, shuffle=True):
"""Split data among clients with IID (Independent and Identically Distributed) distribution"""
data_per_client = randomize_client_distributiony(len(data), n_clients)
label_indices = [np.where(labels == label)[0] for label in range(np.max(labels) + 1)]
if shuffle:
for indices in label_indices:
np.random.shuffle(indices)
clients_split = []
for client_data in data_per_client:
client_indices = []
class_ = np.random.randint(len(label_indices))
while client_data > 0:
take = min(client_data, len(label_indices[class_]))
client_indices.extend(label_indices[class_][:take])
label_indices[class_] = label_indices[class_][take:]
client_data -= take
class_ = (class_ + 1) % len(label_indices)
clients_split.append([data[client_indices], labels[client_indices]])
return np.array(clients_split)
def randomize_data_order(data):
"""Shuffle data while maintaining the mapping between inputs and labels"""
for i in range(len(data)):
index = np.arange(len(data[i][0]))
np.random.shuffle(index)
data[i][0], data[i][1] = data[i][0][index], data[i][1][index]
return data
class CustomImageDataset(Dataset):
"""Custom Dataset class for image data"""
def __init__(self, inputs, labels, apply_transformations=None, split_factor=1):
# Convert input data to torch tensors and apply apply_transformationations if provided
self.inputs = torch.Tensor(inputs)
self.labels = labels
self.apply_transformations = apply_transformations
self.split_factor = split_factor
def __getitem__(self, index):
img, label = self.inputs[index], self.labels[index]
# Apply apply_transformationations to the image multiple times if split_factor > 1
imgs = [self.apply_transformations(img) for _ in range(self.split_factor)] if self.apply_transformations else [img]
return torch.cat(imgs, dim=0), label
def __len__(self):
return len(self.inputs)
def get_default_apply_transformations(verbose=True):
"""Return default apply_transformationations for training and evaluation"""
apply_transformations_train = apply_transformations.Compose([
apply_transformations.ToPILImage(), # Convert numpy array to PIL image
apply_transformations.RandomCrop(32, padding=4), # Randomly crop to 32x32 with padding
apply_transformations.RandomHorizontalFlip(), # Randomly flip images horizontally
apply_transformations.ToTensor(), # Convert image to tensor
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # Normalize with CIFAR-10 mean and std
])
apply_transformations_eval = apply_transformations.Compose([
apply_transformations.ToPILImage(),
apply_transformations.ToTensor(),
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # Same normalization for evaluation
])
if verbose:
print("\nData preprocessing steps:")
for apply_transformationation in apply_transformations_train.apply_transformations:
print(f" - {apply_transformationation}")
return apply_transformations_train, apply_transformations_eval
def obtain_data_loaders(data_dir, n_clients, batch_size, classes_per_client=10, non_iid=None, split_factor=1):
"""Return DataLoader objects for clients with either IID or non-IID data split"""
x_train, y_train, _, _ = get_cifar10(data_dir)
display_data_statistics(x_train, y_train, "Train")
# Split data based on non-IID method specified (either 'quantity_skew' or 'label_skew')
if non_iid == 'quantity_skew':
clients_data = split_data_real_world_scenario(x_train, y_train, n_clients)
elif non_iid == 'label_skew':
clients_data = split_data_iid(x_train, y_train, n_clients, classes_per_client)
shuffled_clients_data = randomize_data_order(clients_data)
apply_transformations_train, apply_transformations_eval = get_default_apply_transformations(verbose=False)
client_loaders = [DataLoader(CustomImageDataset(x, y, apply_transformations_train, split_factor=split_factor),
batch_size=batch_size, shuffle=True) for x, y in shuffled_clients_data]
return client_loaders
def get_test_data_loader(data_dir, batch_size):
"""Return DataLoader for test data"""
_, _, x_test, y_test = get_cifar10(data_dir)
display_data_statistics(x_test, y_test, "Test")
_, apply_transformations_eval = get_default_apply_transformations(verbose=False)
test_loader = DataLoader(CustomImageDataset(x_test, y_test, apply_transformations_eval), batch_size=batch_size, shuffle=False)
return test_loader