Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

246 lines
12 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import os
from sklearn import ensemble
import torch
import torch.nn as nn
import torch.nn.functional as F
from .mixup import mixup_loss_criterion, combine_mixup_data
from . import resnet, resnet_sl
# Exported members of the module
__all__ = ['coremodelSL']
def _retrieve_networkwork(arch='resnet_model_110sl'):
"""Retrieve the specific network architecture based on the provided name."""
available_networks = {
'resnet_model_110sl': resnet_sl.resnet_model_110sl,
'wide_resnetsl50_2': resnet_sl.wide_resnetsl50_2,
'wide_resnetsl16_8': resnet_sl.wide_resnetsl16_8,
}
# Ensure the architecture requested exists in the available networks
assert arch in available_networks, f"Architecture '{arch}' is not supported."
return available_networks[arch]
class CoreModelClient(nn.Module):
"""Main client model for training and inference, managing multiple sub-networks."""
def __init__(self, args, norm_layer=None, criterion=None, progress=True):
super(CoreModelClient, self).__init__()
# Parameters and configurations for the client model
self.split_factor = args.split_factor
self.arch = args.arch
self.loop_factor = args.loop_factor
self.is_train_sep = args.is_train_sep
self.epochs = args.epochs
self.num_classes = args.num_classes
self.is_diff_data_train = args.is_diff_data_train
self.is_mixup = args.is_mixup
self.mix_alpha = args.mix_alpha
# Model arguments
model_kwargs = {
'num_classes': self.num_classes,
'norm_layer': norm_layer,
'dataset': args.dataset,
'split_factor': self.split_factor,
'output_stride': args.output_stride
}
# Initialize multiple instances of the network architecture for the main client
if self.arch in ['resnet_model_110sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8']:
self.main_client_models = nn.ModuleList(
[_retrieve_networkwork(self.arch)(models_pretrained=args.models_pretrained, **model_kwargs)[0]
for _ in range(self.loop_factor)]
)
else:
raise NotImplementedError(f"Architecture '{self.arch}' not implemented.")
# Identical initialization of the model if specified
if args.is_identical_init:
print("INFO:PyTorch: Using identical initialization.")
self._identical_init()
def forward(self, x, target=None, mode='train', epoch=0, streams=None):
"""Forward pass for the main client. Handles both training and evaluation modes."""
main_client_outputs = []
if self.arch in ['resnet_model_110sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8']:
if mode == 'train':
# Apply mixup augmentation if enabled
if self.is_mixup:
x, y_a, y_b, lam = combine_mixup_data(x, target, alpha=self.mix_alpha)
# Split input data across multiple sub-networks during training
all_x = torch.chunk(x, chunks=self.loop_factor, dim=1) if self.is_diff_data_train else [x] * self.loop_factor
for i in range(self.loop_factor):
fx = self.main_client_models[i](all_x[i])
main_client_outputs.append(fx.clone().detach().requires_grad_(True))
return main_client_outputs, y_a, y_b, lam
elif mode in ['val', 'test']:
# Forward pass during evaluation or testing
for i in range(self.loop_factor):
fx = self.main_client_models[i](x)
main_client_outputs.append(fx.clone().detach().requires_grad_(True))
return main_client_outputs
else:
# Return a dummy tensor if the mode is unsupported
return torch.ones(1)
else:
raise NotImplementedError(f"Mode '{mode}' not supported for architecture '{self.arch}'.")
def _identical_init(self):
"""Ensure identical initialization of weights for sub-networks."""
with torch.no_grad():
# Copy weights from the first model to all subsequent models
for i in range(1, self.split_factor):
for (name1, param1), (name2, param2) in zip(self.main_client_models[i].named_parameters(),
self.main_client_models[0].named_parameters()):
if 'weight' in name1:
param1.data.copy_(param2.data)
class coremodelProxyClient(nn.Module):
"""Proxy client model to handle downstream processing and training logic."""
def __init__(self, args, norm_layer=None, criterion=None, progress=True):
super(coremodelProxyClient, self).__init__()
# Parameters and configurations for the proxy client model
self.split_factor = args.split_factor
self.arch = args.arch
self.loop_factor = args.loop_factor
self.epochs = args.epochs
self.num_classes = args.num_classes
self.criterion = criterion
self.is_mixup = args.is_mixup
self.is_ensembled_loss = args.is_ensembled_loss if self.split_factor > 1 else False
self.ensembled_loss_weight = args.ensembled_loss_weight
self.is_ensembled_after_softmax = args.is_ensembled_after_softmax if self.split_factor > 1 else False
self.is_max_ensemble = args.is_max_ensemble if self.split_factor > 1 else False
self.is_cot_loss = args.is_cot_loss if self.split_factor > 1 else False
self.cot_weight = args.cot_weight
self.is_cot_weight_warm_up = args.is_cot_weight_warm_up
self.cot_weight_warm_up_epochs = args.cot_weight_warm_up_epochs
self.cot_loss_choose = args.cot_loss_choose
# Model arguments for the proxy client
model_kwargs = {
'num_classes': self.num_classes,
'norm_layer': norm_layer,
'dataset': args.dataset,
'split_factor': self.split_factor,
'output_stride': args.output_stride
}
# Initialize multiple instances of the network architecture for the proxy client
if self.arch in ['resnet_model_110sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8']:
self.proxy_clients_models = nn.ModuleList(
[_retrieve_networkwork(self.arch)(models_pretrained=args.models_pretrained, **model_kwargs)[1]
for _ in range(self.loop_factor)]
)
else:
raise NotImplementedError(f"Architecture '{self.arch}' not implemented.")
# Identical initialization of the model if specified
if args.is_identical_init:
print("INFO:PyTorch: Using identical initialization.")
self._identical_init()
def forward(self, main_client_outputs, y_a=None, y_b=None, lam=None, target=None, mode='train', epoch=0, streams=None):
"""Forward pass for the proxy client. Manages multiple sub-networks and ensemble outputs."""
outputs = []
ce_losses = []
if self.arch in ['resnet_model_110sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8']:
if mode == 'train':
# Calculate loss and forward pass during training
for i in range(self.loop_factor):
output = self.proxy_clients_models[i](main_client_outputs[i])
loss = mixup_loss_criterion(self.criterion, output, y_a, y_b, lam) if self.is_mixup else self.criterion(output, target)
outputs.append(output)
ce_losses.append(loss)
ensemble_output = self._collect_ensemble_output(outputs)
ce_loss = torch.sum(torch.stack(ce_losses, dim=0))
# Calculate co-training loss if enabled
if self.is_cot_loss:
cot_loss = self._calculate_co_training_loss(outputs, epoch)
else:
cot_loss = torch.zeros_like(ce_loss)
return ensemble_output, torch.stack(outputs, dim=0), ce_loss, cot_loss
elif mode in ['val', 'test']:
# Forward pass during evaluation or testing
for i in range(self.loop_factor):
output = self.proxy_clients_models[i](main_client_outputs[i])
loss = self.criterion(output, target) if self.criterion else torch.zeros(1)
outputs.append(output)
ce_losses.append(loss)
ensemble_output = self._collect_ensemble_output(outputs)
ce_loss = torch.sum(torch.stack(ce_losses, dim=0))
return ensemble_output, torch.stack(outputs, dim=0), ce_loss
else:
# Return a dummy tensor if the mode is unsupported
return torch.ones(1)
else:
raise NotImplementedError(f"Mode '{mode}' not supported for architecture '{self.arch}'.")
def _collect_ensemble_output(self, outputs):
"""Calculate the ensemble output from multiple sub-networks."""
stacked_outputs = torch.stack(outputs, dim=0)
# Apply softmax to the outputs before ensembling if specified
if self.is_ensembled_after_softmax:
if self.is_max_ensemble:
ensemble_output, _ = torch.max(F.softmax(stacked_outputs, dim=-1), dim=0)
else:
ensemble_output = torch.mean(F.softmax(stacked_outputs, dim=-1), dim=0)
else:
if self.is_max_ensemble:
ensemble_output, _ = torch.max(stacked_outputs, dim=0)
else:
ensemble_output = torch.mean(stacked_outputs, dim=0)
return ensemble_output
def _calculate_co_training_loss(self, outputs, epoch):
"""Calculate the co-training loss between outputs of different sub-networks."""
# Adjust the weight of the co-training loss during warm-up epochs
weight_now = self.cot_weight if not self.is_cot_weight_warm_up or epoch >= self.cot_weight_warm_up_epochs else max(self.cot_weight * epoch / self.cot_weight_warm_up_epochs, 0.005)
# Different methods of calculating co-training loss
if self.cot_loss_choose == 'js_divergence':
outputs_all = torch.stack(outputs, dim=0)
p_all = F.softmax(outputs_all, dim=-1)
p_mean = torch.mean(p_all, dim=0)
H_mean = (-p_mean * torch.log(p_mean)).sum(-1).mean()
H_sep = (-p_all * F.log_softmax(outputs_all, dim=-1)).sum(-1).mean()
return weight_now * (H_mean - H_sep)
elif self.cot_loss_choose == 'kl_separate':
outputs_all = torch.stack(outputs, dim=0)
outputs_r1 = torch.repeat_interleave(outputs_all, self.split_factor - 1, dim=0)
index_list = [j for i in range(self.split_factor) for j in range(self.split_factor) if j != i]
outputs_r2 = torch.index_select(outputs_all, dim=0, index=torch.tensor(index_list, dtype=torch.long).cuda())
kl_loss = F.kl_div(F.log_softmax(outputs_r1, dim=-1), F.softmax(outputs_r2, dim=-1).detach(), reduction='none')
return weight_now * kl_loss.sum(-1).mean(-1).sum() / (self.split_factor - 1)
else:
raise NotImplementedError(f"Co-training loss '{self.cot_loss_choose}' not implemented.")
def _identical_init(self):
"""Ensure identical initialization of weights for sub-networks."""
with torch.no_grad():
# Copy weights from the first model to all subsequent models
for i in range(1, self.split_factor):
for (name1, param1), (name2, param2) in zip(self.proxy_clients_models[i].named_parameters(),
self.proxy_clients_models[0].named_parameters()):
if 'weight' in name1:
param1.data.copy_(param2.data)