# -*- coding: utf-8 -*- # @Author: Weisen Pan import logging import os import shutil import torch from torch import nn, optim from torch.optim.lr_scheduler import ReduceLROnPlateau from utils import metric from fedml_service.decentralized.federated_gkt import utils # List to store filenames of saved checkpoints saved_ckpt_filenames = [] class GKTServerTrainer: def __init__(self, client_num, device, server_model, args, writer): # Initialize the trainer with the number of clients, device (CPU/GPU), global server model, training arguments, and a writer for logging self.client_num = client_num self.device = device self.args = args self.writer = writer """ Notes: Using data parallelism requires adjusting the batch size accordingly. For example, with a single GPU (batch_size = 64), an epoch takes 1:03; using 4 GPUs (batch_size = 256), it takes 38 seconds, and with 4 GPUs (batch_size = 64), it takes 1:00. If batch size is not adjusted, the communication between CPU and GPU may slow down training. """ # Server model setup self.model_global = server_model self.model_global.train() # Set model to training mode self.model_global.to(self.device) # Move model to the specified device (CPU or GPU) # Model parameters for optimization self.model_params = self.master_params = self.model_global.parameters() optim_params = self.master_params # Choose optimizer based on arguments (SGD or Adam) if self.args.optimizer == "SGD": self.optimizer = optim.SGD(optim_params, lr=self.args.lr, momentum=0.9, nesterov=True, weight_decay=self.args.wd) elif self.args.optimizer == "Adam": self.optimizer = optim.Adam(optim_params, lr=self.args.lr, weight_decay=0.0001, amsgrad=True) # Learning rate scheduler to reduce the learning rate when the accuracy plateaus self.scheduler = ReduceLROnPlateau(self.optimizer, 'max') # Loss functions: CrossEntropy for classification, KL for knowledge distillation self.criterion_CE = nn.CrossEntropyLoss() self.criterion_KL = utils.KL_Loss(self.args.temperature) # Best accuracy tracking self.best_acc = 0.0 # Client data dictionaries to store features, logits, and labels self.client_extracted_feature_dict = {} self.client_logits_dict = {} self.client_labels_dict = {} self.server_logits_dict = {} # Testing data dictionaries self.client_extracted_feature_dict_test = {} self.client_labels_dict_test = {} # Miscellaneous dictionaries to store model info, sample numbers, training accuracy, and loss self.model_dict = {} self.sample_num_dict = {} self.train_acc_dict = {} self.train_loss_dict = {} self.test_acc_avg = 0.0 self.test_loss_avg = 0.0 # Dictionary to track if the client model has been uploaded self.flag_client_model_uploaded_dict = {idx: False for idx in range(self.client_num)} # Add results from a local client model after training def add_local_trained_result(self, index, extracted_feature_dict, logits_dict, labels_dict, extracted_feature_dict_test, labels_dict_test): logging.info(f"Adding model for client index = {index}") self.client_extracted_feature_dict[index] = extracted_feature_dict self.client_logits_dict[index] = logits_dict self.client_labels_dict[index] = labels_dict self.client_extracted_feature_dict_test[index] = extracted_feature_dict_test self.client_labels_dict_test[index] = labels_dict_test self.flag_client_model_uploaded_dict[index] = True # Check if all clients have uploaded their models def check_whether_all_receive(self): if all(self.flag_client_model_uploaded_dict.values()): self.flag_client_model_uploaded_dict = {idx: False for idx in range(self.client_num)} return True return False # Get logits from the global model for a specific client def get_global_logits(self, client_index): return self.server_logits_dict.get(client_index) # Main training function based on the round index def train(self, round_idx): if self.args.sweep == 1: # Sweep mode self.sweep(round_idx) else: # Normal training process if self.args.whether_training_on_client == 1: # Check if training occurs on client self.train_and_distill_on_client(round_idx) else: # No training on client, just evaluate self.do_not_train_on_client(round_idx) # Training and knowledge distillation on client side def train_and_distill_on_client(self, round_idx): # Set the number of server epochs (based on testing mode) epochs_server = 1 if not self.args.test else self.get_server_epoch_strategy_test()[0] self.train_and_eval(round_idx, epochs_server, self.writer, self.args) # Train and evaluate self.scheduler.step(self.best_acc, epoch=round_idx) # Update learning rate scheduler # Skip client-side training def do_not_train_on_client(self, round_idx): self.train_and_eval(round_idx, 1) self.scheduler.step(self.best_acc, epoch=round_idx) # Training with sweeping strategy def sweep(self, round_idx): self.train_and_eval(round_idx, self.args.epochs_server) self.scheduler.step(self.best_acc, epoch=round_idx) # Strategy for determining the number of epochs (used in testing) def get_server_epoch_strategy_test(self): return 1, True # Different strategies for determining the number of epochs based on training round def get_server_epoch_strategy_reset56(self, round_idx): epochs = 20 if round_idx < 20 else 15 if round_idx < 30 else 10 if round_idx < 40 else 5 if round_idx < 50 else 3 if round_idx < 150 else 1 whether_distill_back = round_idx < 150 return epochs, whether_distill_back # Another variant of epoch strategy def get_server_epoch_strategy_reset56_2(self, round_idx): return self.args.epochs_server, True # Main training and evaluation loop def train_and_eval(self, round_idx, epochs, val_writer, args): for epoch in range(epochs): logging.info(f"Train and evaluate. Round = {round_idx}, Epoch = {epoch}") train_metrics = self.train_large_model_on_the_server() # Training step if epoch == epochs - 1: # Log metrics for the final epoch val_writer.add_scalar('average training loss', train_metrics['train_loss'], global_step=round_idx) test_metrics = self.eval_large_model_on_the_server() # Evaluation step test_acc = test_metrics['test_accTop1'] val_writer.add_scalar('test loss', test_metrics['test_loss'], global_step=round_idx) val_writer.add_scalar('test acc', test_metrics['test_accTop1'], global_step=round_idx) # Save best accuracy model if test_acc >= self.best_acc: logging.info("- Found better accuracy") self.best_acc = test_acc val_writer.add_scalar('best_acc1', self.best_acc, global_step=round_idx) # Save model checkpoints if args.save_weight: filename = f"checkpoint_{round_idx}.pth.tar" saved_ckpt_filenames.append(filename) if len(saved_ckpt_filenames) > args.max_ckpt_nums: os.remove(os.path.join(args.model_dir, saved_ckpt_filenames.pop(0))) ckpt_dict = { 'round': round_idx + 1, 'arch': args.arch, 'state_dict': self.model_global.state_dict(), 'best_acc1': self.best_acc, 'optimizer': self.optimizer.state_dict(), } metric.save_checkpoint(ckpt_dict, test_acc >= self.best_acc, args.model_dir, filename=filename) # Print metrics for the current round print(f"{round_idx}-th round | Train Loss: {train_metrics['train_loss']:.3g} | Test Loss: {test_metrics['test_loss']:.3g} | Test Acc: {test_metrics['test_accTop1']:.3f}") # Function to train the model on the server side def train_large_model_on_the_server(self): # Clear the logits dictionary and set model to training mode self.server_logits_dict.clear() self.model_global.train() # Track loss and accuracy loss_avg = utils.RollingAverage() accTop1_avg = utils.RollingAverage() accTop5_avg = utils.RollingAverage() # Iterate over clients' extracted features for client_index, extracted_feature_dict in self.client_extracted_feature_dict.items(): logits_dict = self.client_logits_dict[client_index] labels_dict = self.client_labels_dict[client_index] s_logits_dict = {} self.server_logits_dict[client_index] = s_logits_dict # Iterate over batches of features for each client for batch_index, batch_feature_map_x in extracted_feature_dict.items(): batch_feature_map_x = torch.from_numpy(batch_feature_map_x).to(self.device) batch_logits = torch.from_numpy(logits_dict[batch_index]).float().to(self.device) batch_labels = torch.from_numpy(labels_dict[batch_index]).long().to(self.device) # Forward pass output_batch = self.model_global(batch_feature_map_x) # Knowledge distillation loss if self.args.whether_distill_on_the_server == 1: loss_kd = self.criterion_KL(output_batch, batch_logits).to(self.device) loss_true = self.criterion_CE(output_batch, batch_labels).to(self.device) loss = loss_kd + self.args.alpha * loss_true else: # Standard cross-entropy loss loss = self.criterion_CE(output_batch, batch_labels).to(self.device) # Backward pass and optimization self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Compute accuracy metrics metrics = utils.accuracy(output_batch, batch_labels, topk=(1, 5)) accTop1_avg.update(metrics[0].item()) accTop5_avg.update(metrics[1].item()) loss_avg.update(loss.item()) # Store logits for the batch s_logits_dict[batch_index] = output_batch.cpu().detach().numpy() # Aggregate and log training metrics train_metrics = {'train_loss': loss_avg.value(), 'train_accTop1': accTop1_avg.value(), 'train_accTop5': accTop5_avg.value()} logging.info(f"- Train metrics: {' ; '.join(f'{k}: {v:.3f}' for k, v in train_metrics.items())}") return train_metrics # Function to evaluate the model on the server side def eval_large_model_on_the_server(self): # Set model to evaluation mode self.model_global.eval() loss_avg = utils.RollingAverage() accTop1_avg = utils.RollingAverage() accTop5_avg = utils.RollingAverage() # Disable gradient computation for evaluation with torch.no_grad(): # Iterate over clients' extracted features for testing for client_index, extracted_feature_dict in self.client_extracted_feature_dict_test.items(): labels_dict = self.client_labels_dict_test[client_index] # Iterate over batches for each client for batch_index, batch_feature_map_x in extracted_feature_dict.items(): batch_feature_map_x = torch.from_numpy(batch_feature_map_x).to(self.device) batch_labels = torch.from_numpy(labels_dict[batch_index]).long().to(self.device) # Forward pass output_batch = self.model_global(batch_feature_map_x) loss = self.criterion_CE(output_batch, batch_labels) # Compute accuracy metrics metrics = utils.accuracy(output_batch, batch_labels, topk=(1, 5)) accTop1_avg.update(metrics[0].item()) accTop5_avg.update(metrics[1].item()) loss_avg.update(loss.item()) # Aggregate and log test metrics test_metrics = {'test_loss': loss_avg.value(), 'test_accTop1': accTop1_avg.value(), 'test_accTop5': accTop5_avg.value()} logging.info(f"- Test metrics: {' ; '.join(f'{k}: {v:.3f}' for k, v in test_metrics.items())}") return test_metrics