# -*- coding: utf-8 -*- # @Author: Weisen Pan import logging import torch import torch.nn as nn def apply_3x3_convolution(in_channels, out_channels, stride=1, groups=1, dilation=1): """ Creates a 3x3 convolutional layer with padding. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. stride (int, optional): Stride of the convolution. Default is 1. groups (int, optional): Number of blocked connections from input to output. Default is 1. dilation (int, optional): Spacing between kernel elements. Default is 1. Returns: nn.Conv2d: A 3x3 convolutional layer. """ return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) def apply_1x1_convolution(in_channels, out_channels, stride=1): """ Creates a 1x1 convolutional layer. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. stride (int, optional): Stride of the convolution. Default is 1. Returns: nn.Conv2d: A 1x1 convolutional layer. """ return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): """ A basic block for ResNet. This block consists of two convolutional layers with batch normalization and ReLU activation. Attributes: expansion (int): The expansion factor of the block. conv1 (nn.Conv2d): First convolutional layer. bn1 (nn.BatchNorm2d): First batch normalization layer. conv2 (nn.Conv2d): Second convolutional layer. bn2 (nn.BatchNorm2d): Second batch normalization layer. downsample (nn.Module): Downsample layer if input and output dimensions differ. """ expansion = 1 def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm_layer=None): """ Initializes the BasicBlock. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. stride (int, optional): Stride for the convolutional layers. Default is 1. downsample (nn.Module, optional): Downsample layer if input dimensions differ. Default is None. norm_layer (nn.Module, optional): Normalization layer. Default is BatchNorm2d. """ super(BasicBlock, self).__init__() norm_layer = norm_layer or nn.BatchNorm2d self.conv1 = apply_3x3_convolution(in_channels, out_channels, stride) self.bn1 = norm_layer(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = apply_3x3_convolution(out_channels, out_channels) self.bn2 = norm_layer(out_channels) self.downsample = downsample def forward(self, x): """ Defines the forward pass for the block. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Output tensor after applying the block. """ identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Bottleneck(nn.Module): """ A bottleneck block for ResNet. This block reduces the number of input channels before performing convolution and then expands it back. Attributes: expansion (int): The expansion factor of the block. conv1 (nn.Conv2d): First 1x1 convolutional layer. conv2 (nn.Conv2d): 3x3 convolutional layer. conv3 (nn.Conv2d): Second 1x1 convolutional layer. downsample (nn.Module): Downsample layer if input and output dimensions differ. """ expansion = 4 def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm_layer=None): """ Initializes the Bottleneck block. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. stride (int, optional): Stride for the convolutional layers. Default is 1. downsample (nn.Module, optional): Downsample layer if input dimensions differ. Default is None. norm_layer (nn.Module, optional): Normalization layer. Default is BatchNorm2d. """ super(Bottleneck, self).__init__() norm_layer = norm_layer or nn.BatchNorm2d width = int(out_channels * (64 / 64)) # Base width self.conv1 = apply_1x1_convolution(in_channels, width) self.bn1 = norm_layer(width) self.conv2 = apply_3x3_convolution(width, width, stride) self.bn2 = norm_layer(width) self.conv3 = apply_1x1_convolution(width, out_channels * self.expansion) self.bn3 = norm_layer(out_channels * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample def forward(self, x): """ Defines the forward pass for the bottleneck block. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Output tensor after applying the block. """ identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class ResNet(nn.Module): """ ResNet architecture. This class constructs a ResNet model with a specified block type and layer configuration. Attributes: conv1 (nn.Conv2d): Initial convolutional layer. bn1 (nn.BatchNorm2d): Initial batch normalization layer. layer1 (nn.Sequential): First residual layer. layer2 (nn.Sequential): Second residual layer. layer3 (nn.Sequential): Third residual layer. fc (nn.Linear): Fully connected output layer. """ def __init__(self, block, layers, num_classes=10, zero_init_residual=False, norm_layer=None): """ Initializes the ResNet architecture. Args: block (nn.Module): The block type (BasicBlock or Bottleneck). layers (list of int): Number of blocks per layer. num_classes (int, optional): Number of output classes. Default is 10. zero_init_residual (bool, optional): Whether to zero-initialize residual layers. Default is False. norm_layer (nn.Module, optional): Normalization layer. Default is BatchNorm2d. """ super(ResNet, self).__init__() norm_layer = norm_layer or nn.BatchNorm2d self.in_channels = 16 self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(self.in_channels) self.relu = nn.ReLU(inplace=True) self.layer1 = self._create_model_layer(block, 16, layers[0]) self.layer2 = self._create_model_layer(block, 32, layers[1], stride=2) self.layer3 = self._create_model_layer(block, 64, layers[2], stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(64 * block.expansion, num_classes) self._init_model_weights(zero_init_residual) def _create_model_layer(self, block, out_channels, blocks, stride=1): """ Creates a residual layer. Args: block (nn.Module): The block type. out_channels (int): Number of output channels. blocks (int): Number of blocks in the layer. stride (int, optional): Stride for the first block. Default is 1. Returns: nn.Sequential: A sequence of residual blocks. """ downsample = None if stride != 1 or self.in_channels != out_channels * block.expansion: downsample = nn.Sequential( apply_1x1_convolution(self.in_channels, out_channels * block.expansion, stride), nn.BatchNorm2d(out_channels * block.expansion), ) layers = [block(self.in_channels, out_channels, stride, downsample)] self.in_channels = out_channels * block.expansion layers.extend(block(self.in_channels, out_channels) for _ in range(1, blocks)) return nn.Sequential(*layers) def _init_model_weights(self, zero_init_residual): """ Initializes the weights of the model. Args: zero_init_residual (bool): If True, initializes residual layers to zero. """ for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) if zero_init_residual and isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) elif zero_init_residual and isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) def forward(self, x): """ Defines the forward pass of the ResNet. Args: x (torch.Tensor): Input tensor. Returns: tuple: Logits and extracted features. """ x = self.conv1(x) x = self.bn1(x) x = self.relu(x) extracted_features = x x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.avgpool(x) x_f = x.view(x.size(0), -1) logits = self.fc(x_f) return logits, extracted_features def resnet32_models_pretrained(num_classes, models_pretrained=False, path=None, **kwargs): """ Constructs a ResNet-32 model. Args: num_classes (int): Number of output classes. models_pretrained (bool, optional): If True, loads pretrained weights. Default is False. path (str, optional): Path to the pretrained weights. Default is None. Returns: ResNet: A ResNet-32 model. """ model = ResNet(BasicBlock, [5, 5, 5], num_classes=num_classes, **kwargs) if models_pretrained: model.load_state_dict(_load_models_pretrained_weights(path)) return model def resnet56_models_pretrained(num_classes, models_pretrained=False, path=None, **kwargs): """ Constructs a ResNet-56 model. Args: num_classes (int): Number of output classes. models_pretrained (bool, optional): If True, loads pretrained weights. Default is False. path (str, optional): Path to the pretrained weights. Default is None. Returns: ResNet: A ResNet-56 model. """ logging.info("Loading pretrained model from: " + str(path)) model = ResNet(Bottleneck, [6, 6, 6], num_classes=num_classes, **kwargs) if models_pretrained: model.load_state_dict(_load_models_pretrained_weights(path)) return model def _load_models_pretrained_weights(path): """ Loads pretrained weights from a checkpoint. Args: path (str): Path to the checkpoint file. Returns: dict: State dictionary with the loaded weights. """ checkpoint = torch.load(path, map_location=torch.device('cpu')) state_dict = checkpoint['state_dict'] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): new_state_dict[k.replace("module.", "")] = v return new_state_dict