Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

313 lines
14 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
# Importing necessary PyTorch libraries
import torch
import torch.nn as nn
# Attempt to import model loading utilities from torch.hub; fall back to torch.utils.model_zoo if unavailable
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
# Specify all the modules and functions to export
__all__ = ['resnet110_sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8']
# Function for 3x3 convolution with padding
def apply_3x3_convolution(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
# Function for 1x1 convolution, typically used to change the number of channels
def apply_1x1_convolution(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
# Basic Block class for ResNet (used in smaller networks like resnet_model_18/resnet_model_34)
class BasicBlock(nn.Module):
expansion = 1 # Expansion factor for output channels
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
norm_layer = norm_layer or nn.BatchNorm2d
# BasicBlock only supports groups=1 and base_width=64
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("BasicBlock does not support dilation greater than 1")
# Define two 3x3 convolution layers with batch normalization and ReLU activation
self.conv1 = apply_3x3_convolution(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = apply_3x3_convolution(planes, planes)
self.bn2 = norm_layer(planes)
# Optional downsample layer for changing the dimensions
self.downsample = downsample
self.stride = stride
# Forward function defining the data flow through the block
def forward(self, x):
identity = x # Save the input for residual connection
# First convolution, batch norm, and ReLU
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
# Second convolution, batch norm
out = self.conv2(out)
out = self.bn2(out)
# Apply downsample if needed to match dimensions for residual addition
if self.downsample is not None:
identity = self.downsample(x)
# Residual connection (add identity to output)
out += identity
out = self.relu(out)
return out
# Bottleneck block class for deeper ResNet architectures (e.g., resnet_model_50/resnet_model_101)
class Bottleneck(nn.Module):
expansion = 4 # Expansion factor for output channels (output = input * 4)
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
norm_layer = norm_layer or nn.BatchNorm2d
# Width of the block based on base_width and groups
width = int(planes * (base_width / 64.)) * groups
# Define 1x1, 3x3, and 1x1 convolutions with batch norm and ReLU activation
self.conv1 = apply_1x1_convolution(inplanes, width) # First 1x1 convolution
self.bn1 = norm_layer(width)
self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation) # Main 3x3 convolution
self.bn2 = norm_layer(width)
self.conv3 = apply_1x1_convolution(width, planes * self.expansion) # Final 1x1 convolution
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample # Downsample layer for dimension adjustment
self.stride = stride
# Forward function defining the data flow through the bottleneck block
def forward(self, x):
identity = x # Save the input for residual connection
# First 1x1 convolution, batch norm, and ReLU
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
# Second 3x3 convolution, batch norm, and ReLU
out = self.conv2(x)
out = self.bn2(out)
out = self.relu(out)
# Third 1x1 convolution, batch norm
out = self.conv3(x)
out = self.bn3(out)
# Apply downsample if needed for residual connection
if self.downsample is not None:
identity = self.downsample(x)
# Residual connection (add identity to output)
out += identity
out = self.relu(out)
return out
# ResNet model for the main client (usually the primary model)
class PrimaryResNetClient(nn.Module):
def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None):
super(PrimaryResNetClient, self).__init__()
norm_layer = norm_layer or nn.BatchNorm2d
self._norm_layer = norm_layer
# Initialize the number of input channels based on the dataset and split factor
inplanes_dict = {
'cifar10': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4, 32: 3},
'cifar100': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4, 32: 3},
'skin_dataset': {1: 64, 2: 44, 4: 32, 8: 24},
'pill_base': {1: 64, 2: 44, 4: 32, 8: 24},
'medical_images': {1: 64, 2: 44, 4: 32, 8: 24},
}
self.inplanes = inplanes_dict[dataset][split_factor]
# Adjust input planes if using a wide ResNet
if 'wide_resnet' in arch:
widen_factor = int(arch.split('_')[-1])
self.inplanes *= int(max(widen_factor / (split_factor ** 0.5) + 0.4, 1.0))
self.base_width = width_per_group
self.dilation = 1
replace_stride_with_dilation = replace_stride_with_dilation or [False, False, False]
# Check if replace_stride_with_dilation is properly defined
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation must either be None or a tuple with three elements")
# Initialize input layer depending on the dataset (small or large)
if dataset in ['skin_dataset', 'pill_base', 'medical_images']:
self.layer0 = self._initialize_primary_layer_large()
else:
self.layer0 = self._init_layer0_small()
# Initialize model weights
self._init_model_weights(zero_init_residual)
# Define the large initial convolution layer for large datasets
def _initialize_primary_layer_large(self):
return nn.Sequential(
nn.Conv2d(3, self.inplanes, kernel_size=3, stride=2, padding=1, bias=False),
self._norm_layer(self.inplanes),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
# Define the small initial convolution layer for smaller datasets like CIFAR
def _init_layer0_small(self):
return nn.Sequential(
nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
self._norm_layer(self.inplanes),
nn.ReLU(inplace=True),
)
# Function to initialize weights in the network
def _init_model_weights(self, zero_init_residual):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=1e-3)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
# Initialize residual weights for Bottleneck and BasicBlock if specified
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
# Define forward pass for the model
def forward(self, x):
x = self.layer0(x)
return x
# ResNet model for proxy clients (usually assisting the main model)
class ResNetProxies(nn.Module):
def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None):
super(ResNetProxies, self).__init__()
norm_layer = norm_layer or nn.BatchNorm2d
self._norm_layer = norm_layer
# Set input channels based on architecture, dataset, and split factor
self.inplanes = self._set_input_planes(arch, dataset, split_factor, width_per_group)
self.base_width = width_per_group
# Define layers of the network (layer1, layer2, layer3)
self.layer1 = self._create_model_layer(block, self.inplanes, layers[0], stride=1)
self.layer2 = self._create_model_layer(block, self.inplanes * 2, layers[1], stride=2)
self.layer3 = self._create_model_layer(block, self.inplanes * 4, layers[2], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Adaptive average pooling layer
self.fc = nn.Linear(self.inplanes * 4 * block.expansion, num_classes)
# Initialize model weights
self._init_model_weights(zero_init_residual)
# Set input channels based on dataset and split factor
def _set_input_planes(self, arch, dataset, split_factor, width_per_group):
inplanes_dict = {
'cifar10': {1: 16, 2: 12, 4: 8, 8: 6},
'skin_dataset': {1: 64, 2: 44, 4: 32, 8: 24},
}
inplanes = inplanes_dict[dataset][split_factor]
# Adjust input planes for wide ResNet
if 'wide_resnet' in arch:
widen_factor = float(arch.split('_')[-1])
inplanes *= int(max(widen_factor / (split_factor ** 0.5) + 0.4, 1.0))
return inplanes
# Function to create layers of the network (consisting of blocks)
def _create_model_layer(self, block, planes, blocks, stride=1):
layers = [block(self.inplanes, planes, stride)] # First block
self.inplanes = planes * block.expansion # Update input planes
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes)) # Additional blocks
return nn.Sequential(*layers)
# Initialize weights in the network
def _init_model_weights(self, zero_init_residual):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=1e-3)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
# Initialize residual weights for Bottleneck and BasicBlock if specified
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
# Define forward pass for the model
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# Helper function to create the main ResNet client
def _resnetsl_primary_client_(arch, block, layers, models_pretrained, progress, **kwargs):
return PrimaryResNetClient(arch, block, layers, **kwargs)
# Helper function to create the proxy ResNet client
def _resnetsl_secondary_client_(arch, block, layers, models_pretrained, progress, **kwargs):
return ResNetProxies(arch, block, layers, **kwargs)
# Function to define a ResNet-110 model for main and proxy clients
def resnet_model_110sl(models_pretrained=False, progress=True, **kwargs):
assert 'cifar' in kwargs['dataset'] # Ensure that CIFAR dataset is used
return _resnetsl_primary_client_('resnet110_sl', Bottleneck, [12, 12, 12, 12], models_pretrained, progress, **kwargs), \
_resnetsl_secondary_client_('resnet110_sl', Bottleneck, [12, 12, 12, 12], models_pretrained, progress, **kwargs)
# Function to define a Wide ResNet-50-2 model for main and proxy clients
def wide_resnetsl50_2(models_pretrained=False, progress=True, **kwargs):
kwargs['width_per_group'] = 64 * 2 # Adjust width for Wide ResNet
return _resnetsl_primary_client_('wide_resnetsl50_2', Bottleneck, [3, 4, 6, 3], models_pretrained, progress, **kwargs), \
_resnetsl_secondary_client_('wide_resnetsl50_2', Bottleneck, [3, 4, 6, 3], models_pretrained, progress, **kwargs)
# Function to define a Wide ResNet-16-8 model for main and proxy clients
def wide_resnetsl16_8(models_pretrained=False, progress=True, **kwargs):
kwargs['width_per_group'] = 64 # Adjust width for Wide ResNet
return _resnetsl_primary_client_('wide_resnetsl16_8', BasicBlock, [2, 2, 2, 2], models_pretrained, progress, **kwargs), \
_resnetsl_secondary_client_('wide_resnetsl16_8', BasicBlock, [2, 2, 2, 2], models_pretrained, progress, **kwargs)