Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

40 lines
1.5 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import torch
class DataPrefetcher:
def __init__(self, dataloader):
# Initialize with the dataloader and create an iterator
self.dataloader = iter(dataloader)
# Create a CUDA stream for asynchronous data transfer
self.cuda_stream = torch.cuda.Stream()
# Load the next batch of data
self._load_next_batch()
def _load_next_batch(self):
try:
# Fetch the next batch from the dataloader iterator
self.batch_input, self.batch_target = next(self.dataloader)
except StopIteration:
# If no more data, set inputs and targets to None
self.batch_input, self.batch_target = None, None
return
# Transfer data to GPU asynchronously using the created CUDA stream
with torch.cuda.stream(self.cuda_stream):
self.batch_input = self.batch_input.cuda(non_blocking=True)
self.batch_target = self.batch_target.cuda(non_blocking=True)
def get_next_batch(self):
# Synchronize the current stream with the prefetching stream to ensure data is ready
torch.cuda.current_stream().wait_stream(self.cuda_stream)
# Return the preloaded batch of input and target data
current_input, current_target = self.batch_input, self.batch_target
# Preload the next batch in the background while the current batch is processed
self._load_next_batch()
return current_input, current_target