Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

213 lines
8.9 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import ensemble
from .mixup import mixup_loss_criterion, combine_mixup_data
from . import resnet, resnet_sl
__all__ = ['coremodel']
def _retrieve_network(arch='wide_resnet28_10'):
"""
Get the network architecture based on the provided name.
Args:
arch (str): Name of the architecture.
Returns:
Callable: The network class or function corresponding to the given architecture.
"""
networks = {
'wide_resnet28_10': resnet.wide_resnet28_10,
'wide_resnet16_8': resnet.wide_resnet16_8,
'resnet110': resnet.resnet110,
'wide_resnet_model_50_2': resnet.wide_resnet_model_50_2
}
if arch not in networks:
raise ValueError(f"Architecture {arch} is not supported.")
return networks[arch]
class coremodel(nn.Module):
def __init__(self, args, norm_layer=None, criterion=None, progress=True):
"""
Initialize the coremodel model with multiple sub-networks.
Args:
args (argparse.Namespace): Configuration arguments.
norm_layer (callable, optional): Normalization layer.
criterion (callable, optional): Loss function.
progress (bool): Whether to show progress.
"""
super(coremodel, self).__init__()
# Configuration parameters
self.split_factor = args.split_factor
self.arch = args.arch
self.loop_factor = args.loop_factor
self.is_train_sep = args.is_train_sep
self.epochs = args.epochs
self.criterion = criterion
self.is_diff_data_train = args.is_diff_data_train
self.is_mixup = args.is_mixup
self.mix_alpha = args.mix_alpha
# Define model architectures
valid_archs = [
'resnet_model_50', 'resnet_model_101', 'resnet_model_152', 'resnet_model_200',
'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d',
'resnext29_8x64d', 'resnext29_16x64d', 'resnet110', 'resnet164',
'wide_resnet16_8', 'wide_resnet16_12', 'wide_resnet28_10', 'wide_resnet40_10',
'wide_resnet52_8', 'wide_resnet_model_50_2', 'wide_resnet_model_50_3', 'wide_resnet_model_101_2'
]
if self.arch not in valid_archs:
raise NotImplementedError(f"Architecture {self.arch} is not implemented.")
model_args = {
'num_classes': args.num_classes,
'norm_layer': norm_layer,
'dataset': args.dataset,
'split_factor': self.split_factor,
'output_stride': args.output_stride
}
# Initialize multiple sub-models based on the loop factor
self.models = nn.ModuleList([_retrieve_network(self.arch)(models_models_pretrained=args.models_models_pretrained, **model_args) for _ in range(self.loop_factor)])
if args.is_identical_init:
print("INFO: Using identical initialization.")
self._identical_init()
# Ensemble settings
self.is_ensembled_loss = args.is_ensembled_loss if self.split_factor > 1 else False
self.ensembled_loss_weight = args.ensembled_loss_weight
self.is_ensembled_after_softmax = args.is_ensembled_after_softmax if self.split_factor > 1 else False
self.is_max_ensemble = args.is_max_ensemble if self.split_factor > 1 else False
# Co-training settings
self.is_cot_loss = args.is_cot_loss if self.split_factor > 1 else False
self.cot_weight = args.cot_weight
self.is_cot_weight_warm_up = args.is_cot_weight_warm_up
self.cot_weight_warm_up_epochs = args.cot_weight_warm_up_epochs
self.cot_loss_choose = args.cot_loss_choose
print(f"INFO: The co-training loss is {self.cot_loss_choose}.")
self.num_classes = args.num_classes
def forward(self, x, target=None, mode='train', epoch=0, streams=None):
"""
Forward pass through the model with optional mixup and co-training loss.
Args:
x (Tensor): Input tensor.
target (Tensor, optional): Target tensor for loss computation.
mode (str): Mode of operation ('train', 'val', or 'test').
epoch (int): Current epoch.
streams (optional): Additional data streams.
Returns:
Tuple:
- ensemble_output (Tensor): The ensemble output of shape [batch_size, num_classes].
- outputs (Tensor): Stack of individual outputs of shape [split_factor, batch_size, num_classes].
- ce_loss (Tensor): Sum of cross-entropy losses for each model.
- cot_loss (Tensor): Co-training loss if applicable.
"""
outputs, ce_losses = [], []
if 'train' in mode:
if self.is_mixup:
x, y_a, y_b, lam = combine_mixup_data(x, target, alpha=self.mix_alpha)
# Split input data based on the loop factor
all_x = torch.chunk(x, chunks=self.loop_factor, dim=1) if self.is_diff_data_train else [x]
for i in range(self.loop_factor:
x_input = all_x[i]
output = self.models[i](x_input)
loss = mixup_loss_criterion(self.criterion, output, y_a, y_b, lam) if self.is_mixup else self.criterion(output, target)
outputs.append(output)
ce_losses.append(loss)
elif mode in ['val', 'test']:
for i in range(self.loop_factor:
output = self.models[i](x)
loss = self.criterion(output, target) if self.criterion else torch.zeros(1)
outputs.append(output)
ce_losses.append(loss)
else:
return torch.ones(1), None, None, None
# Calculate ensemble output and losses
ensemble_output = self._collect_ensemble_output(outputs)
ce_loss = torch.sum(torch.stack(ce_losses))
if mode in ['val', 'test']:
return ensemble_output, torch.stack(outputs, dim=0), ce_loss
if self.is_cot_loss:
cot_loss = self._calculate_co_training_loss(outputs, self.cot_loss_choose, epoch)
else:
cot_loss = torch.zeros_like(ce_loss)
return ensemble_output, torch.stack(outputs, dim=0), ce_loss, cot_loss
def _collect_ensemble_output(self, outputs):
"""
Calculate the ensemble output from a list of tensors.
Args:
outputs (list of tensors): A list where each tensor has shape [batch_size, num_classes].
Returns:
Tensor: The ensemble output with shape [batch_size, num_classes].
"""
stacked_outputs = torch.stack(outputs, dim=0)
if self.is_ensembled_after_softmax:
softmax_outputs = F.softmax(stacked_outputs, dim=-1)
if self.is_max_ensemble:
ensemble_output, _ = torch.max(softmax_outputs, dim=0)
else:
ensemble_output = torch.mean(softmax_outputs, dim=0)
else:
if self.is_max_ensemble:
ensemble_output, _ = torch.max(stacked_outputs, dim=0)
else:
ensemble_output = torch.mean(stacked_outputs, dim=0)
return ensemble_output
def _calculate_co_training_loss(self, outputs, loss_choose, epoch=0):
"""
Calculate the co-training loss between outputs of different networks.
Args:
outputs (list of tensors): A list where each tensor has shape [batch_size, num_classes].
loss_choose (str): Type of co-training loss to compute ('js_divergence' or 'kl_seperate').
epoch (int): Current epoch.
Returns:
Tensor: The computed co-training loss.
"""
weight_now = self.cot_weight
if self.is_cot_weight_warm_up and epoch < self.cot_weight_warm_up_epochs:
weight_now = max(self.cot_weight * epoch / self.cot_weight_warm_up_epochs, 0.005)
stacked_outputs = torch.stack(outputs, dim=0)
if loss_choose == 'js_divergence':
p_all = F.softmax(stacked_outputs, dim=-1)
p_mean = torch.mean(p_all, dim=0)
H_mean = (-p_mean * torch.log(p_mean + 1e-8)).sum(-1).mean()
H_sep = (-p_all * F.log_softmax(stacked_outputs, dim=-1)).sum(-1).mean()
cot_loss = weight_now * (H_mean - H_sep)
elif loss_choose == 'kl_seperate':
outputs_r1 = torch.repeat_interleave(stacked_outputs, self.split_factor - 1, dim=0)
index_list = [j for i in range(self.split_factor) for j in range(self.split_factor) if j != i]
outputs_r2 = torch.index_select(stacked_outputs, dim=0, index=torch.tensor(index_list, dtype=torch.long, device=stacked_outputs.device))
kl_loss = F.kl_div(F.log_softmax(outputs_r1, dim=-1), F.softmax(outputs_r2,