313 lines
14 KiB
Python
313 lines
14 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
# Importing necessary PyTorch libraries
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
# Attempt to import model loading utilities from torch.hub; fall back to torch.utils.model_zoo if unavailable
|
|
try:
|
|
from torch.hub import load_state_dict_from_url
|
|
except ImportError:
|
|
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
|
|
|
# Specify all the modules and functions to export
|
|
__all__ = ['resnet110_sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8']
|
|
|
|
# Function for 3x3 convolution with padding
|
|
def apply_3x3_convolution(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
|
"""3x3 convolution with padding"""
|
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
|
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
|
|
|
# Function for 1x1 convolution, typically used to change the number of channels
|
|
def apply_1x1_convolution(in_planes, out_planes, stride=1):
|
|
"""1x1 convolution"""
|
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
|
|
|
# Basic Block class for ResNet (used in smaller networks like resnet_model_18/resnet_model_34)
|
|
class BasicBlock(nn.Module):
|
|
expansion = 1 # Expansion factor for output channels
|
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
|
base_width=64, dilation=1, norm_layer=None):
|
|
super(BasicBlock, self).__init__()
|
|
norm_layer = norm_layer or nn.BatchNorm2d
|
|
# BasicBlock only supports groups=1 and base_width=64
|
|
if groups != 1 or base_width != 64:
|
|
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
|
if dilation > 1:
|
|
raise NotImplementedError("BasicBlock does not support dilation greater than 1")
|
|
|
|
# Define two 3x3 convolution layers with batch normalization and ReLU activation
|
|
self.conv1 = apply_3x3_convolution(inplanes, planes, stride)
|
|
self.bn1 = norm_layer(planes)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.conv2 = apply_3x3_convolution(planes, planes)
|
|
self.bn2 = norm_layer(planes)
|
|
# Optional downsample layer for changing the dimensions
|
|
self.downsample = downsample
|
|
self.stride = stride
|
|
|
|
# Forward function defining the data flow through the block
|
|
def forward(self, x):
|
|
identity = x # Save the input for residual connection
|
|
|
|
# First convolution, batch norm, and ReLU
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
|
|
# Second convolution, batch norm
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
|
|
# Apply downsample if needed to match dimensions for residual addition
|
|
if self.downsample is not None:
|
|
identity = self.downsample(x)
|
|
|
|
# Residual connection (add identity to output)
|
|
out += identity
|
|
out = self.relu(out)
|
|
|
|
return out
|
|
|
|
# Bottleneck block class for deeper ResNet architectures (e.g., resnet_model_50/resnet_model_101)
|
|
class Bottleneck(nn.Module):
|
|
expansion = 4 # Expansion factor for output channels (output = input * 4)
|
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
|
base_width=64, dilation=1, norm_layer=None):
|
|
super(Bottleneck, self).__init__()
|
|
norm_layer = norm_layer or nn.BatchNorm2d
|
|
# Width of the block based on base_width and groups
|
|
width = int(planes * (base_width / 64.)) * groups
|
|
|
|
# Define 1x1, 3x3, and 1x1 convolutions with batch norm and ReLU activation
|
|
self.conv1 = apply_1x1_convolution(inplanes, width) # First 1x1 convolution
|
|
self.bn1 = norm_layer(width)
|
|
self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation) # Main 3x3 convolution
|
|
self.bn2 = norm_layer(width)
|
|
self.conv3 = apply_1x1_convolution(width, planes * self.expansion) # Final 1x1 convolution
|
|
self.bn3 = norm_layer(planes * self.expansion)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.downsample = downsample # Downsample layer for dimension adjustment
|
|
self.stride = stride
|
|
|
|
# Forward function defining the data flow through the bottleneck block
|
|
def forward(self, x):
|
|
identity = x # Save the input for residual connection
|
|
|
|
# First 1x1 convolution, batch norm, and ReLU
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
|
|
# Second 3x3 convolution, batch norm, and ReLU
|
|
out = self.conv2(x)
|
|
out = self.bn2(out)
|
|
out = self.relu(out)
|
|
|
|
# Third 1x1 convolution, batch norm
|
|
out = self.conv3(x)
|
|
out = self.bn3(out)
|
|
|
|
# Apply downsample if needed for residual connection
|
|
if self.downsample is not None:
|
|
identity = self.downsample(x)
|
|
|
|
# Residual connection (add identity to output)
|
|
out += identity
|
|
out = self.relu(out)
|
|
|
|
return out
|
|
|
|
# ResNet model for the main client (usually the primary model)
|
|
class PrimaryResNetClient(nn.Module):
|
|
|
|
def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True,
|
|
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
|
norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None):
|
|
super(PrimaryResNetClient, self).__init__()
|
|
norm_layer = norm_layer or nn.BatchNorm2d
|
|
self._norm_layer = norm_layer
|
|
|
|
# Initialize the number of input channels based on the dataset and split factor
|
|
inplanes_dict = {
|
|
'cifar10': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4, 32: 3},
|
|
'cifar100': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4, 32: 3},
|
|
'skin_dataset': {1: 64, 2: 44, 4: 32, 8: 24},
|
|
'pill_base': {1: 64, 2: 44, 4: 32, 8: 24},
|
|
'medical_images': {1: 64, 2: 44, 4: 32, 8: 24},
|
|
}
|
|
self.inplanes = inplanes_dict[dataset][split_factor]
|
|
|
|
# Adjust input planes if using a wide ResNet
|
|
if 'wide_resnet' in arch:
|
|
widen_factor = int(arch.split('_')[-1])
|
|
self.inplanes *= int(max(widen_factor / (split_factor ** 0.5) + 0.4, 1.0))
|
|
|
|
self.base_width = width_per_group
|
|
self.dilation = 1
|
|
replace_stride_with_dilation = replace_stride_with_dilation or [False, False, False]
|
|
|
|
# Check if replace_stride_with_dilation is properly defined
|
|
if len(replace_stride_with_dilation) != 3:
|
|
raise ValueError("replace_stride_with_dilation must either be None or a tuple with three elements")
|
|
|
|
# Initialize input layer depending on the dataset (small or large)
|
|
if dataset in ['skin_dataset', 'pill_base', 'medical_images']:
|
|
self.layer0 = self._initialize_primary_layer_large()
|
|
else:
|
|
self.layer0 = self._init_layer0_small()
|
|
|
|
# Initialize model weights
|
|
self._init_model_weights(zero_init_residual)
|
|
|
|
# Define the large initial convolution layer for large datasets
|
|
def _initialize_primary_layer_large(self):
|
|
return nn.Sequential(
|
|
nn.Conv2d(3, self.inplanes, kernel_size=3, stride=2, padding=1, bias=False),
|
|
self._norm_layer(self.inplanes),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
)
|
|
|
|
# Define the small initial convolution layer for smaller datasets like CIFAR
|
|
def _init_layer0_small(self):
|
|
return nn.Sequential(
|
|
nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
|
|
self._norm_layer(self.inplanes),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
|
|
# Function to initialize weights in the network
|
|
def _init_model_weights(self, zero_init_residual):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.normal_(m.weight, std=1e-3)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
# Initialize residual weights for Bottleneck and BasicBlock if specified
|
|
if zero_init_residual:
|
|
for m in self.modules():
|
|
if isinstance(m, Bottleneck):
|
|
nn.init.constant_(m.bn3.weight, 0)
|
|
elif isinstance(m, BasicBlock):
|
|
nn.init.constant_(m.bn2.weight, 0)
|
|
|
|
# Define forward pass for the model
|
|
def forward(self, x):
|
|
x = self.layer0(x)
|
|
return x
|
|
|
|
# ResNet model for proxy clients (usually assisting the main model)
|
|
class ResNetProxies(nn.Module):
|
|
|
|
def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True,
|
|
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
|
norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None):
|
|
super(ResNetProxies, self).__init__()
|
|
norm_layer = norm_layer or nn.BatchNorm2d
|
|
self._norm_layer = norm_layer
|
|
|
|
# Set input channels based on architecture, dataset, and split factor
|
|
self.inplanes = self._set_input_planes(arch, dataset, split_factor, width_per_group)
|
|
self.base_width = width_per_group
|
|
|
|
# Define layers of the network (layer1, layer2, layer3)
|
|
self.layer1 = self._create_model_layer(block, self.inplanes, layers[0], stride=1)
|
|
self.layer2 = self._create_model_layer(block, self.inplanes * 2, layers[1], stride=2)
|
|
self.layer3 = self._create_model_layer(block, self.inplanes * 4, layers[2], stride=2)
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Adaptive average pooling layer
|
|
self.fc = nn.Linear(self.inplanes * 4 * block.expansion, num_classes)
|
|
|
|
# Initialize model weights
|
|
self._init_model_weights(zero_init_residual)
|
|
|
|
# Set input channels based on dataset and split factor
|
|
def _set_input_planes(self, arch, dataset, split_factor, width_per_group):
|
|
inplanes_dict = {
|
|
'cifar10': {1: 16, 2: 12, 4: 8, 8: 6},
|
|
'skin_dataset': {1: 64, 2: 44, 4: 32, 8: 24},
|
|
}
|
|
inplanes = inplanes_dict[dataset][split_factor]
|
|
|
|
# Adjust input planes for wide ResNet
|
|
if 'wide_resnet' in arch:
|
|
widen_factor = float(arch.split('_')[-1])
|
|
inplanes *= int(max(widen_factor / (split_factor ** 0.5) + 0.4, 1.0))
|
|
|
|
return inplanes
|
|
|
|
# Function to create layers of the network (consisting of blocks)
|
|
def _create_model_layer(self, block, planes, blocks, stride=1):
|
|
layers = [block(self.inplanes, planes, stride)] # First block
|
|
self.inplanes = planes * block.expansion # Update input planes
|
|
for _ in range(1, blocks):
|
|
layers.append(block(self.inplanes, planes)) # Additional blocks
|
|
return nn.Sequential(*layers)
|
|
|
|
# Initialize weights in the network
|
|
def _init_model_weights(self, zero_init_residual):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.normal_(m.weight, std=1e-3)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
# Initialize residual weights for Bottleneck and BasicBlock if specified
|
|
if zero_init_residual:
|
|
for m in self.modules():
|
|
if isinstance(m, Bottleneck):
|
|
nn.init.constant_(m.bn3.weight, 0)
|
|
elif isinstance(m, BasicBlock):
|
|
nn.init.constant_(m.bn2.weight, 0)
|
|
|
|
# Define forward pass for the model
|
|
def forward(self, x):
|
|
x = self.layer1(x)
|
|
x = self.layer2(x)
|
|
x = self.layer3(x)
|
|
x = self.avgpool(x)
|
|
x = torch.flatten(x, 1)
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
# Helper function to create the main ResNet client
|
|
def _resnetsl_primary_client_(arch, block, layers, models_pretrained, progress, **kwargs):
|
|
return PrimaryResNetClient(arch, block, layers, **kwargs)
|
|
|
|
# Helper function to create the proxy ResNet client
|
|
def _resnetsl_secondary_client_(arch, block, layers, models_pretrained, progress, **kwargs):
|
|
return ResNetProxies(arch, block, layers, **kwargs)
|
|
|
|
# Function to define a ResNet-110 model for main and proxy clients
|
|
def resnet_model_110sl(models_pretrained=False, progress=True, **kwargs):
|
|
assert 'cifar' in kwargs['dataset'] # Ensure that CIFAR dataset is used
|
|
return _resnetsl_primary_client_('resnet110_sl', Bottleneck, [12, 12, 12, 12], models_pretrained, progress, **kwargs), \
|
|
_resnetsl_secondary_client_('resnet110_sl', Bottleneck, [12, 12, 12, 12], models_pretrained, progress, **kwargs)
|
|
|
|
# Function to define a Wide ResNet-50-2 model for main and proxy clients
|
|
def wide_resnetsl50_2(models_pretrained=False, progress=True, **kwargs):
|
|
kwargs['width_per_group'] = 64 * 2 # Adjust width for Wide ResNet
|
|
return _resnetsl_primary_client_('wide_resnetsl50_2', Bottleneck, [3, 4, 6, 3], models_pretrained, progress, **kwargs), \
|
|
_resnetsl_secondary_client_('wide_resnetsl50_2', Bottleneck, [3, 4, 6, 3], models_pretrained, progress, **kwargs)
|
|
|
|
# Function to define a Wide ResNet-16-8 model for main and proxy clients
|
|
def wide_resnetsl16_8(models_pretrained=False, progress=True, **kwargs):
|
|
kwargs['width_per_group'] = 64 # Adjust width for Wide ResNet
|
|
return _resnetsl_primary_client_('wide_resnetsl16_8', BasicBlock, [2, 2, 2, 2], models_pretrained, progress, **kwargs), \
|
|
_resnetsl_secondary_client_('wide_resnetsl16_8', BasicBlock, [2, 2, 2, 2], models_pretrained, progress, **kwargs)
|