327 lines
12 KiB
Python
327 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import logging
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
def apply_3x3_convolution(in_channels, out_channels, stride=1, groups=1, dilation=1):
|
|
"""
|
|
Creates a 3x3 convolutional layer with padding.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
out_channels (int): Number of output channels.
|
|
stride (int, optional): Stride of the convolution. Default is 1.
|
|
groups (int, optional): Number of blocked connections from input to output. Default is 1.
|
|
dilation (int, optional): Spacing between kernel elements. Default is 1.
|
|
|
|
Returns:
|
|
nn.Conv2d: A 3x3 convolutional layer.
|
|
"""
|
|
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
|
|
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
|
|
|
def apply_1x1_convolution(in_channels, out_channels, stride=1):
|
|
"""
|
|
Creates a 1x1 convolutional layer.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
out_channels (int): Number of output channels.
|
|
stride (int, optional): Stride of the convolution. Default is 1.
|
|
|
|
Returns:
|
|
nn.Conv2d: A 1x1 convolutional layer.
|
|
"""
|
|
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
|
|
|
|
class BasicBlock(nn.Module):
|
|
"""
|
|
A basic block for ResNet.
|
|
|
|
This block consists of two convolutional layers with batch normalization and ReLU activation.
|
|
|
|
Attributes:
|
|
expansion (int): The expansion factor of the block.
|
|
conv1 (nn.Conv2d): First convolutional layer.
|
|
bn1 (nn.BatchNorm2d): First batch normalization layer.
|
|
conv2 (nn.Conv2d): Second convolutional layer.
|
|
bn2 (nn.BatchNorm2d): Second batch normalization layer.
|
|
downsample (nn.Module): Downsample layer if input and output dimensions differ.
|
|
"""
|
|
expansion = 1
|
|
|
|
def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm_layer=None):
|
|
"""
|
|
Initializes the BasicBlock.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
out_channels (int): Number of output channels.
|
|
stride (int, optional): Stride for the convolutional layers. Default is 1.
|
|
downsample (nn.Module, optional): Downsample layer if input dimensions differ. Default is None.
|
|
norm_layer (nn.Module, optional): Normalization layer. Default is BatchNorm2d.
|
|
"""
|
|
super(BasicBlock, self).__init__()
|
|
norm_layer = norm_layer or nn.BatchNorm2d
|
|
self.conv1 = apply_3x3_convolution(in_channels, out_channels, stride)
|
|
self.bn1 = norm_layer(out_channels)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.conv2 = apply_3x3_convolution(out_channels, out_channels)
|
|
self.bn2 = norm_layer(out_channels)
|
|
self.downsample = downsample
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Defines the forward pass for the block.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: Output tensor after applying the block.
|
|
"""
|
|
identity = x
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
|
|
if self.downsample is not None:
|
|
identity = self.downsample(x)
|
|
|
|
out += identity
|
|
out = self.relu(out)
|
|
return out
|
|
|
|
class Bottleneck(nn.Module):
|
|
"""
|
|
A bottleneck block for ResNet.
|
|
|
|
This block reduces the number of input channels before performing convolution and then expands it back.
|
|
|
|
Attributes:
|
|
expansion (int): The expansion factor of the block.
|
|
conv1 (nn.Conv2d): First 1x1 convolutional layer.
|
|
conv2 (nn.Conv2d): 3x3 convolutional layer.
|
|
conv3 (nn.Conv2d): Second 1x1 convolutional layer.
|
|
downsample (nn.Module): Downsample layer if input and output dimensions differ.
|
|
"""
|
|
expansion = 4
|
|
|
|
def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm_layer=None):
|
|
"""
|
|
Initializes the Bottleneck block.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
out_channels (int): Number of output channels.
|
|
stride (int, optional): Stride for the convolutional layers. Default is 1.
|
|
downsample (nn.Module, optional): Downsample layer if input dimensions differ. Default is None.
|
|
norm_layer (nn.Module, optional): Normalization layer. Default is BatchNorm2d.
|
|
"""
|
|
super(Bottleneck, self).__init__()
|
|
norm_layer = norm_layer or nn.BatchNorm2d
|
|
width = int(out_channels * (64 / 64)) # Base width
|
|
self.conv1 = apply_1x1_convolution(in_channels, width)
|
|
self.bn1 = norm_layer(width)
|
|
self.conv2 = apply_3x3_convolution(width, width, stride)
|
|
self.bn2 = norm_layer(width)
|
|
self.conv3 = apply_1x1_convolution(width, out_channels * self.expansion)
|
|
self.bn3 = norm_layer(out_channels * self.expansion)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.downsample = downsample
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Defines the forward pass for the bottleneck block.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: Output tensor after applying the block.
|
|
"""
|
|
identity = x
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
out = self.relu(out)
|
|
out = self.conv3(out)
|
|
out = self.bn3(out)
|
|
|
|
if self.downsample is not None:
|
|
identity = self.downsample(x)
|
|
|
|
out += identity
|
|
out = self.relu(out)
|
|
return out
|
|
|
|
class ResNet(nn.Module):
|
|
"""
|
|
ResNet architecture.
|
|
|
|
This class constructs a ResNet model with a specified block type and layer configuration.
|
|
|
|
Attributes:
|
|
conv1 (nn.Conv2d): Initial convolutional layer.
|
|
bn1 (nn.BatchNorm2d): Initial batch normalization layer.
|
|
layer1 (nn.Sequential): First residual layer.
|
|
layer2 (nn.Sequential): Second residual layer.
|
|
layer3 (nn.Sequential): Third residual layer.
|
|
fc (nn.Linear): Fully connected output layer.
|
|
"""
|
|
def __init__(self, block, layers, num_classes=10, zero_init_residual=False, norm_layer=None):
|
|
"""
|
|
Initializes the ResNet architecture.
|
|
|
|
Args:
|
|
block (nn.Module): The block type (BasicBlock or Bottleneck).
|
|
layers (list of int): Number of blocks per layer.
|
|
num_classes (int, optional): Number of output classes. Default is 10.
|
|
zero_init_residual (bool, optional): Whether to zero-initialize residual layers. Default is False.
|
|
norm_layer (nn.Module, optional): Normalization layer. Default is BatchNorm2d.
|
|
"""
|
|
super(ResNet, self).__init__()
|
|
norm_layer = norm_layer or nn.BatchNorm2d
|
|
self.in_channels = 16
|
|
|
|
self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
|
self.bn1 = nn.BatchNorm2d(self.in_channels)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
self.layer1 = self._create_model_layer(block, 16, layers[0])
|
|
self.layer2 = self._create_model_layer(block, 32, layers[1], stride=2)
|
|
self.layer3 = self._create_model_layer(block, 64, layers[2], stride=2)
|
|
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.fc = nn.Linear(64 * block.expansion, num_classes)
|
|
|
|
self._init_model_weights(zero_init_residual)
|
|
|
|
def _create_model_layer(self, block, out_channels, blocks, stride=1):
|
|
"""
|
|
Creates a residual layer.
|
|
|
|
Args:
|
|
block (nn.Module): The block type.
|
|
out_channels (int): Number of output channels.
|
|
blocks (int): Number of blocks in the layer.
|
|
stride (int, optional): Stride for the first block. Default is 1.
|
|
|
|
Returns:
|
|
nn.Sequential: A sequence of residual blocks.
|
|
"""
|
|
downsample = None
|
|
if stride != 1 or self.in_channels != out_channels * block.expansion:
|
|
downsample = nn.Sequential(
|
|
apply_1x1_convolution(self.in_channels, out_channels * block.expansion, stride),
|
|
nn.BatchNorm2d(out_channels * block.expansion),
|
|
)
|
|
|
|
layers = [block(self.in_channels, out_channels, stride, downsample)]
|
|
self.in_channels = out_channels * block.expansion
|
|
layers.extend(block(self.in_channels, out_channels) for _ in range(1, blocks))
|
|
return nn.Sequential(*layers)
|
|
|
|
def _init_model_weights(self, zero_init_residual):
|
|
"""
|
|
Initializes the weights of the model.
|
|
|
|
Args:
|
|
zero_init_residual (bool): If True, initializes residual layers to zero.
|
|
"""
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
if zero_init_residual and isinstance(m, Bottleneck):
|
|
nn.init.constant_(m.bn3.weight, 0)
|
|
elif zero_init_residual and isinstance(m, BasicBlock):
|
|
nn.init.constant_(m.bn2.weight, 0)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Defines the forward pass of the ResNet.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor.
|
|
|
|
Returns:
|
|
tuple: Logits and extracted features.
|
|
"""
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
extracted_features = x
|
|
|
|
x = self.layer1(x)
|
|
x = self.layer2(x)
|
|
x = self.layer3(x)
|
|
|
|
x = self.avgpool(x)
|
|
x_f = x.view(x.size(0), -1)
|
|
logits = self.fc(x_f)
|
|
return logits, extracted_features
|
|
|
|
def resnet32_models_pretrained(num_classes, models_pretrained=False, path=None, **kwargs):
|
|
"""
|
|
Constructs a ResNet-32 model.
|
|
|
|
Args:
|
|
num_classes (int): Number of output classes.
|
|
models_pretrained (bool, optional): If True, loads pretrained weights. Default is False.
|
|
path (str, optional): Path to the pretrained weights. Default is None.
|
|
|
|
Returns:
|
|
ResNet: A ResNet-32 model.
|
|
"""
|
|
model = ResNet(BasicBlock, [5, 5, 5], num_classes=num_classes, **kwargs)
|
|
if models_pretrained:
|
|
model.load_state_dict(_load_models_pretrained_weights(path))
|
|
return model
|
|
|
|
def resnet56_models_pretrained(num_classes, models_pretrained=False, path=None, **kwargs):
|
|
"""
|
|
Constructs a ResNet-56 model.
|
|
|
|
Args:
|
|
num_classes (int): Number of output classes.
|
|
models_pretrained (bool, optional): If True, loads pretrained weights. Default is False.
|
|
path (str, optional): Path to the pretrained weights. Default is None.
|
|
|
|
Returns:
|
|
ResNet: A ResNet-56 model.
|
|
"""
|
|
logging.info("Loading pretrained model from: " + str(path))
|
|
model = ResNet(Bottleneck, [6, 6, 6], num_classes=num_classes, **kwargs)
|
|
if models_pretrained:
|
|
model.load_state_dict(_load_models_pretrained_weights(path))
|
|
return model
|
|
|
|
def _load_models_pretrained_weights(path):
|
|
"""
|
|
Loads pretrained weights from a checkpoint.
|
|
|
|
Args:
|
|
path (str): Path to the checkpoint file.
|
|
|
|
Returns:
|
|
dict: State dictionary with the loaded weights.
|
|
"""
|
|
checkpoint = torch.load(path, map_location=torch.device('cpu'))
|
|
state_dict = checkpoint['state_dict']
|
|
from collections import OrderedDict
|
|
new_state_dict = OrderedDict()
|
|
|
|
for k, v in state_dict.items():
|
|
new_state_dict[k.replace("module.", "")] = v
|
|
|
|
return new_state_dict
|