159 lines
6.3 KiB
Python
159 lines
6.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.decentralized as dist
|
|
import torch.multiprocessing as mp
|
|
import torch.cuda.amp as amp
|
|
from torch.backends import cudnn
|
|
from tensorboardX import SummaryWriter
|
|
import warnings
|
|
import argparse
|
|
import os
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from dataset import factory
|
|
from model import coremodel
|
|
from utils import metric, label_smoothing, lr_scheduler, prefetch
|
|
from params.train_params import save_hp_to_json
|
|
from params import train_params
|
|
|
|
# Global variable to track the best accuracy
|
|
best_accuracy = 0
|
|
|
|
def calculate_average(values):
|
|
"""Calculate the average of a list of values"""
|
|
return sum(values) / len(values)
|
|
|
|
def initialize_processes(rank, world_size, args):
|
|
"""
|
|
Initialize decentralized processes.
|
|
This function is used to set up distributed training across multiple GPUs.
|
|
"""
|
|
ngpus = torch.cuda.device_count()
|
|
args.ngpus = ngpus
|
|
args.is_decentralized = world_size > 1
|
|
|
|
if args.multiprocessing_decentralized:
|
|
# If running decentralized with multiple GPUs, spawn processes for each GPU
|
|
mp.spawn(train_single_worker, nprocs=ngpus, args=(ngpus, args))
|
|
else:
|
|
print(f"INFO:PyTorch: Using {ngpus} GPUs")
|
|
# If single GPU, start the training worker directly
|
|
train_single_worker(args.gpu, ngpus, args)
|
|
|
|
def client_training_step(args, current_round, model, optimizer, scheduler, dataloader, epochs=5, scaler=None):
|
|
"""
|
|
Perform training for a single client model in the federated learning setup.
|
|
This method will train the model for a given number of epochs.
|
|
"""
|
|
model.train() # Set model to training mode
|
|
for epoch in range(epochs):
|
|
# Prefetch data to improve efficiency
|
|
prefetcher = prefetch.data_prefetcher(dataloader)
|
|
images, targets = prefetcher.next()
|
|
step = 0
|
|
|
|
while images is not None:
|
|
# Update the learning rate using the scheduler
|
|
scheduler(optimizer, step, current_round)
|
|
optimizer.zero_grad() # Clear the gradients
|
|
|
|
# Enable mixed precision training to optimize memory and computation speed
|
|
with amp.autocast(enabled=args.is_amp):
|
|
outputs, ce_loss, cot_loss = model(images, target=targets, mode='train')
|
|
|
|
# Combine losses and normalize by accumulation steps
|
|
loss = (ce_loss + cot_loss) / args.accumulation_steps
|
|
loss.backward() # Backpropagate the gradients
|
|
|
|
# Perform optimization step after enough accumulation
|
|
if step % args.accumulation_steps == 0:
|
|
optimizer.step()
|
|
optimizer.zero_grad() # Clear gradients after the step
|
|
|
|
images, targets = prefetcher.next() # Get the next batch of images and targets
|
|
step += 1
|
|
|
|
return loss.item() # Return the final loss value
|
|
|
|
def combine_model_parameters(global_model, client_models):
|
|
"""
|
|
Aggregate the weights of multiple client models to update the global model.
|
|
This is the core of the Federated Averaging (FedAvg) algorithm.
|
|
"""
|
|
global_state = global_model.state_dict()
|
|
for key in global_state.keys():
|
|
# Average the weights of the corresponding layers from all client models
|
|
global_state[key] = torch.stack([client.state_dict()[key].float() for client in client_models], dim=0).mean(dim=0)
|
|
|
|
# Load the averaged weights into the global model
|
|
global_model.load_state_dict(global_state)
|
|
# Update the client models with the new global model weights
|
|
for client in client_models:
|
|
client.load_state_dict(global_model.state_dict())
|
|
|
|
def validate_model(validation_loader, model, args):
|
|
"""
|
|
Perform model validation on the validation dataset.
|
|
Calculate and return the average accuracy across the dataset.
|
|
"""
|
|
model.eval() # Set the model to evaluation mode
|
|
accuracy_values = []
|
|
|
|
with torch.no_grad():
|
|
for images, targets in validation_loader:
|
|
if args.gpu is not None:
|
|
images, targets = images.cuda(args.gpu), targets.cuda(args.gpu)
|
|
|
|
# Use mixed precision for inference
|
|
with amp.autocast(enabled=args.is_amp):
|
|
ensemble_output, outputs, ce_loss = model(images, target=targets, mode='val')
|
|
|
|
# Calculate the top-1 accuracy for the current batch
|
|
avg_acc1 = metric.accuracy(ensemble_output, targets, topk=(1,))
|
|
accuracy_values.append(avg_acc1)
|
|
|
|
return calculate_average(accuracy_values) # Return the average accuracy
|
|
|
|
def train_single_worker(gpu, ngpus, args):
|
|
"""
|
|
Training worker function that runs on a single GPU.
|
|
This function handles the entire federated learning workflow for the assigned GPU.
|
|
"""
|
|
global best_accuracy
|
|
args.gpu = gpu
|
|
cudnn.performance_test = True # Enable performance optimization for CuDNN
|
|
|
|
# Optionally, resume from a checkpoint if provided
|
|
if args.resume:
|
|
checkpoint = torch.load(args.resume)
|
|
args.start_round = checkpoint['round']
|
|
best_accuracy = checkpoint['best_acc1']
|
|
|
|
# Initialize global and client models
|
|
model = coremodel.coremodel(args).cuda()
|
|
client_models = [coremodel.coremodel(args).cuda() for _ in range(args.num_clients)]
|
|
optimizers = [torch.optim.SGD(client.parameters(), lr=args.lr) for client in client_models]
|
|
|
|
# Training and validation loop
|
|
for round_num in range(args.start_round, args.num_rounds):
|
|
# Perform training for each client model
|
|
for client_num in range(args.num_clients):
|
|
client_training_step(args, round_num, client_models[client_num], optimizers[client_num], lr_scheduler, args.train_loader)
|
|
|
|
# Aggregate client models to update the global model
|
|
combine_model_parameters(model, client_models)
|
|
|
|
# Validate the updated global model and track the best accuracy
|
|
validation_accuracy = validate_model(args.val_loader, model, args)
|
|
best_accuracy = max(best_accuracy, validation_accuracy)
|
|
print(f"Round {round_num}: Best Accuracy: {best_accuracy:.2f}")
|
|
|
|
if __name__ == "__main__":
|
|
# Parse command-line arguments
|
|
parser = argparse.ArgumentParser(description='FedAvg decentralized Training')
|
|
args = train_params.add_parser_params(parser)
|
|
initialize_processes(0, args.world_size, args) # Initialize distributed training
|