# -*- coding: utf-8 -*- # @Author: Weisen Pan import torch import torch.nn as nn import torch.decentralized as dist import torch.multiprocessing as mp import torch.cuda.amp as amp from torch.backends import cudnn from tensorboardX import SummaryWriter import warnings import argparse import os import numpy as np from tqdm import tqdm from dataset import factory from model import coremodel from utils import metric, label_smoothing, lr_scheduler, prefetch from params.train_params import save_hp_to_json from params import train_params # Global variable to track the best accuracy best_accuracy = 0 def calculate_average(values): """Calculate the average of a list of values""" return sum(values) / len(values) def initialize_processes(rank, world_size, args): """ Initialize decentralized processes. This function is used to set up distributed training across multiple GPUs. """ ngpus = torch.cuda.device_count() args.ngpus = ngpus args.is_decentralized = world_size > 1 if args.multiprocessing_decentralized: # If running decentralized with multiple GPUs, spawn processes for each GPU mp.spawn(train_single_worker, nprocs=ngpus, args=(ngpus, args)) else: print(f"INFO:PyTorch: Using {ngpus} GPUs") # If single GPU, start the training worker directly train_single_worker(args.gpu, ngpus, args) def client_training_step(args, current_round, model, optimizer, scheduler, dataloader, epochs=5, scaler=None): """ Perform training for a single client model in the federated learning setup. This method will train the model for a given number of epochs. """ model.train() # Set model to training mode for epoch in range(epochs): # Prefetch data to improve efficiency prefetcher = prefetch.data_prefetcher(dataloader) images, targets = prefetcher.next() step = 0 while images is not None: # Update the learning rate using the scheduler scheduler(optimizer, step, current_round) optimizer.zero_grad() # Clear the gradients # Enable mixed precision training to optimize memory and computation speed with amp.autocast(enabled=args.is_amp): outputs, ce_loss, cot_loss = model(images, target=targets, mode='train') # Combine losses and normalize by accumulation steps loss = (ce_loss + cot_loss) / args.accumulation_steps loss.backward() # Backpropagate the gradients # Perform optimization step after enough accumulation if step % args.accumulation_steps == 0: optimizer.step() optimizer.zero_grad() # Clear gradients after the step images, targets = prefetcher.next() # Get the next batch of images and targets step += 1 return loss.item() # Return the final loss value def combine_model_parameters(global_model, client_models): """ Aggregate the weights of multiple client models to update the global model. This is the core of the Federated Averaging (FedAvg) algorithm. """ global_state = global_model.state_dict() for key in global_state.keys(): # Average the weights of the corresponding layers from all client models global_state[key] = torch.stack([client.state_dict()[key].float() for client in client_models], dim=0).mean(dim=0) # Load the averaged weights into the global model global_model.load_state_dict(global_state) # Update the client models with the new global model weights for client in client_models: client.load_state_dict(global_model.state_dict()) def validate_model(validation_loader, model, args): """ Perform model validation on the validation dataset. Calculate and return the average accuracy across the dataset. """ model.eval() # Set the model to evaluation mode accuracy_values = [] with torch.no_grad(): for images, targets in validation_loader: if args.gpu is not None: images, targets = images.cuda(args.gpu), targets.cuda(args.gpu) # Use mixed precision for inference with amp.autocast(enabled=args.is_amp): ensemble_output, outputs, ce_loss = model(images, target=targets, mode='val') # Calculate the top-1 accuracy for the current batch avg_acc1 = metric.accuracy(ensemble_output, targets, topk=(1,)) accuracy_values.append(avg_acc1) return calculate_average(accuracy_values) # Return the average accuracy def train_single_worker(gpu, ngpus, args): """ Training worker function that runs on a single GPU. This function handles the entire federated learning workflow for the assigned GPU. """ global best_accuracy args.gpu = gpu cudnn.performance_test = True # Enable performance optimization for CuDNN # Optionally, resume from a checkpoint if provided if args.resume: checkpoint = torch.load(args.resume) args.start_round = checkpoint['round'] best_accuracy = checkpoint['best_acc1'] # Initialize global and client models model = coremodel.coremodel(args).cuda() client_models = [coremodel.coremodel(args).cuda() for _ in range(args.num_clients)] optimizers = [torch.optim.SGD(client.parameters(), lr=args.lr) for client in client_models] # Training and validation loop for round_num in range(args.start_round, args.num_rounds): # Perform training for each client model for client_num in range(args.num_clients): client_training_step(args, round_num, client_models[client_num], optimizers[client_num], lr_scheduler, args.train_loader) # Aggregate client models to update the global model combine_model_parameters(model, client_models) # Validate the updated global model and track the best accuracy validation_accuracy = validate_model(args.val_loader, model, args) best_accuracy = max(best_accuracy, validation_accuracy) print(f"Round {round_num}: Best Accuracy: {best_accuracy:.2f}") if __name__ == "__main__": # Parse command-line arguments parser = argparse.ArgumentParser(description='FedAvg decentralized Training') args = train_params.add_parser_params(parser) initialize_processes(0, args.world_size, args) # Initialize distributed training