# -*- coding: utf-8 -*- # @Author: Weisen Pan import os import shutil import torch def store_model(state, best_model, directory, filename='checkpoint.pth'): """ Stores the model checkpoint in the specified directory. If it's the best model, it saves another copy named 'best_model.pth'. Args: state (dict): Model's state dictionary. best_model (bool): Flag indicating if the current model is the best. directory (str): Directory where the model is saved. filename (str): Name of the file to save the checkpoint (default 'checkpoint.pth'). """ save_path = os.path.join(directory, filename) torch.save(state, save_path) if best_model: # If the current model is the best, save another copy as 'best_model.pth' shutil.copy(save_path, os.path.join(directory, 'best_model.pth')) def save_main_client_model(state, best_model, directory): """ Saves the model for the main client if it's the best one. Args: state (dict): Model's state dictionary. best_model (bool): Flag indicating if the current model is the best. directory (str): Directory where the model is saved. """ if best_model: print("Saving the best main client model") torch.save(state, os.path.join(directory, 'main_client_best.pth')) def save_proxy_clients_model(state, best_model, directory): """ Saves the model for proxy clients if it's the best one. Args: state (dict): Model's state dictionary. best_model (bool): Flag indicating if the current model is the best. directory (str): Directory where the model is saved. """ if best_model: print("Saving the best proxy client model") torch.save(state, os.path.join(directory, 'proxy_clients_best.pth')) def save_individual_client_model(state, best_model, directory): """ Saves the model for individual clients if it's the best one. Args: state (dict): Model's state dictionary. best_model (bool): Flag indicating if the current model is the best. directory (str): Directory where the model is saved. """ if best_model: print("Saving the best client model") torch.save(state, os.path.join(directory, 'client_best.pth')) def save_server_model(state, best_model, directory): """ Saves the model for the server if it's the best one. Args: state (dict): Model's state dictionary. best_model (bool): Flag indicating if the current model is the best. directory (str): Directory where the model is saved. """ if best_model: print("Saving the best server model") torch.save(state, os.path.join(directory, 'server_best.pth')) class MetricTracker(object): """ A helper class to track and compute the average of a given metric. Args: metric_name (str): Name of the metric to track. fmt (str): Format for printing metric values (default ':f'). """ def __init__(self, metric_name, fmt=':f'): self.metric_name = metric_name self.fmt = fmt self.reset() def reset(self): """Resets all metric counters.""" self.current_value = 0 self.total_sum = 0 self.count = 0 self.average = 0 def update(self, value, n=1): """ Updates the metric value. Args: value (float): New value of the metric. n (int): Weight or count for the value (default 1). """ self.current_value = value self.total_sum += value * n self.count += n self.average = self.total_sum / self.count def __str__(self): """Returns the formatted metric string showing current value and average.""" return f'{self.metric_name} {self.current_value{self.fmt}} ({self.average{self.fmt}})' class ProgressLogger(object): """ A class to log and display the progress of training/testing over multiple batches. Args: total_batches (int): Total number of batches. *metrics (MetricTracker): Metrics to log during the process. prefix (str): Prefix for the progress log (default "Progress:"). """ def __init__(self, total_batches, *metrics, prefix="Progress:"): self.batch_format = self._get_batch_format(total_batches) self.metrics = metrics self.prefix = prefix def log(self, batch_idx): """ Logs the current progress of training/testing. Args: batch_idx (int): The current batch index. """ output = [self.prefix + self.batch_format.format(batch_idx)] output += [str(metric) for metric in self.metrics] print(' | '.join(output)) def _get_batch_format(self, total_batches): """Creates a format string to display the batch index.""" num_digits = len(str(total_batches)) return '[{:' + str(num_digits) + 'd}/{}]'.format(total_batches) def compute_accuracy(prediction, target, top_k=(1,)): """ Computes the accuracy for the top-k predictions. Args: prediction (Tensor): Model predictions. target (Tensor): Ground truth labels. top_k (tuple): Tuple of top-k values to consider for accuracy (default (1,)). Returns: List[Tensor]: List of accuracies for each top-k value. """ with torch.no_grad(): max_k = max(top_k) batch_size = target.size(0) # Get the top-k predictions _, top_predictions = prediction.topk(max_k, 1, largest=True, sorted=True) top_predictions = top_predictions.t() # Compare top-k predictions with targets correct_predictions = top_predictions.eq(target.view(1, -1).expand_as(top_predictions)) accuracy_results = [] for k in top_k: # Count the number of correct predictions within the top-k correct_k = correct_predictions[:k].view(-1).float().sum(0, keepdim=True) accuracy_results.append(correct_k.mul_(100.0 / batch_size)) return accuracy_results def count_model_parameters(model, trainable_only=False): """ Counts the total number of parameters in the model. Args: model (nn.Module): The PyTorch model. trainable_only (bool): Whether to count only trainable parameters (default False). Returns: int: Total number of parameters in the model. """ if trainable_only: # Count only the parameters that require gradients (trainable parameters) return sum(p.numel() for p in model.parameters() if p.requires_grad) # Count all parameters (trainable and non-trainable) return sum(p.numel() for p in model.parameters())