# -*- coding: utf-8 -*- # @Author: Weisen Pan import torch import torch.nn as nn # Try to import load_state_dict_from_url from torch.hub. # If it fails (due to older versions), fall back to load_url from torch.utils.model_zoo. try: from torch.hub import load_state_dict_from_url except ImportError: from torch.utils.model_zoo import load_url as load_state_dict_from_url # List of all exportable models __all__ = ['resnet110_sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8'] def apply_3x3_convolution(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding.""" return nn.Conv2d( in_planes, # Number of input channels out_planes, # Number of output channels kernel_size=3, # Size of the filter stride=stride, # Stride of the convolution padding=dilation, # Padding for the convolution groups=groups, # Group convolution bias=False, # No bias in convolution dilation=dilation # Dilation rate for dilated convolutions ) def apply_1x1_convolution(in_planes, out_planes, stride=1): """1x1 convolution.""" return nn.Conv2d( in_planes, # Number of input channels out_planes, # Number of output channels kernel_size=1, # Filter size is 1x1 stride=stride, # Stride of the convolution bias=False # No bias in convolution ) class BasicBlock(nn.Module): """Basic block for ResNet.""" expansion = 1 # No expansion in BasicBlock def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") self.conv1 = apply_3x3_convolution(inplanes, planes, stride) # First 3x3 convolution self.bn1 = norm_layer(planes) # First batch normalization self.relu = nn.ReLU(inplace=True) # ReLU activation self.conv2 = apply_3x3_convolution(planes, planes) # Second 3x3 convolution self.bn2 = norm_layer(planes) # Second batch normalization self.downsample = downsample # If there's downsampling (e.g., stride mismatch) def forward(self, x): identity = x # Preserve the input as identity for skip connection out = self.conv1(x) # Apply the first convolution out = self.bn1(out) # Apply first batch normalization out = self.relu(out) # Apply ReLU activation out = self.conv2(out) # Apply the second convolution out = self.bn2(out) # Apply second batch normalization # If downsample exists, apply it to the identity if self.downsample is not None: identity = self.downsample(x) out += identity # Add skip connection out = self.relu(out) # Final ReLU activation return out # Return the result class Bottleneck(nn.Module): """Bottleneck block for ResNet.""" expansion = 4 # Bottleneck expands the channels by a factor of 4 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.)) * groups # Width of the block # 1x1 convolution (bottleneck) self.conv1 = apply_1x1_convolution(inplanes, width) self.bn1 = norm_layer(width) # Batch normalization after 1x1 convolution # 3x3 convolution (main block) self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) # Batch normalization after 3x3 convolution # 1x1 convolution (bottleneck exit) self.conv3 = apply_1x1_convolution(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) # Batch normalization after 1x1 exit self.relu = nn.ReLU(inplace=True) # ReLU activation self.downsample = downsample # Downsampling for skip connection, if needed def forward(self, x): identity = x # Store input as identity for the skip connection out = self.conv1(x) # Apply first 1x1 convolution out = self.bn1(out) # Apply batch normalization out = self.relu(out) # Apply ReLU out = self.conv2(out) # Apply 3x3 convolution out = self.bn2(out) # Apply batch normalization out = self.relu(out) # Apply ReLU out = self.conv3(out) # Apply 1x1 convolution out = self.bn3(out) # Apply batch normalization # If downsample exists, apply it to the identity if self.downsample is not None: identity = self.downsample(x) out += identity # Add skip connection out = self.relu(out) # Final ReLU activation return out # Return the result class PrimaryResNetClient(nn.Module): """Main ResNet model for client.""" def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None): super(PrimaryResNetClient, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Global average pooling before fully connected layer # Dictionary to store input channel size based on dataset and split factor inplanes_dict = { 'cifar10': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4, 32: 3}, 'cifar100': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4}, 'skin_dataset': {1: 64, 2: 44, 4: 32, 8: 24}, 'pill_base': {1: 64, 2: 44, 4: 32, 8: 24}, 'medical_images': {1: 64, 2: 44, 4: 32, 8: 24}, } self.inplanes = inplanes_dict[dataset][split_factor] # Set initial input channels self.fc = nn.Linear(self.inplanes * 4 * block.expansion, num_classes) # Fully connected layer for classification # Initialize all layers for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=1e-3) if m.bias is not None: nn.init.constant_(m.bias, 0) # Optionally initialize the last batch normalization layer to zero if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) def _create_model_layer(self, block, planes, blocks, stride=1, dilate=False): """Create a residual layer consisting of several blocks.""" norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( apply_1x1_convolution(self.inplanes, planes * block.expansion, stride), # Adjust input size for downsampling norm_layer(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer)) # Add the first block with downsample self.inplanes = planes * block.expansion # Update inplanes for the next block for _ in range(1, blocks): layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer)) # Add the remaining blocks return nn.Sequential(*layers) # Return the stacked blocks def _forward_impl(self, x): """Implementation of the forward pass.""" x = self.layer0(x) # Initial layer extracted_features = x # Save features after the initial layer x = self.layer1(x) # First layer x = self.avgpool(x) # Global average pooling x = torch.flatten(x, 1) # Flatten the features into a 1D tensor logits = self.fc(x) # Pass through the fully connected layer return logits, extracted_features # Return logits and extracted features def forward(self, x): """Standard forward method.""" return self._forward_impl(x)