Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

61 lines
2.5 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import torch
import numpy as np
@torch.no_grad()
def combine_mixup_data(x, y, alpha=1.0, use_cuda=True):
"""
Perform the mixup operation on input data.
Args:
x (Tensor): Input features, typically from the dataset.
y (Tensor): Input labels corresponding to the features.
alpha (float): Mixup interpolation coefficient. The default value is 1.0.
A higher value results in more mixing between samples.
use_cuda (bool): Boolean flag to indicate whether CUDA should be used if available.
Returns:
mixed_x (Tensor): Mixed input features, a linear combination of x and a permuted version of x.
y_a (Tensor): Original input labels corresponding to x.
y_b (Tensor): Permuted input labels corresponding to the mixed samples.
lam (float): The lambda value used for interpolation between samples.
"""
# Draw lambda value from the Beta distribution if alpha > 0, otherwise set lam to 1 (no mixup)
lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
# Get the batch size from the input tensor
batch_size = x.size(0)
# Generate a random permutation of indices for mixing
# Use CUDA if available, otherwise stick with CPU
index = torch.randperm(batch_size).cuda() if use_cuda else torch.randperm(batch_size)
# Mix the features of the original and permuted samples using the lambda value
mixed_x = lam * x + (1 - lam) * x[index, :]
# Assign original and permuted labels to y_a and y_b, respectively
y_a, y_b = y, y[index]
# Return mixed features, original and permuted labels, and the lambda value
return mixed_x, y_a, y_b, lam
def mixup_loss_criterion(criterion, pred, y_a, y_b, lam):
"""
Compute the mixup loss using the provided criterion.
Args:
criterion (function): The loss function used to compute the error (e.g., CrossEntropyLoss).
pred (Tensor): The model predictions, typically the output of a neural network.
y_a (Tensor): The original labels corresponding to the original input features.
y_b (Tensor): The permuted labels corresponding to the mixed input features.
lam (float): The lambda value for mixup, used to interpolate between the two losses.
Returns:
loss (Tensor): The final mixup loss, computed as a weighted sum of the two losses.
"""
# Compute the mixup loss by combining the loss from the original and permuted labels
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)