# -*- coding: utf-8 -*- # @Author: Weisen Pan import torch import torch.nn as nn class PassThrough(nn.Module): """ A placeholder module that simply returns the input tensor unchanged. """ def __init__(self, **kwargs): super(PassThrough, self).__init__() def forward(self, input_tensor): return input_tensor class LayerNormalization2D(nn.Module): """ A custom layer normalization module for 2D inputs (typically used for convolutional layers). It optionally applies learned scaling (weight) and shifting (bias) parameters. Arguments: epsilon: A small value to avoid division by zero. use_weight: Whether to learn and apply weight parameters. use_bias: Whether to learn and apply bias parameters. """ def __init__(self, epsilon=1e-05, use_weight=True, use_bias=True, **kwargs): super(LayerNormalization2D, self).__init__() self.epsilon = epsilon self.use_weight = use_weight self.use_bias = use_bias def forward(self, input_tensor): # Initialize weight and bias parameters if they are not nn.Parameter instances if (not isinstance(self.use_weight, nn.parameter.Parameter) and not isinstance(self.use_bias, nn.parameter.Parameter) and (self.use_weight or self.use_bias)): self._initialize_parameters(input_tensor) # Apply layer normalization return nn.functional.layer_norm(input_tensor, input_tensor.shape[1:], weight=self.use_weight, bias=self.use_bias, eps=self.epsilon) def _initialize_parameters(self, input_tensor): """ Initialize weight and bias parameters for layer normalization. Arguments: input_tensor: The input tensor to the normalization layer. """ channels, height, width = input_tensor.shape[1:] param_shape = [channels, height, width] # Initialize weight parameter if applicable if self.use_weight: self.use_weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) else: self.register_parameter('use_weight', None) # Initialize bias parameter if applicable if self.use_bias: self.use_bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) else: self.register_parameter('use_bias', None) class NormalizationLayer(nn.Module): """ A flexible normalization layer that supports different types of normalization (batch, group, layer, instance, or none). This class is a wrapper that selects the appropriate normalization technique based on the norm_type argument. Arguments: norm_type: The type of normalization to apply ('batch', 'group', 'layer', 'instance', or 'none'). epsilon: A small value to avoid division by zero (Default: 1e-05). momentum: Momentum for updating running statistics (Default: 0.1, applicable for batch norm). use_weight: Whether to learn weight parameters (Default: True). use_bias: Whether to learn bias parameters (Default: True). track_stats: Whether to track running statistics (Default: True, applicable for batch norm). group_norm_groups: Number of groups to use for group normalization (Default: 32). """ def __init__(self, norm_type='batch', epsilon=1e-05, momentum=0.1, use_weight=True, use_bias=True, track_stats=True, group_norm_groups=32, **kwargs): super(NormalizationLayer, self).__init__() if norm_type not in ['batch', 'group', 'layer', 'instance', 'none']: raise ValueError('Unsupported norm_type: {}. Supported options: ' '"batch" | "group" | "layer" | "instance" | "none".'.format(norm_type)) self.norm_type = norm_type self.epsilon = epsilon self.momentum = momentum self.use_weight = use_weight self.use_bias = use_bias self.affine = self.use_weight and self.use_bias # Check if affine apply_transformationation is needed self.track_stats = track_stats self.group_norm_groups = group_norm_groups def forward(self, num_features): """ Select and apply the appropriate normalization technique based on the norm_type. Arguments: num_features: The number of input channels or features. Returns: A normalization layer corresponding to the norm_type. """ if self.norm_type == 'batch': # Apply Batch Normalization normalizer = nn.BatchNorm2d(num_features=num_features, eps=self.epsilon, momentum=self.momentum, affine=self.affine, track_running_stats=self.track_stats) elif self.norm_type == 'group': # Apply Group Normalization normalizer = nn.GroupNorm(self.group_norm_groups, num_features, eps=self.epsilon, affine=self.affine) elif self.norm_type == 'layer': # Apply Layer Normalization normalizer = LayerNormalization2D(epsilon=self.epsilon, use_weight=self.use_weight, use_bias=self.use_bias) elif self.norm_type == 'instance': # Apply Instance Normalization normalizer = nn.InstanceNorm2d(num_features, eps=self.epsilon, affine=self.affine) else: # No normalization applied, just pass the input through normalizer = PassThrough() return normalizer