# -*- coding: utf-8 -*- # @Author: Weisen Pan # Importing necessary PyTorch libraries import torch import torch.nn as nn # Attempt to import model loading utilities from torch.hub; fall back to torch.utils.model_zoo if unavailable try: from torch.hub import load_state_dict_from_url except ImportError: from torch.utils.model_zoo import load_url as load_state_dict_from_url # Specify all the modules and functions to export __all__ = ['resnet110_sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8'] # Function for 3x3 convolution with padding def apply_3x3_convolution(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) # Function for 1x1 convolution, typically used to change the number of channels def apply_1x1_convolution(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) # Basic Block class for ResNet (used in smaller networks like resnet_model_18/resnet_model_34) class BasicBlock(nn.Module): expansion = 1 # Expansion factor for output channels def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(BasicBlock, self).__init__() norm_layer = norm_layer or nn.BatchNorm2d # BasicBlock only supports groups=1 and base_width=64 if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError("BasicBlock does not support dilation greater than 1") # Define two 3x3 convolution layers with batch normalization and ReLU activation self.conv1 = apply_3x3_convolution(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = apply_3x3_convolution(planes, planes) self.bn2 = norm_layer(planes) # Optional downsample layer for changing the dimensions self.downsample = downsample self.stride = stride # Forward function defining the data flow through the block def forward(self, x): identity = x # Save the input for residual connection # First convolution, batch norm, and ReLU out = self.conv1(x) out = self.bn1(out) out = self.relu(out) # Second convolution, batch norm out = self.conv2(out) out = self.bn2(out) # Apply downsample if needed to match dimensions for residual addition if self.downsample is not None: identity = self.downsample(x) # Residual connection (add identity to output) out += identity out = self.relu(out) return out # Bottleneck block class for deeper ResNet architectures (e.g., resnet_model_50/resnet_model_101) class Bottleneck(nn.Module): expansion = 4 # Expansion factor for output channels (output = input * 4) def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(Bottleneck, self).__init__() norm_layer = norm_layer or nn.BatchNorm2d # Width of the block based on base_width and groups width = int(planes * (base_width / 64.)) * groups # Define 1x1, 3x3, and 1x1 convolutions with batch norm and ReLU activation self.conv1 = apply_1x1_convolution(inplanes, width) # First 1x1 convolution self.bn1 = norm_layer(width) self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation) # Main 3x3 convolution self.bn2 = norm_layer(width) self.conv3 = apply_1x1_convolution(width, planes * self.expansion) # Final 1x1 convolution self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample # Downsample layer for dimension adjustment self.stride = stride # Forward function defining the data flow through the bottleneck block def forward(self, x): identity = x # Save the input for residual connection # First 1x1 convolution, batch norm, and ReLU out = self.conv1(x) out = self.bn1(out) out = self.relu(out) # Second 3x3 convolution, batch norm, and ReLU out = self.conv2(x) out = self.bn2(out) out = self.relu(out) # Third 1x1 convolution, batch norm out = self.conv3(x) out = self.bn3(out) # Apply downsample if needed for residual connection if self.downsample is not None: identity = self.downsample(x) # Residual connection (add identity to output) out += identity out = self.relu(out) return out # ResNet model for the main client (usually the primary model) class PrimaryResNetClient(nn.Module): def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None): super(PrimaryResNetClient, self).__init__() norm_layer = norm_layer or nn.BatchNorm2d self._norm_layer = norm_layer # Initialize the number of input channels based on the dataset and split factor inplanes_dict = { 'cifar10': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4, 32: 3}, 'cifar100': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4, 32: 3}, 'skin_dataset': {1: 64, 2: 44, 4: 32, 8: 24}, 'pill_base': {1: 64, 2: 44, 4: 32, 8: 24}, 'medical_images': {1: 64, 2: 44, 4: 32, 8: 24}, } self.inplanes = inplanes_dict[dataset][split_factor] # Adjust input planes if using a wide ResNet if 'wide_resnet' in arch: widen_factor = int(arch.split('_')[-1]) self.inplanes *= int(max(widen_factor / (split_factor ** 0.5) + 0.4, 1.0)) self.base_width = width_per_group self.dilation = 1 replace_stride_with_dilation = replace_stride_with_dilation or [False, False, False] # Check if replace_stride_with_dilation is properly defined if len(replace_stride_with_dilation) != 3: raise ValueError("replace_stride_with_dilation must either be None or a tuple with three elements") # Initialize input layer depending on the dataset (small or large) if dataset in ['skin_dataset', 'pill_base', 'medical_images']: self.layer0 = self._initialize_primary_layer_large() else: self.layer0 = self._init_layer0_small() # Initialize model weights self._init_model_weights(zero_init_residual) # Define the large initial convolution layer for large datasets def _initialize_primary_layer_large(self): return nn.Sequential( nn.Conv2d(3, self.inplanes, kernel_size=3, stride=2, padding=1, bias=False), self._norm_layer(self.inplanes), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1) ) # Define the small initial convolution layer for smaller datasets like CIFAR def _init_layer0_small(self): return nn.Sequential( nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), self._norm_layer(self.inplanes), nn.ReLU(inplace=True), ) # Function to initialize weights in the network def _init_model_weights(self, zero_init_residual): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=1e-3) if m.bias is not None: nn.init.constant_(m.bias, 0) # Initialize residual weights for Bottleneck and BasicBlock if specified if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) # Define forward pass for the model def forward(self, x): x = self.layer0(x) return x # ResNet model for proxy clients (usually assisting the main model) class ResNetProxies(nn.Module): def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None): super(ResNetProxies, self).__init__() norm_layer = norm_layer or nn.BatchNorm2d self._norm_layer = norm_layer # Set input channels based on architecture, dataset, and split factor self.inplanes = self._set_input_planes(arch, dataset, split_factor, width_per_group) self.base_width = width_per_group # Define layers of the network (layer1, layer2, layer3) self.layer1 = self._create_model_layer(block, self.inplanes, layers[0], stride=1) self.layer2 = self._create_model_layer(block, self.inplanes * 2, layers[1], stride=2) self.layer3 = self._create_model_layer(block, self.inplanes * 4, layers[2], stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Adaptive average pooling layer self.fc = nn.Linear(self.inplanes * 4 * block.expansion, num_classes) # Initialize model weights self._init_model_weights(zero_init_residual) # Set input channels based on dataset and split factor def _set_input_planes(self, arch, dataset, split_factor, width_per_group): inplanes_dict = { 'cifar10': {1: 16, 2: 12, 4: 8, 8: 6}, 'skin_dataset': {1: 64, 2: 44, 4: 32, 8: 24}, } inplanes = inplanes_dict[dataset][split_factor] # Adjust input planes for wide ResNet if 'wide_resnet' in arch: widen_factor = float(arch.split('_')[-1]) inplanes *= int(max(widen_factor / (split_factor ** 0.5) + 0.4, 1.0)) return inplanes # Function to create layers of the network (consisting of blocks) def _create_model_layer(self, block, planes, blocks, stride=1): layers = [block(self.inplanes, planes, stride)] # First block self.inplanes = planes * block.expansion # Update input planes for _ in range(1, blocks): layers.append(block(self.inplanes, planes)) # Additional blocks return nn.Sequential(*layers) # Initialize weights in the network def _init_model_weights(self, zero_init_residual): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=1e-3) if m.bias is not None: nn.init.constant_(m.bias, 0) # Initialize residual weights for Bottleneck and BasicBlock if specified if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) # Define forward pass for the model def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x # Helper function to create the main ResNet client def _resnetsl_primary_client_(arch, block, layers, models_pretrained, progress, **kwargs): return PrimaryResNetClient(arch, block, layers, **kwargs) # Helper function to create the proxy ResNet client def _resnetsl_secondary_client_(arch, block, layers, models_pretrained, progress, **kwargs): return ResNetProxies(arch, block, layers, **kwargs) # Function to define a ResNet-110 model for main and proxy clients def resnet_model_110sl(models_pretrained=False, progress=True, **kwargs): assert 'cifar' in kwargs['dataset'] # Ensure that CIFAR dataset is used return _resnetsl_primary_client_('resnet110_sl', Bottleneck, [12, 12, 12, 12], models_pretrained, progress, **kwargs), \ _resnetsl_secondary_client_('resnet110_sl', Bottleneck, [12, 12, 12, 12], models_pretrained, progress, **kwargs) # Function to define a Wide ResNet-50-2 model for main and proxy clients def wide_resnetsl50_2(models_pretrained=False, progress=True, **kwargs): kwargs['width_per_group'] = 64 * 2 # Adjust width for Wide ResNet return _resnetsl_primary_client_('wide_resnetsl50_2', Bottleneck, [3, 4, 6, 3], models_pretrained, progress, **kwargs), \ _resnetsl_secondary_client_('wide_resnetsl50_2', Bottleneck, [3, 4, 6, 3], models_pretrained, progress, **kwargs) # Function to define a Wide ResNet-16-8 model for main and proxy clients def wide_resnetsl16_8(models_pretrained=False, progress=True, **kwargs): kwargs['width_per_group'] = 64 # Adjust width for Wide ResNet return _resnetsl_primary_client_('wide_resnetsl16_8', BasicBlock, [2, 2, 2, 2], models_pretrained, progress, **kwargs), \ _resnetsl_secondary_client_('wide_resnetsl16_8', BasicBlock, [2, 2, 2, 2], models_pretrained, progress, **kwargs)