61 lines
2.5 KiB
Python
61 lines
2.5 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import torch
|
|
import numpy as np
|
|
|
|
@torch.no_grad()
|
|
def combine_mixup_data(x, y, alpha=1.0, use_cuda=True):
|
|
"""
|
|
Perform the mixup operation on input data.
|
|
|
|
Args:
|
|
x (Tensor): Input features, typically from the dataset.
|
|
y (Tensor): Input labels corresponding to the features.
|
|
alpha (float): Mixup interpolation coefficient. The default value is 1.0.
|
|
A higher value results in more mixing between samples.
|
|
use_cuda (bool): Boolean flag to indicate whether CUDA should be used if available.
|
|
|
|
Returns:
|
|
mixed_x (Tensor): Mixed input features, a linear combination of x and a permuted version of x.
|
|
y_a (Tensor): Original input labels corresponding to x.
|
|
y_b (Tensor): Permuted input labels corresponding to the mixed samples.
|
|
lam (float): The lambda value used for interpolation between samples.
|
|
"""
|
|
# Draw lambda value from the Beta distribution if alpha > 0, otherwise set lam to 1 (no mixup)
|
|
lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
|
|
|
|
# Get the batch size from the input tensor
|
|
batch_size = x.size(0)
|
|
|
|
# Generate a random permutation of indices for mixing
|
|
# Use CUDA if available, otherwise stick with CPU
|
|
index = torch.randperm(batch_size).cuda() if use_cuda else torch.randperm(batch_size)
|
|
|
|
# Mix the features of the original and permuted samples using the lambda value
|
|
mixed_x = lam * x + (1 - lam) * x[index, :]
|
|
|
|
# Assign original and permuted labels to y_a and y_b, respectively
|
|
y_a, y_b = y, y[index]
|
|
|
|
# Return mixed features, original and permuted labels, and the lambda value
|
|
return mixed_x, y_a, y_b, lam
|
|
|
|
|
|
def mixup_loss_criterion(criterion, pred, y_a, y_b, lam):
|
|
"""
|
|
Compute the mixup loss using the provided criterion.
|
|
|
|
Args:
|
|
criterion (function): The loss function used to compute the error (e.g., CrossEntropyLoss).
|
|
pred (Tensor): The model predictions, typically the output of a neural network.
|
|
y_a (Tensor): The original labels corresponding to the original input features.
|
|
y_b (Tensor): The permuted labels corresponding to the mixed input features.
|
|
lam (float): The lambda value for mixup, used to interpolate between the two losses.
|
|
|
|
Returns:
|
|
loss (Tensor): The final mixup loss, computed as a weighted sum of the two losses.
|
|
"""
|
|
# Compute the mixup loss by combining the loss from the original and permuted labels
|
|
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
|