Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

159 lines
6.3 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import torch
import torch.nn as nn
import torch.decentralized as dist
import torch.multiprocessing as mp
import torch.cuda.amp as amp
from torch.backends import cudnn
from tensorboardX import SummaryWriter
import warnings
import argparse
import os
import numpy as np
from tqdm import tqdm
from dataset import factory
from model import coremodel
from utils import metric, label_smoothing, lr_scheduler, prefetch
from params.train_params import save_hp_to_json
from params import train_params
# Global variable to track the best accuracy
best_accuracy = 0
def calculate_average(values):
"""Calculate the average of a list of values"""
return sum(values) / len(values)
def initialize_processes(rank, world_size, args):
"""
Initialize decentralized processes.
This function is used to set up distributed training across multiple GPUs.
"""
ngpus = torch.cuda.device_count()
args.ngpus = ngpus
args.is_decentralized = world_size > 1
if args.multiprocessing_decentralized:
# If running decentralized with multiple GPUs, spawn processes for each GPU
mp.spawn(train_single_worker, nprocs=ngpus, args=(ngpus, args))
else:
print(f"INFO:PyTorch: Using {ngpus} GPUs")
# If single GPU, start the training worker directly
train_single_worker(args.gpu, ngpus, args)
def client_training_step(args, current_round, model, optimizer, scheduler, dataloader, epochs=5, scaler=None):
"""
Perform training for a single client model in the federated learning setup.
This method will train the model for a given number of epochs.
"""
model.train() # Set model to training mode
for epoch in range(epochs):
# Prefetch data to improve efficiency
prefetcher = prefetch.data_prefetcher(dataloader)
images, targets = prefetcher.next()
step = 0
while images is not None:
# Update the learning rate using the scheduler
scheduler(optimizer, step, current_round)
optimizer.zero_grad() # Clear the gradients
# Enable mixed precision training to optimize memory and computation speed
with amp.autocast(enabled=args.is_amp):
outputs, ce_loss, cot_loss = model(images, target=targets, mode='train')
# Combine losses and normalize by accumulation steps
loss = (ce_loss + cot_loss) / args.accumulation_steps
loss.backward() # Backpropagate the gradients
# Perform optimization step after enough accumulation
if step % args.accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad() # Clear gradients after the step
images, targets = prefetcher.next() # Get the next batch of images and targets
step += 1
return loss.item() # Return the final loss value
def combine_model_parameters(global_model, client_models):
"""
Aggregate the weights of multiple client models to update the global model.
This is the core of the Federated Averaging (FedAvg) algorithm.
"""
global_state = global_model.state_dict()
for key in global_state.keys():
# Average the weights of the corresponding layers from all client models
global_state[key] = torch.stack([client.state_dict()[key].float() for client in client_models], dim=0).mean(dim=0)
# Load the averaged weights into the global model
global_model.load_state_dict(global_state)
# Update the client models with the new global model weights
for client in client_models:
client.load_state_dict(global_model.state_dict())
def validate_model(validation_loader, model, args):
"""
Perform model validation on the validation dataset.
Calculate and return the average accuracy across the dataset.
"""
model.eval() # Set the model to evaluation mode
accuracy_values = []
with torch.no_grad():
for images, targets in validation_loader:
if args.gpu is not None:
images, targets = images.cuda(args.gpu), targets.cuda(args.gpu)
# Use mixed precision for inference
with amp.autocast(enabled=args.is_amp):
ensemble_output, outputs, ce_loss = model(images, target=targets, mode='val')
# Calculate the top-1 accuracy for the current batch
avg_acc1 = metric.accuracy(ensemble_output, targets, topk=(1,))
accuracy_values.append(avg_acc1)
return calculate_average(accuracy_values) # Return the average accuracy
def train_single_worker(gpu, ngpus, args):
"""
Training worker function that runs on a single GPU.
This function handles the entire federated learning workflow for the assigned GPU.
"""
global best_accuracy
args.gpu = gpu
cudnn.performance_test = True # Enable performance optimization for CuDNN
# Optionally, resume from a checkpoint if provided
if args.resume:
checkpoint = torch.load(args.resume)
args.start_round = checkpoint['round']
best_accuracy = checkpoint['best_acc1']
# Initialize global and client models
model = coremodel.coremodel(args).cuda()
client_models = [coremodel.coremodel(args).cuda() for _ in range(args.num_clients)]
optimizers = [torch.optim.SGD(client.parameters(), lr=args.lr) for client in client_models]
# Training and validation loop
for round_num in range(args.start_round, args.num_rounds):
# Perform training for each client model
for client_num in range(args.num_clients):
client_training_step(args, round_num, client_models[client_num], optimizers[client_num], lr_scheduler, args.train_loader)
# Aggregate client models to update the global model
combine_model_parameters(model, client_models)
# Validate the updated global model and track the best accuracy
validation_accuracy = validate_model(args.val_loader, model, args)
best_accuracy = max(best_accuracy, validation_accuracy)
print(f"Round {round_num}: Best Accuracy: {best_accuracy:.2f}")
if __name__ == "__main__":
# Parse command-line arguments
parser = argparse.ArgumentParser(description='FedAvg decentralized Training')
args = train_params.add_parser_params(parser)
initialize_processes(0, args.world_size, args) # Initialize distributed training