# -*- coding: utf-8 -*- # @Author: Weisen Pan import logging import torch from torch import nn, optim from fedml_service.decentralized.federated_gkt import utils # Class for training a GKT client in a federated learning setup class GKTTrainer: def __init__(self, client_index, local_training_data, local_test_data, device, client_model, args): # Initialize the client trainer with various parameters self.client_index = client_index # Index for the current client self.local_training_data = local_training_data[client_index] # Local training dataset specific to the client self.local_test_data = local_test_data[client_index] # Local test dataset specific to the client self.device = device # Device (CPU/GPU) where the computation will take place self.client_model = client_model.to(self.device) # Model assigned to the client self.args = args # Arguments passed for configuring the training process logging.info(f"Client device = {self.device}") # Model parameters used for optimization self.model_params = self.master_params = self.client_model.parameters() optim_params = self.master_params # Configure optimizer based on the provided arguments if self.args.optimizer == "SGD": # Using SGD optimizer with learning rate, momentum, and weight decay self.optimizer = optim.SGD(optim_params, lr=self.args.lr, momentum=0.9, nesterov=True, weight_decay=self.args.wd) elif self.args.optimizer == "Adam": # Using Adam optimizer with learning rate, weight decay, and AMSGrad variant self.optimizer = optim.Adam(optim_params, lr=self.args.lr, weight_decay=0.0001, amsgrad=True) # Define loss functions: CrossEntropy for true label prediction, KL divergence for knowledge distillation self.criterion_CE = nn.CrossEntropyLoss() self.criterion_KL = utils.KL_Loss(self.args.temperature) # Dictionary to hold logits received from the server (used for knowledge distillation) self.server_logits_dict = {} logging.info(f"Client device = {self.device} - Initialization Complete") # Update server logits for knowledge distillation def update_large_model_logits(self, logits): self.server_logits_dict = logits # Main training function for the client def train(self): # Dictionaries to store extracted features, logits, and labels during training and testing extracted_feature_dict, logits_dict, labels_dict = {}, {}, {} extracted_feature_dict_test, labels_dict_test = {}, {} # Only train if training on client is enabled if self.args.whether_training_on_client: self.client_model.train() # Set model to training mode epoch_loss = [] # Track loss for each epoch # Loop over the specified number of federated epochs for epoch in range(self.args.fed_epochs): batch_loss = [] # Track loss for each batch # Loop through the local training data in batches for batch_idx, (images, labels) in enumerate(self.local_training_data): # Move images and labels to the specified device images, labels = images.to(self.device), labels.to(self.device) # Forward pass through the client model log_probs, _ = self.client_model(images) # Compute the loss with respect to the true labels loss_true = self.criterion_CE(log_probs, labels) # If server logits are available, calculate the distillation loss using KL divergence if self.server_logits_dict: large_model_logits = torch.from_numpy(self.server_logits_dict[batch_idx]).to(self.device) loss_kd = self.criterion_KL(log_probs, large_model_logits) # Combine true label loss and distillation loss loss = loss_true + self.args.alpha * loss_kd else: # Use only the true label loss if no server logits are available loss = loss_true # Perform backpropagation and optimization step self.optimizer.zero_grad() # Reset gradients loss.backward() # Backpropagate the loss self.optimizer.step() # Update model parameters # Logging progress for each batch logging.info(f'Client {self.client_index} - Update Epoch: {epoch} ' f'[{batch_idx * len(images)}/{len(self.local_training_data.dataset)} ' f'({100. * batch_idx / len(self.local_training_data):.0f}%)]') batch_loss.append(loss.item()) # Store the loss for the current batch # Calculate and store average loss for the epoch epoch_loss.append(sum(batch_loss) / len(batch_loss)) # Switch to evaluation mode after training self.client_model.eval() # Extract features, logits, and labels from the training data for evaluation for batch_idx, (images, labels) in enumerate(self.local_training_data): images, labels = images.to(self.device), labels.to(self.device) log_probs, extracted_features = self.client_model(images) # Store the extracted features, logits, and labels for this batch extracted_feature_dict[batch_idx] = extracted_features.cpu().detach().numpy() logits_dict[batch_idx] = log_probs.cpu().detach().numpy() labels_dict[batch_idx] = labels.cpu().detach().numpy() # Extract features and labels from the test data for evaluation for batch_idx, (images, labels) in enumerate(self.local_test_data): test_images, test_labels = images.to(self.device), labels.to(self.device) _, extracted_features_test = self.client_model(test_images) # Store the extracted test features and labels for this batch extracted_feature_dict_test[batch_idx] = extracted_features_test.cpu().detach().numpy() labels_dict_test[batch_idx] = test_labels.cpu().detach().numpy() # Return the extracted features, logits, and labels from both training and test datasets return extracted_feature_dict, logits_dict, labels_dict, extracted_feature_dict_test, labels_dict_test