Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

121 lines
6.5 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import logging
import torch
from torch import nn, optim
from fedml_service.decentralized.federated_gkt import utils
# Class for training a GKT client in a federated learning setup
class GKTTrainer:
def __init__(self, client_index, local_training_data, local_test_data, device, client_model, args):
# Initialize the client trainer with various parameters
self.client_index = client_index # Index for the current client
self.local_training_data = local_training_data[client_index] # Local training dataset specific to the client
self.local_test_data = local_test_data[client_index] # Local test dataset specific to the client
self.device = device # Device (CPU/GPU) where the computation will take place
self.client_model = client_model.to(self.device) # Model assigned to the client
self.args = args # Arguments passed for configuring the training process
logging.info(f"Client device = {self.device}")
# Model parameters used for optimization
self.model_params = self.master_params = self.client_model.parameters()
optim_params = self.master_params
# Configure optimizer based on the provided arguments
if self.args.optimizer == "SGD":
# Using SGD optimizer with learning rate, momentum, and weight decay
self.optimizer = optim.SGD(optim_params, lr=self.args.lr, momentum=0.9, nesterov=True, weight_decay=self.args.wd)
elif self.args.optimizer == "Adam":
# Using Adam optimizer with learning rate, weight decay, and AMSGrad variant
self.optimizer = optim.Adam(optim_params, lr=self.args.lr, weight_decay=0.0001, amsgrad=True)
# Define loss functions: CrossEntropy for true label prediction, KL divergence for knowledge distillation
self.criterion_CE = nn.CrossEntropyLoss()
self.criterion_KL = utils.KL_Loss(self.args.temperature)
# Dictionary to hold logits received from the server (used for knowledge distillation)
self.server_logits_dict = {}
logging.info(f"Client device = {self.device} - Initialization Complete")
# Update server logits for knowledge distillation
def update_large_model_logits(self, logits):
self.server_logits_dict = logits
# Main training function for the client
def train(self):
# Dictionaries to store extracted features, logits, and labels during training and testing
extracted_feature_dict, logits_dict, labels_dict = {}, {}, {}
extracted_feature_dict_test, labels_dict_test = {}, {}
# Only train if training on client is enabled
if self.args.whether_training_on_client:
self.client_model.train() # Set model to training mode
epoch_loss = [] # Track loss for each epoch
# Loop over the specified number of federated epochs
for epoch in range(self.args.fed_epochs):
batch_loss = [] # Track loss for each batch
# Loop through the local training data in batches
for batch_idx, (images, labels) in enumerate(self.local_training_data):
# Move images and labels to the specified device
images, labels = images.to(self.device), labels.to(self.device)
# Forward pass through the client model
log_probs, _ = self.client_model(images)
# Compute the loss with respect to the true labels
loss_true = self.criterion_CE(log_probs, labels)
# If server logits are available, calculate the distillation loss using KL divergence
if self.server_logits_dict:
large_model_logits = torch.from_numpy(self.server_logits_dict[batch_idx]).to(self.device)
loss_kd = self.criterion_KL(log_probs, large_model_logits)
# Combine true label loss and distillation loss
loss = loss_true + self.args.alpha * loss_kd
else:
# Use only the true label loss if no server logits are available
loss = loss_true
# Perform backpropagation and optimization step
self.optimizer.zero_grad() # Reset gradients
loss.backward() # Backpropagate the loss
self.optimizer.step() # Update model parameters
# Logging progress for each batch
logging.info(f'Client {self.client_index} - Update Epoch: {epoch} '
f'[{batch_idx * len(images)}/{len(self.local_training_data.dataset)} '
f'({100. * batch_idx / len(self.local_training_data):.0f}%)]')
batch_loss.append(loss.item()) # Store the loss for the current batch
# Calculate and store average loss for the epoch
epoch_loss.append(sum(batch_loss) / len(batch_loss))
# Switch to evaluation mode after training
self.client_model.eval()
# Extract features, logits, and labels from the training data for evaluation
for batch_idx, (images, labels) in enumerate(self.local_training_data):
images, labels = images.to(self.device), labels.to(self.device)
log_probs, extracted_features = self.client_model(images)
# Store the extracted features, logits, and labels for this batch
extracted_feature_dict[batch_idx] = extracted_features.cpu().detach().numpy()
logits_dict[batch_idx] = log_probs.cpu().detach().numpy()
labels_dict[batch_idx] = labels.cpu().detach().numpy()
# Extract features and labels from the test data for evaluation
for batch_idx, (images, labels) in enumerate(self.local_test_data):
test_images, test_labels = images.to(self.device), labels.to(self.device)
_, extracted_features_test = self.client_model(test_images)
# Store the extracted test features and labels for this batch
extracted_feature_dict_test[batch_idx] = extracted_features_test.cpu().detach().numpy()
labels_dict_test[batch_idx] = test_labels.cpu().detach().numpy()
# Return the extracted features, logits, and labels from both training and test datasets
return extracted_feature_dict, logits_dict, labels_dict, extracted_feature_dict_test, labels_dict_test