Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

327 lines
12 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import logging
import torch
import torch.nn as nn
def apply_3x3_convolution(in_channels, out_channels, stride=1, groups=1, dilation=1):
"""
Creates a 3x3 convolutional layer with padding.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
stride (int, optional): Stride of the convolution. Default is 1.
groups (int, optional): Number of blocked connections from input to output. Default is 1.
dilation (int, optional): Spacing between kernel elements. Default is 1.
Returns:
nn.Conv2d: A 3x3 convolutional layer.
"""
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def apply_1x1_convolution(in_channels, out_channels, stride=1):
"""
Creates a 1x1 convolutional layer.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
stride (int, optional): Stride of the convolution. Default is 1.
Returns:
nn.Conv2d: A 1x1 convolutional layer.
"""
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
"""
A basic block for ResNet.
This block consists of two convolutional layers with batch normalization and ReLU activation.
Attributes:
expansion (int): The expansion factor of the block.
conv1 (nn.Conv2d): First convolutional layer.
bn1 (nn.BatchNorm2d): First batch normalization layer.
conv2 (nn.Conv2d): Second convolutional layer.
bn2 (nn.BatchNorm2d): Second batch normalization layer.
downsample (nn.Module): Downsample layer if input and output dimensions differ.
"""
expansion = 1
def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm_layer=None):
"""
Initializes the BasicBlock.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
stride (int, optional): Stride for the convolutional layers. Default is 1.
downsample (nn.Module, optional): Downsample layer if input dimensions differ. Default is None.
norm_layer (nn.Module, optional): Normalization layer. Default is BatchNorm2d.
"""
super(BasicBlock, self).__init__()
norm_layer = norm_layer or nn.BatchNorm2d
self.conv1 = apply_3x3_convolution(in_channels, out_channels, stride)
self.bn1 = norm_layer(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = apply_3x3_convolution(out_channels, out_channels)
self.bn2 = norm_layer(out_channels)
self.downsample = downsample
def forward(self, x):
"""
Defines the forward pass for the block.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after applying the block.
"""
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
"""
A bottleneck block for ResNet.
This block reduces the number of input channels before performing convolution and then expands it back.
Attributes:
expansion (int): The expansion factor of the block.
conv1 (nn.Conv2d): First 1x1 convolutional layer.
conv2 (nn.Conv2d): 3x3 convolutional layer.
conv3 (nn.Conv2d): Second 1x1 convolutional layer.
downsample (nn.Module): Downsample layer if input and output dimensions differ.
"""
expansion = 4
def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm_layer=None):
"""
Initializes the Bottleneck block.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
stride (int, optional): Stride for the convolutional layers. Default is 1.
downsample (nn.Module, optional): Downsample layer if input dimensions differ. Default is None.
norm_layer (nn.Module, optional): Normalization layer. Default is BatchNorm2d.
"""
super(Bottleneck, self).__init__()
norm_layer = norm_layer or nn.BatchNorm2d
width = int(out_channels * (64 / 64)) # Base width
self.conv1 = apply_1x1_convolution(in_channels, width)
self.bn1 = norm_layer(width)
self.conv2 = apply_3x3_convolution(width, width, stride)
self.bn2 = norm_layer(width)
self.conv3 = apply_1x1_convolution(width, out_channels * self.expansion)
self.bn3 = norm_layer(out_channels * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
"""
Defines the forward pass for the bottleneck block.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after applying the block.
"""
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
"""
ResNet architecture.
This class constructs a ResNet model with a specified block type and layer configuration.
Attributes:
conv1 (nn.Conv2d): Initial convolutional layer.
bn1 (nn.BatchNorm2d): Initial batch normalization layer.
layer1 (nn.Sequential): First residual layer.
layer2 (nn.Sequential): Second residual layer.
layer3 (nn.Sequential): Third residual layer.
fc (nn.Linear): Fully connected output layer.
"""
def __init__(self, block, layers, num_classes=10, zero_init_residual=False, norm_layer=None):
"""
Initializes the ResNet architecture.
Args:
block (nn.Module): The block type (BasicBlock or Bottleneck).
layers (list of int): Number of blocks per layer.
num_classes (int, optional): Number of output classes. Default is 10.
zero_init_residual (bool, optional): Whether to zero-initialize residual layers. Default is False.
norm_layer (nn.Module, optional): Normalization layer. Default is BatchNorm2d.
"""
super(ResNet, self).__init__()
norm_layer = norm_layer or nn.BatchNorm2d
self.in_channels = 16
self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(self.in_channels)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._create_model_layer(block, 16, layers[0])
self.layer2 = self._create_model_layer(block, 32, layers[1], stride=2)
self.layer3 = self._create_model_layer(block, 64, layers[2], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(64 * block.expansion, num_classes)
self._init_model_weights(zero_init_residual)
def _create_model_layer(self, block, out_channels, blocks, stride=1):
"""
Creates a residual layer.
Args:
block (nn.Module): The block type.
out_channels (int): Number of output channels.
blocks (int): Number of blocks in the layer.
stride (int, optional): Stride for the first block. Default is 1.
Returns:
nn.Sequential: A sequence of residual blocks.
"""
downsample = None
if stride != 1 or self.in_channels != out_channels * block.expansion:
downsample = nn.Sequential(
apply_1x1_convolution(self.in_channels, out_channels * block.expansion, stride),
nn.BatchNorm2d(out_channels * block.expansion),
)
layers = [block(self.in_channels, out_channels, stride, downsample)]
self.in_channels = out_channels * block.expansion
layers.extend(block(self.in_channels, out_channels) for _ in range(1, blocks))
return nn.Sequential(*layers)
def _init_model_weights(self, zero_init_residual):
"""
Initializes the weights of the model.
Args:
zero_init_residual (bool): If True, initializes residual layers to zero.
"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if zero_init_residual and isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif zero_init_residual and isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def forward(self, x):
"""
Defines the forward pass of the ResNet.
Args:
x (torch.Tensor): Input tensor.
Returns:
tuple: Logits and extracted features.
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
extracted_features = x
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x_f = x.view(x.size(0), -1)
logits = self.fc(x_f)
return logits, extracted_features
def resnet32_models_pretrained(num_classes, models_pretrained=False, path=None, **kwargs):
"""
Constructs a ResNet-32 model.
Args:
num_classes (int): Number of output classes.
models_pretrained (bool, optional): If True, loads pretrained weights. Default is False.
path (str, optional): Path to the pretrained weights. Default is None.
Returns:
ResNet: A ResNet-32 model.
"""
model = ResNet(BasicBlock, [5, 5, 5], num_classes=num_classes, **kwargs)
if models_pretrained:
model.load_state_dict(_load_models_pretrained_weights(path))
return model
def resnet56_models_pretrained(num_classes, models_pretrained=False, path=None, **kwargs):
"""
Constructs a ResNet-56 model.
Args:
num_classes (int): Number of output classes.
models_pretrained (bool, optional): If True, loads pretrained weights. Default is False.
path (str, optional): Path to the pretrained weights. Default is None.
Returns:
ResNet: A ResNet-56 model.
"""
logging.info("Loading pretrained model from: " + str(path))
model = ResNet(Bottleneck, [6, 6, 6], num_classes=num_classes, **kwargs)
if models_pretrained:
model.load_state_dict(_load_models_pretrained_weights(path))
return model
def _load_models_pretrained_weights(path):
"""
Loads pretrained weights from a checkpoint.
Args:
path (str): Path to the checkpoint file.
Returns:
dict: State dictionary with the loaded weights.
"""
checkpoint = torch.load(path, map_location=torch.device('cpu'))
state_dict = checkpoint['state_dict']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_state_dict[k.replace("module.", "")] = v
return new_state_dict