# -*- coding: utf-8 -*- # @Author: Weisen Pan import torch import numpy as np @torch.no_grad() def combine_mixup_data(x, y, alpha=1.0, use_cuda=True): """ Perform the mixup operation on input data. Args: x (Tensor): Input features, typically from the dataset. y (Tensor): Input labels corresponding to the features. alpha (float): Mixup interpolation coefficient. The default value is 1.0. A higher value results in more mixing between samples. use_cuda (bool): Boolean flag to indicate whether CUDA should be used if available. Returns: mixed_x (Tensor): Mixed input features, a linear combination of x and a permuted version of x. y_a (Tensor): Original input labels corresponding to x. y_b (Tensor): Permuted input labels corresponding to the mixed samples. lam (float): The lambda value used for interpolation between samples. """ # Draw lambda value from the Beta distribution if alpha > 0, otherwise set lam to 1 (no mixup) lam = np.random.beta(alpha, alpha) if alpha > 0 else 1 # Get the batch size from the input tensor batch_size = x.size(0) # Generate a random permutation of indices for mixing # Use CUDA if available, otherwise stick with CPU index = torch.randperm(batch_size).cuda() if use_cuda else torch.randperm(batch_size) # Mix the features of the original and permuted samples using the lambda value mixed_x = lam * x + (1 - lam) * x[index, :] # Assign original and permuted labels to y_a and y_b, respectively y_a, y_b = y, y[index] # Return mixed features, original and permuted labels, and the lambda value return mixed_x, y_a, y_b, lam def mixup_loss_criterion(criterion, pred, y_a, y_b, lam): """ Compute the mixup loss using the provided criterion. Args: criterion (function): The loss function used to compute the error (e.g., CrossEntropyLoss). pred (Tensor): The model predictions, typically the output of a neural network. y_a (Tensor): The original labels corresponding to the original input features. y_b (Tensor): The permuted labels corresponding to the mixed input features. lam (float): The lambda value for mixup, used to interpolate between the two losses. Returns: loss (Tensor): The final mixup loss, computed as a weighted sum of the two losses. """ # Compute the mixup loss by combining the loss from the original and permuted labels return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)