275 lines
13 KiB
Python
275 lines
13 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import logging
|
|
import os
|
|
import shutil
|
|
|
|
import torch
|
|
from torch import nn, optim
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
from utils import metric
|
|
from fedml_service.decentralized.federated_gkt import utils
|
|
|
|
# List to store filenames of saved checkpoints
|
|
saved_ckpt_filenames = []
|
|
|
|
class GKTServerTrainer:
|
|
def __init__(self, client_num, device, server_model, args, writer):
|
|
# Initialize the trainer with the number of clients, device (CPU/GPU), global server model, training arguments, and a writer for logging
|
|
self.client_num = client_num
|
|
self.device = device
|
|
self.args = args
|
|
self.writer = writer
|
|
|
|
"""
|
|
Notes: Using data parallelism requires adjusting the batch size accordingly.
|
|
For example, with a single GPU (batch_size = 64), an epoch takes 1:03;
|
|
using 4 GPUs (batch_size = 256), it takes 38 seconds, and with 4 GPUs (batch_size = 64), it takes 1:00.
|
|
If batch size is not adjusted, the communication between CPU and GPU may slow down training.
|
|
"""
|
|
|
|
# Server model setup
|
|
self.model_global = server_model
|
|
self.model_global.train() # Set model to training mode
|
|
self.model_global.to(self.device) # Move model to the specified device (CPU or GPU)
|
|
|
|
# Model parameters for optimization
|
|
self.model_params = self.master_params = self.model_global.parameters()
|
|
optim_params = self.master_params
|
|
|
|
# Choose optimizer based on arguments (SGD or Adam)
|
|
if self.args.optimizer == "SGD":
|
|
self.optimizer = optim.SGD(optim_params, lr=self.args.lr, momentum=0.9, nesterov=True, weight_decay=self.args.wd)
|
|
elif self.args.optimizer == "Adam":
|
|
self.optimizer = optim.Adam(optim_params, lr=self.args.lr, weight_decay=0.0001, amsgrad=True)
|
|
|
|
# Learning rate scheduler to reduce the learning rate when the accuracy plateaus
|
|
self.scheduler = ReduceLROnPlateau(self.optimizer, 'max')
|
|
|
|
# Loss functions: CrossEntropy for classification, KL for knowledge distillation
|
|
self.criterion_CE = nn.CrossEntropyLoss()
|
|
self.criterion_KL = utils.KL_Loss(self.args.temperature)
|
|
|
|
# Best accuracy tracking
|
|
self.best_acc = 0.0
|
|
|
|
# Client data dictionaries to store features, logits, and labels
|
|
self.client_extracted_feature_dict = {}
|
|
self.client_logits_dict = {}
|
|
self.client_labels_dict = {}
|
|
self.server_logits_dict = {}
|
|
|
|
# Testing data dictionaries
|
|
self.client_extracted_feature_dict_test = {}
|
|
self.client_labels_dict_test = {}
|
|
|
|
# Miscellaneous dictionaries to store model info, sample numbers, training accuracy, and loss
|
|
self.model_dict = {}
|
|
self.sample_num_dict = {}
|
|
self.train_acc_dict = {}
|
|
self.train_loss_dict = {}
|
|
self.test_acc_avg = 0.0
|
|
self.test_loss_avg = 0.0
|
|
|
|
# Dictionary to track if the client model has been uploaded
|
|
self.flag_client_model_uploaded_dict = {idx: False for idx in range(self.client_num)}
|
|
|
|
# Add results from a local client model after training
|
|
def add_local_trained_result(self, index, extracted_feature_dict, logits_dict, labels_dict,
|
|
extracted_feature_dict_test, labels_dict_test):
|
|
logging.info(f"Adding model for client index = {index}")
|
|
self.client_extracted_feature_dict[index] = extracted_feature_dict
|
|
self.client_logits_dict[index] = logits_dict
|
|
self.client_labels_dict[index] = labels_dict
|
|
self.client_extracted_feature_dict_test[index] = extracted_feature_dict_test
|
|
self.client_labels_dict_test[index] = labels_dict_test
|
|
self.flag_client_model_uploaded_dict[index] = True
|
|
|
|
# Check if all clients have uploaded their models
|
|
def check_whether_all_receive(self):
|
|
if all(self.flag_client_model_uploaded_dict.values()):
|
|
self.flag_client_model_uploaded_dict = {idx: False for idx in range(self.client_num)}
|
|
return True
|
|
return False
|
|
|
|
# Get logits from the global model for a specific client
|
|
def get_global_logits(self, client_index):
|
|
return self.server_logits_dict.get(client_index)
|
|
|
|
# Main training function based on the round index
|
|
def train(self, round_idx):
|
|
if self.args.sweep == 1: # Sweep mode
|
|
self.sweep(round_idx)
|
|
else: # Normal training process
|
|
if self.args.whether_training_on_client == 1: # Check if training occurs on client
|
|
self.train_and_distill_on_client(round_idx)
|
|
else: # No training on client, just evaluate
|
|
self.do_not_train_on_client(round_idx)
|
|
|
|
# Training and knowledge distillation on client side
|
|
def train_and_distill_on_client(self, round_idx):
|
|
# Set the number of server epochs (based on testing mode)
|
|
epochs_server = 1 if not self.args.test else self.get_server_epoch_strategy_test()[0]
|
|
self.train_and_eval(round_idx, epochs_server, self.writer, self.args) # Train and evaluate
|
|
self.scheduler.step(self.best_acc, epoch=round_idx) # Update learning rate scheduler
|
|
|
|
# Skip client-side training
|
|
def do_not_train_on_client(self, round_idx):
|
|
self.train_and_eval(round_idx, 1)
|
|
self.scheduler.step(self.best_acc, epoch=round_idx)
|
|
|
|
# Training with sweeping strategy
|
|
def sweep(self, round_idx):
|
|
self.train_and_eval(round_idx, self.args.epochs_server)
|
|
self.scheduler.step(self.best_acc, epoch=round_idx)
|
|
|
|
# Strategy for determining the number of epochs (used in testing)
|
|
def get_server_epoch_strategy_test(self):
|
|
return 1, True
|
|
|
|
# Different strategies for determining the number of epochs based on training round
|
|
def get_server_epoch_strategy_reset56(self, round_idx):
|
|
epochs = 20 if round_idx < 20 else 15 if round_idx < 30 else 10 if round_idx < 40 else 5 if round_idx < 50 else 3 if round_idx < 150 else 1
|
|
whether_distill_back = round_idx < 150
|
|
return epochs, whether_distill_back
|
|
|
|
# Another variant of epoch strategy
|
|
def get_server_epoch_strategy_reset56_2(self, round_idx):
|
|
return self.args.epochs_server, True
|
|
|
|
# Main training and evaluation loop
|
|
def train_and_eval(self, round_idx, epochs, val_writer, args):
|
|
for epoch in range(epochs):
|
|
logging.info(f"Train and evaluate. Round = {round_idx}, Epoch = {epoch}")
|
|
train_metrics = self.train_large_model_on_the_server() # Training step
|
|
|
|
if epoch == epochs - 1:
|
|
# Log metrics for the final epoch
|
|
val_writer.add_scalar('average training loss', train_metrics['train_loss'], global_step=round_idx)
|
|
test_metrics = self.eval_large_model_on_the_server() # Evaluation step
|
|
test_acc = test_metrics['test_accTop1']
|
|
|
|
val_writer.add_scalar('test loss', test_metrics['test_loss'], global_step=round_idx)
|
|
val_writer.add_scalar('test acc', test_metrics['test_accTop1'], global_step=round_idx)
|
|
|
|
# Save best accuracy model
|
|
if test_acc >= self.best_acc:
|
|
logging.info("- Found better accuracy")
|
|
self.best_acc = test_acc
|
|
|
|
val_writer.add_scalar('best_acc1', self.best_acc, global_step=round_idx)
|
|
|
|
# Save model checkpoints
|
|
if args.save_weight:
|
|
filename = f"checkpoint_{round_idx}.pth.tar"
|
|
saved_ckpt_filenames.append(filename)
|
|
if len(saved_ckpt_filenames) > args.max_ckpt_nums:
|
|
os.remove(os.path.join(args.model_dir, saved_ckpt_filenames.pop(0)))
|
|
|
|
ckpt_dict = {
|
|
'round': round_idx + 1,
|
|
'arch': args.arch,
|
|
'state_dict': self.model_global.state_dict(),
|
|
'best_acc1': self.best_acc,
|
|
'optimizer': self.optimizer.state_dict(),
|
|
}
|
|
metric.save_checkpoint(ckpt_dict, test_acc >= self.best_acc, args.model_dir, filename=filename)
|
|
|
|
# Print metrics for the current round
|
|
print(f"{round_idx}-th round | Train Loss: {train_metrics['train_loss']:.3g} | Test Loss: {test_metrics['test_loss']:.3g} | Test Acc: {test_metrics['test_accTop1']:.3f}")
|
|
|
|
# Function to train the model on the server side
|
|
def train_large_model_on_the_server(self):
|
|
# Clear the logits dictionary and set model to training mode
|
|
self.server_logits_dict.clear()
|
|
self.model_global.train()
|
|
|
|
# Track loss and accuracy
|
|
loss_avg = utils.RollingAverage()
|
|
accTop1_avg = utils.RollingAverage()
|
|
accTop5_avg = utils.RollingAverage()
|
|
|
|
# Iterate over clients' extracted features
|
|
for client_index, extracted_feature_dict in self.client_extracted_feature_dict.items():
|
|
logits_dict = self.client_logits_dict[client_index]
|
|
labels_dict = self.client_labels_dict[client_index]
|
|
|
|
s_logits_dict = {}
|
|
self.server_logits_dict[client_index] = s_logits_dict
|
|
|
|
# Iterate over batches of features for each client
|
|
for batch_index, batch_feature_map_x in extracted_feature_dict.items():
|
|
batch_feature_map_x = torch.from_numpy(batch_feature_map_x).to(self.device)
|
|
batch_logits = torch.from_numpy(logits_dict[batch_index]).float().to(self.device)
|
|
batch_labels = torch.from_numpy(labels_dict[batch_index]).long().to(self.device)
|
|
|
|
# Forward pass
|
|
output_batch = self.model_global(batch_feature_map_x)
|
|
|
|
# Knowledge distillation loss
|
|
if self.args.whether_distill_on_the_server == 1:
|
|
loss_kd = self.criterion_KL(output_batch, batch_logits).to(self.device)
|
|
loss_true = self.criterion_CE(output_batch, batch_labels).to(self.device)
|
|
loss = loss_kd + self.args.alpha * loss_true
|
|
else:
|
|
# Standard cross-entropy loss
|
|
loss = self.criterion_CE(output_batch, batch_labels).to(self.device)
|
|
|
|
# Backward pass and optimization
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
# Compute accuracy metrics
|
|
metrics = utils.accuracy(output_batch, batch_labels, topk=(1, 5))
|
|
accTop1_avg.update(metrics[0].item())
|
|
accTop5_avg.update(metrics[1].item())
|
|
loss_avg.update(loss.item())
|
|
|
|
# Store logits for the batch
|
|
s_logits_dict[batch_index] = output_batch.cpu().detach().numpy()
|
|
|
|
# Aggregate and log training metrics
|
|
train_metrics = {'train_loss': loss_avg.value(),
|
|
'train_accTop1': accTop1_avg.value(),
|
|
'train_accTop5': accTop5_avg.value()}
|
|
logging.info(f"- Train metrics: {' ; '.join(f'{k}: {v:.3f}' for k, v in train_metrics.items())}")
|
|
return train_metrics
|
|
|
|
# Function to evaluate the model on the server side
|
|
def eval_large_model_on_the_server(self):
|
|
# Set model to evaluation mode
|
|
self.model_global.eval()
|
|
loss_avg = utils.RollingAverage()
|
|
accTop1_avg = utils.RollingAverage()
|
|
accTop5_avg = utils.RollingAverage()
|
|
|
|
# Disable gradient computation for evaluation
|
|
with torch.no_grad():
|
|
# Iterate over clients' extracted features for testing
|
|
for client_index, extracted_feature_dict in self.client_extracted_feature_dict_test.items():
|
|
labels_dict = self.client_labels_dict_test[client_index]
|
|
|
|
# Iterate over batches for each client
|
|
for batch_index, batch_feature_map_x in extracted_feature_dict.items():
|
|
batch_feature_map_x = torch.from_numpy(batch_feature_map_x).to(self.device)
|
|
batch_labels = torch.from_numpy(labels_dict[batch_index]).long().to(self.device)
|
|
|
|
# Forward pass
|
|
output_batch = self.model_global(batch_feature_map_x)
|
|
loss = self.criterion_CE(output_batch, batch_labels)
|
|
|
|
# Compute accuracy metrics
|
|
metrics = utils.accuracy(output_batch, batch_labels, topk=(1, 5))
|
|
accTop1_avg.update(metrics[0].item())
|
|
accTop5_avg.update(metrics[1].item())
|
|
loss_avg.update(loss.item())
|
|
|
|
# Aggregate and log test metrics
|
|
test_metrics = {'test_loss': loss_avg.value(),
|
|
'test_accTop1': accTop1_avg.value(),
|
|
'test_accTop5': accTop5_avg.value()}
|
|
logging.info(f"- Test metrics: {' ; '.join(f'{k}: {v:.3f}' for k, v in test_metrics.items())}")
|
|
return test_metrics
|