121 lines
6.5 KiB
Python
121 lines
6.5 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import logging
|
|
import torch
|
|
from torch import nn, optim
|
|
from fedml_service.decentralized.federated_gkt import utils
|
|
|
|
# Class for training a GKT client in a federated learning setup
|
|
class GKTTrainer:
|
|
def __init__(self, client_index, local_training_data, local_test_data, device, client_model, args):
|
|
# Initialize the client trainer with various parameters
|
|
self.client_index = client_index # Index for the current client
|
|
self.local_training_data = local_training_data[client_index] # Local training dataset specific to the client
|
|
self.local_test_data = local_test_data[client_index] # Local test dataset specific to the client
|
|
self.device = device # Device (CPU/GPU) where the computation will take place
|
|
self.client_model = client_model.to(self.device) # Model assigned to the client
|
|
self.args = args # Arguments passed for configuring the training process
|
|
|
|
logging.info(f"Client device = {self.device}")
|
|
|
|
# Model parameters used for optimization
|
|
self.model_params = self.master_params = self.client_model.parameters()
|
|
optim_params = self.master_params
|
|
|
|
# Configure optimizer based on the provided arguments
|
|
if self.args.optimizer == "SGD":
|
|
# Using SGD optimizer with learning rate, momentum, and weight decay
|
|
self.optimizer = optim.SGD(optim_params, lr=self.args.lr, momentum=0.9, nesterov=True, weight_decay=self.args.wd)
|
|
elif self.args.optimizer == "Adam":
|
|
# Using Adam optimizer with learning rate, weight decay, and AMSGrad variant
|
|
self.optimizer = optim.Adam(optim_params, lr=self.args.lr, weight_decay=0.0001, amsgrad=True)
|
|
|
|
# Define loss functions: CrossEntropy for true label prediction, KL divergence for knowledge distillation
|
|
self.criterion_CE = nn.CrossEntropyLoss()
|
|
self.criterion_KL = utils.KL_Loss(self.args.temperature)
|
|
|
|
# Dictionary to hold logits received from the server (used for knowledge distillation)
|
|
self.server_logits_dict = {}
|
|
|
|
logging.info(f"Client device = {self.device} - Initialization Complete")
|
|
|
|
# Update server logits for knowledge distillation
|
|
def update_large_model_logits(self, logits):
|
|
self.server_logits_dict = logits
|
|
|
|
# Main training function for the client
|
|
def train(self):
|
|
# Dictionaries to store extracted features, logits, and labels during training and testing
|
|
extracted_feature_dict, logits_dict, labels_dict = {}, {}, {}
|
|
extracted_feature_dict_test, labels_dict_test = {}, {}
|
|
|
|
# Only train if training on client is enabled
|
|
if self.args.whether_training_on_client:
|
|
self.client_model.train() # Set model to training mode
|
|
epoch_loss = [] # Track loss for each epoch
|
|
|
|
# Loop over the specified number of federated epochs
|
|
for epoch in range(self.args.fed_epochs):
|
|
batch_loss = [] # Track loss for each batch
|
|
|
|
# Loop through the local training data in batches
|
|
for batch_idx, (images, labels) in enumerate(self.local_training_data):
|
|
# Move images and labels to the specified device
|
|
images, labels = images.to(self.device), labels.to(self.device)
|
|
|
|
# Forward pass through the client model
|
|
log_probs, _ = self.client_model(images)
|
|
|
|
# Compute the loss with respect to the true labels
|
|
loss_true = self.criterion_CE(log_probs, labels)
|
|
|
|
# If server logits are available, calculate the distillation loss using KL divergence
|
|
if self.server_logits_dict:
|
|
large_model_logits = torch.from_numpy(self.server_logits_dict[batch_idx]).to(self.device)
|
|
loss_kd = self.criterion_KL(log_probs, large_model_logits)
|
|
# Combine true label loss and distillation loss
|
|
loss = loss_true + self.args.alpha * loss_kd
|
|
else:
|
|
# Use only the true label loss if no server logits are available
|
|
loss = loss_true
|
|
|
|
# Perform backpropagation and optimization step
|
|
self.optimizer.zero_grad() # Reset gradients
|
|
loss.backward() # Backpropagate the loss
|
|
self.optimizer.step() # Update model parameters
|
|
|
|
# Logging progress for each batch
|
|
logging.info(f'Client {self.client_index} - Update Epoch: {epoch} '
|
|
f'[{batch_idx * len(images)}/{len(self.local_training_data.dataset)} '
|
|
f'({100. * batch_idx / len(self.local_training_data):.0f}%)]')
|
|
batch_loss.append(loss.item()) # Store the loss for the current batch
|
|
|
|
# Calculate and store average loss for the epoch
|
|
epoch_loss.append(sum(batch_loss) / len(batch_loss))
|
|
|
|
# Switch to evaluation mode after training
|
|
self.client_model.eval()
|
|
|
|
# Extract features, logits, and labels from the training data for evaluation
|
|
for batch_idx, (images, labels) in enumerate(self.local_training_data):
|
|
images, labels = images.to(self.device), labels.to(self.device)
|
|
log_probs, extracted_features = self.client_model(images)
|
|
|
|
# Store the extracted features, logits, and labels for this batch
|
|
extracted_feature_dict[batch_idx] = extracted_features.cpu().detach().numpy()
|
|
logits_dict[batch_idx] = log_probs.cpu().detach().numpy()
|
|
labels_dict[batch_idx] = labels.cpu().detach().numpy()
|
|
|
|
# Extract features and labels from the test data for evaluation
|
|
for batch_idx, (images, labels) in enumerate(self.local_test_data):
|
|
test_images, test_labels = images.to(self.device), labels.to(self.device)
|
|
_, extracted_features_test = self.client_model(test_images)
|
|
|
|
# Store the extracted test features and labels for this batch
|
|
extracted_feature_dict_test[batch_idx] = extracted_features_test.cpu().detach().numpy()
|
|
labels_dict_test[batch_idx] = test_labels.cpu().detach().numpy()
|
|
|
|
# Return the extracted features, logits, and labels from both training and test datasets
|
|
return extracted_feature_dict, logits_dict, labels_dict, extracted_feature_dict_test, labels_dict_test
|