# -*- coding: utf-8 -*- # @Author: Weisen Pan import torch import torch.nn as nn __all__ = ['ResNet'] # Function to define a 3x3 convolution layer with padding def apply_3x3_convolution(in_channels, out_channels, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) # Function to define a 1x1 convolution layer def apply_1x1_convolution(in_channels, out_channels, stride=1): """1x1 convolution""" return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) # BasicBlock class for ResNet architecture class BasicBlock(nn.Module): expansion = 1 # Expansion factor def __init__(self, in_channels, out_channels, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d # Default normalization layer is BatchNorm2d if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # First convolution and batch normalization layer self.conv1 = apply_3x3_convolution(in_channels, out_channels, stride) self.bn1 = norm_layer(out_channels) self.relu = nn.ReLU(inplace=True) # ReLU activation # Second convolution and batch normalization layer self.conv2 = apply_3x3_convolution(out_channels, out_channels) self.bn2 = norm_layer(out_channels) self.downsample = downsample # If downsample is provided, use it def forward(self, x): identity = x # Keep original input as identity for residual connection # Forward pass through first convolution, batch norm, and ReLU out = self.conv1(x) out = self.bn1(out) out = self.relu(out) # Forward pass through second convolution and batch norm out = self.conv2(out) out = self.bn2(out) # Downsample the identity if downsample is provided if self.downsample is not None: identity = self.downsample(x) # Add residual connection (identity) out += identity out = self.relu(out) # Apply ReLU activation after addition return out # Bottleneck class for deeper ResNet architectures class Bottleneck(nn.Module): expansion = 4 # Expansion factor def __init__(self, in_channels, out_channels, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d # Default normalization layer is BatchNorm2d width = int(out_channels * (base_width / 64.)) * groups # Calculate width based on group size # First 1x1 convolution self.conv1 = apply_1x1_convolution(in_channels, width) self.bn1 = norm_layer(width) # Second 3x3 convolution self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) # Third 1x1 convolution to match output channels self.conv3 = apply_1x1_convolution(width, out_channels * self.expansion) self.bn3 = norm_layer(out_channels * self.expansion) self.relu = nn.ReLU(inplace=True) # ReLU activation self.downsample = downsample # Downsample if provided def forward(self, x): identity = x # Keep original input as identity for residual connection # First 1x1 convolution and ReLU out = self.conv1(x) out = self.bn1(out) out = self.relu(out) # Second 3x3 convolution and ReLU out = self.conv2(out) out = self.bn2(out) out = self.relu(out) # Third 1x1 convolution out = self.conv3(out) out = self.bn3(out) # Add downsampled identity if necessary if self.downsample is not None: identity = self.downsample(x) # Add residual connection (identity) out += identity out = self.relu(out) # Apply ReLU activation after addition return out # ResNet class to build the entire ResNet model class ResNet(nn.Module): def __init__(self, block, layers, num_classes=10, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, KD=False): super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d # Default normalization layer self._norm_layer = norm_layer self.inplanes = 16 # Initial number of channels self.dilation = 1 # Dilation factor if replace_stride_with_dilation is None: replace_stride_with_dilation = [False, False, False] # Default stride behavior if len(replace_stride_with_dilation) != 3: raise ValueError("replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) self.groups = groups # Number of groups for convolutions self.base_width = width_per_group # Base width for groups # Initial convolutional layer with 3 input channels (RGB image) self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(self.inplanes) # Batch normalization self.relu = nn.ReLU(inplace=True) # ReLU activation self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Max pooling layer self.layer1 = self._create_model_layer(block, 16, layers[0]) # First block layer self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Adaptive average pooling self.fc = nn.Linear(16 * block.expansion, num_classes) # Fully connected layer self.KD = KD # Knowledge Distillation flag for m in self.modules(): # Initialize convolutional weights using He initialization if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') # Initialize batch normalization weights elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last batch norm layer if zero_init_residual is True if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) # Helper function to create layers of blocks def _create_model_layer(self, block, planes, blocks, stride=1, dilate=False): norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( apply_1x1_convolution(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer)) return nn.Sequential(*layers) # Forward pass of the ResNet model def forward(self, x): x = self.conv1(x) # Initial convolution x = self.bn1(x) # Batch normalization x = self.relu(x) # ReLU activation extracted_features = x # Feature extraction point x = self.layer1(x) # Pass through the first layer x = self.avgpool(x) # Adaptive average pooling x_f = x.view(x.size(0), -1) # Flatten the features logits = self.fc(x_f) # Fully connected layer for classification return logits, extracted_features # Return logits and extracted features # Function to create ResNet-5 model def resnet5_56(num_classes, models_pretrained=False, path=None, **kwargs): """Constructs a ResNet-5 model.""" model = ResNet(BasicBlock, [1, 2, 2], num_classes=num_classes, **kwargs) if models_pretrained: checkpoint = torch.load(path) state_dict = checkpoint['state_dict'] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k.replace("module.", "") new_state_dict[name] = v model.load_state_dict(new_state_dict) return model # Function to create ResNet-8 model def resnet8_56(num_classes, models_pretrained=False, path=None, **kwargs): """Constructs a ResNet-8 model.""" model = ResNet(Bottleneck, [2, 2, 2], num_classes=num_classes, **kwargs) if models_pretrained: checkpoint = torch.load(path) state_dict = checkpoint['state_dict'] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k.replace("module.", "") new_state_dict[name] = v model.load_state_dict(new_state_dict) return model