Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

275 lines
13 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import logging
import os
import shutil
import torch
from torch import nn, optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils import metric
from fedml_service.decentralized.federated_gkt import utils
# List to store filenames of saved checkpoints
saved_ckpt_filenames = []
class GKTServerTrainer:
def __init__(self, client_num, device, server_model, args, writer):
# Initialize the trainer with the number of clients, device (CPU/GPU), global server model, training arguments, and a writer for logging
self.client_num = client_num
self.device = device
self.args = args
self.writer = writer
"""
Notes: Using data parallelism requires adjusting the batch size accordingly.
For example, with a single GPU (batch_size = 64), an epoch takes 1:03;
using 4 GPUs (batch_size = 256), it takes 38 seconds, and with 4 GPUs (batch_size = 64), it takes 1:00.
If batch size is not adjusted, the communication between CPU and GPU may slow down training.
"""
# Server model setup
self.model_global = server_model
self.model_global.train() # Set model to training mode
self.model_global.to(self.device) # Move model to the specified device (CPU or GPU)
# Model parameters for optimization
self.model_params = self.master_params = self.model_global.parameters()
optim_params = self.master_params
# Choose optimizer based on arguments (SGD or Adam)
if self.args.optimizer == "SGD":
self.optimizer = optim.SGD(optim_params, lr=self.args.lr, momentum=0.9, nesterov=True, weight_decay=self.args.wd)
elif self.args.optimizer == "Adam":
self.optimizer = optim.Adam(optim_params, lr=self.args.lr, weight_decay=0.0001, amsgrad=True)
# Learning rate scheduler to reduce the learning rate when the accuracy plateaus
self.scheduler = ReduceLROnPlateau(self.optimizer, 'max')
# Loss functions: CrossEntropy for classification, KL for knowledge distillation
self.criterion_CE = nn.CrossEntropyLoss()
self.criterion_KL = utils.KL_Loss(self.args.temperature)
# Best accuracy tracking
self.best_acc = 0.0
# Client data dictionaries to store features, logits, and labels
self.client_extracted_feature_dict = {}
self.client_logits_dict = {}
self.client_labels_dict = {}
self.server_logits_dict = {}
# Testing data dictionaries
self.client_extracted_feature_dict_test = {}
self.client_labels_dict_test = {}
# Miscellaneous dictionaries to store model info, sample numbers, training accuracy, and loss
self.model_dict = {}
self.sample_num_dict = {}
self.train_acc_dict = {}
self.train_loss_dict = {}
self.test_acc_avg = 0.0
self.test_loss_avg = 0.0
# Dictionary to track if the client model has been uploaded
self.flag_client_model_uploaded_dict = {idx: False for idx in range(self.client_num)}
# Add results from a local client model after training
def add_local_trained_result(self, index, extracted_feature_dict, logits_dict, labels_dict,
extracted_feature_dict_test, labels_dict_test):
logging.info(f"Adding model for client index = {index}")
self.client_extracted_feature_dict[index] = extracted_feature_dict
self.client_logits_dict[index] = logits_dict
self.client_labels_dict[index] = labels_dict
self.client_extracted_feature_dict_test[index] = extracted_feature_dict_test
self.client_labels_dict_test[index] = labels_dict_test
self.flag_client_model_uploaded_dict[index] = True
# Check if all clients have uploaded their models
def check_whether_all_receive(self):
if all(self.flag_client_model_uploaded_dict.values()):
self.flag_client_model_uploaded_dict = {idx: False for idx in range(self.client_num)}
return True
return False
# Get logits from the global model for a specific client
def get_global_logits(self, client_index):
return self.server_logits_dict.get(client_index)
# Main training function based on the round index
def train(self, round_idx):
if self.args.sweep == 1: # Sweep mode
self.sweep(round_idx)
else: # Normal training process
if self.args.whether_training_on_client == 1: # Check if training occurs on client
self.train_and_distill_on_client(round_idx)
else: # No training on client, just evaluate
self.do_not_train_on_client(round_idx)
# Training and knowledge distillation on client side
def train_and_distill_on_client(self, round_idx):
# Set the number of server epochs (based on testing mode)
epochs_server = 1 if not self.args.test else self.get_server_epoch_strategy_test()[0]
self.train_and_eval(round_idx, epochs_server, self.writer, self.args) # Train and evaluate
self.scheduler.step(self.best_acc, epoch=round_idx) # Update learning rate scheduler
# Skip client-side training
def do_not_train_on_client(self, round_idx):
self.train_and_eval(round_idx, 1)
self.scheduler.step(self.best_acc, epoch=round_idx)
# Training with sweeping strategy
def sweep(self, round_idx):
self.train_and_eval(round_idx, self.args.epochs_server)
self.scheduler.step(self.best_acc, epoch=round_idx)
# Strategy for determining the number of epochs (used in testing)
def get_server_epoch_strategy_test(self):
return 1, True
# Different strategies for determining the number of epochs based on training round
def get_server_epoch_strategy_reset56(self, round_idx):
epochs = 20 if round_idx < 20 else 15 if round_idx < 30 else 10 if round_idx < 40 else 5 if round_idx < 50 else 3 if round_idx < 150 else 1
whether_distill_back = round_idx < 150
return epochs, whether_distill_back
# Another variant of epoch strategy
def get_server_epoch_strategy_reset56_2(self, round_idx):
return self.args.epochs_server, True
# Main training and evaluation loop
def train_and_eval(self, round_idx, epochs, val_writer, args):
for epoch in range(epochs):
logging.info(f"Train and evaluate. Round = {round_idx}, Epoch = {epoch}")
train_metrics = self.train_large_model_on_the_server() # Training step
if epoch == epochs - 1:
# Log metrics for the final epoch
val_writer.add_scalar('average training loss', train_metrics['train_loss'], global_step=round_idx)
test_metrics = self.eval_large_model_on_the_server() # Evaluation step
test_acc = test_metrics['test_accTop1']
val_writer.add_scalar('test loss', test_metrics['test_loss'], global_step=round_idx)
val_writer.add_scalar('test acc', test_metrics['test_accTop1'], global_step=round_idx)
# Save best accuracy model
if test_acc >= self.best_acc:
logging.info("- Found better accuracy")
self.best_acc = test_acc
val_writer.add_scalar('best_acc1', self.best_acc, global_step=round_idx)
# Save model checkpoints
if args.save_weight:
filename = f"checkpoint_{round_idx}.pth.tar"
saved_ckpt_filenames.append(filename)
if len(saved_ckpt_filenames) > args.max_ckpt_nums:
os.remove(os.path.join(args.model_dir, saved_ckpt_filenames.pop(0)))
ckpt_dict = {
'round': round_idx + 1,
'arch': args.arch,
'state_dict': self.model_global.state_dict(),
'best_acc1': self.best_acc,
'optimizer': self.optimizer.state_dict(),
}
metric.save_checkpoint(ckpt_dict, test_acc >= self.best_acc, args.model_dir, filename=filename)
# Print metrics for the current round
print(f"{round_idx}-th round | Train Loss: {train_metrics['train_loss']:.3g} | Test Loss: {test_metrics['test_loss']:.3g} | Test Acc: {test_metrics['test_accTop1']:.3f}")
# Function to train the model on the server side
def train_large_model_on_the_server(self):
# Clear the logits dictionary and set model to training mode
self.server_logits_dict.clear()
self.model_global.train()
# Track loss and accuracy
loss_avg = utils.RollingAverage()
accTop1_avg = utils.RollingAverage()
accTop5_avg = utils.RollingAverage()
# Iterate over clients' extracted features
for client_index, extracted_feature_dict in self.client_extracted_feature_dict.items():
logits_dict = self.client_logits_dict[client_index]
labels_dict = self.client_labels_dict[client_index]
s_logits_dict = {}
self.server_logits_dict[client_index] = s_logits_dict
# Iterate over batches of features for each client
for batch_index, batch_feature_map_x in extracted_feature_dict.items():
batch_feature_map_x = torch.from_numpy(batch_feature_map_x).to(self.device)
batch_logits = torch.from_numpy(logits_dict[batch_index]).float().to(self.device)
batch_labels = torch.from_numpy(labels_dict[batch_index]).long().to(self.device)
# Forward pass
output_batch = self.model_global(batch_feature_map_x)
# Knowledge distillation loss
if self.args.whether_distill_on_the_server == 1:
loss_kd = self.criterion_KL(output_batch, batch_logits).to(self.device)
loss_true = self.criterion_CE(output_batch, batch_labels).to(self.device)
loss = loss_kd + self.args.alpha * loss_true
else:
# Standard cross-entropy loss
loss = self.criterion_CE(output_batch, batch_labels).to(self.device)
# Backward pass and optimization
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Compute accuracy metrics
metrics = utils.accuracy(output_batch, batch_labels, topk=(1, 5))
accTop1_avg.update(metrics[0].item())
accTop5_avg.update(metrics[1].item())
loss_avg.update(loss.item())
# Store logits for the batch
s_logits_dict[batch_index] = output_batch.cpu().detach().numpy()
# Aggregate and log training metrics
train_metrics = {'train_loss': loss_avg.value(),
'train_accTop1': accTop1_avg.value(),
'train_accTop5': accTop5_avg.value()}
logging.info(f"- Train metrics: {' ; '.join(f'{k}: {v:.3f}' for k, v in train_metrics.items())}")
return train_metrics
# Function to evaluate the model on the server side
def eval_large_model_on_the_server(self):
# Set model to evaluation mode
self.model_global.eval()
loss_avg = utils.RollingAverage()
accTop1_avg = utils.RollingAverage()
accTop5_avg = utils.RollingAverage()
# Disable gradient computation for evaluation
with torch.no_grad():
# Iterate over clients' extracted features for testing
for client_index, extracted_feature_dict in self.client_extracted_feature_dict_test.items():
labels_dict = self.client_labels_dict_test[client_index]
# Iterate over batches for each client
for batch_index, batch_feature_map_x in extracted_feature_dict.items():
batch_feature_map_x = torch.from_numpy(batch_feature_map_x).to(self.device)
batch_labels = torch.from_numpy(labels_dict[batch_index]).long().to(self.device)
# Forward pass
output_batch = self.model_global(batch_feature_map_x)
loss = self.criterion_CE(output_batch, batch_labels)
# Compute accuracy metrics
metrics = utils.accuracy(output_batch, batch_labels, topk=(1, 5))
accTop1_avg.update(metrics[0].item())
accTop5_avg.update(metrics[1].item())
loss_avg.update(loss.item())
# Aggregate and log test metrics
test_metrics = {'test_loss': loss_avg.value(),
'test_accTop1': accTop1_avg.value(),
'test_accTop5': accTop5_avg.value()}
logging.info(f"- Test metrics: {' ; '.join(f'{k}: {v:.3f}' for k, v in test_metrics.items())}")
return test_metrics