# -*- coding: utf-8 -*- # @Author: Weisen Pan import torch import torch.nn as nn import torch.nn.functional as F # Define the SmoothEntropyLoss class, which inherits from nn.Module class SmoothEntropyLoss(nn.Module): def __init__(self, smoothing=0.1, reduction='mean'): # Initialize the parent class (nn.Module) and set the smoothing factor and reduction method super(SmoothEntropyLoss, self).__init__() self.smoothing = smoothing # Label smoothing factor self.reduction_method = reduction # Reduction method to apply to the loss def forward(self, predictions, targets): # Ensure that the batch sizes of predictions and targets match if predictions.shape[0] != targets.shape[0]: raise ValueError(f"Batch size of predictions ({predictions.shape[0]}) does not match targets ({targets.shape[0]}).") # Ensure that the predictions tensor has at least 2 dimensions (batch_size x num_classes) if predictions.dim() < 2: raise ValueError(f"Predictions should have at least 2 dimensions, got {predictions.dim()}.") # Get the number of classes from the last dimension of predictions (num_classes) num_classes = predictions.size(-1) # Convert targets (class indices) to one-hot encoded format target_one_hot = F.one_hot(targets, num_classes=num_classes).type_as(predictions) # Apply label smoothing: smooth the one-hot encoded targets by distributing some probability mass across all classes smooth_targets = target_one_hot * (1.0 - self.smoothing) + (self.smoothing / num_classes) # Compute the log probabilities of predictions using softmax (log-softmax for numerical stability) log_probabilities = F.log_softmax(predictions, dim=-1) # Compute the per-sample loss by multiplying log probabilities with the smoothed targets and summing across classes loss_per_sample = -torch.sum(log_probabilities * smooth_targets, dim=-1) # Apply the specified reduction method to the computed loss if self.reduction_method == 'none': return loss_per_sample # Return the unreduced loss for each sample elif self.reduction_method == 'sum': return torch.sum(loss_per_sample) # Return the sum of the losses over all samples elif self.reduction_method == 'mean': return torch.mean(loss_per_sample) # Return the mean loss over all samples else: raise ValueError(f"Invalid reduction option: {self.reduction_method}. Expected 'none', 'sum', or 'mean'.")