40 lines
1.5 KiB
Python
40 lines
1.5 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import torch
|
|
|
|
class DataPrefetcher:
|
|
def __init__(self, dataloader):
|
|
# Initialize with the dataloader and create an iterator
|
|
self.dataloader = iter(dataloader)
|
|
# Create a CUDA stream for asynchronous data transfer
|
|
self.cuda_stream = torch.cuda.Stream()
|
|
# Load the next batch of data
|
|
self._load_next_batch()
|
|
|
|
def _load_next_batch(self):
|
|
try:
|
|
# Fetch the next batch from the dataloader iterator
|
|
self.batch_input, self.batch_target = next(self.dataloader)
|
|
except StopIteration:
|
|
# If no more data, set inputs and targets to None
|
|
self.batch_input, self.batch_target = None, None
|
|
return
|
|
|
|
# Transfer data to GPU asynchronously using the created CUDA stream
|
|
with torch.cuda.stream(self.cuda_stream):
|
|
self.batch_input = self.batch_input.cuda(non_blocking=True)
|
|
self.batch_target = self.batch_target.cuda(non_blocking=True)
|
|
|
|
def get_next_batch(self):
|
|
# Synchronize the current stream with the prefetching stream to ensure data is ready
|
|
torch.cuda.current_stream().wait_stream(self.cuda_stream)
|
|
|
|
# Return the preloaded batch of input and target data
|
|
current_input, current_target = self.batch_input, self.batch_target
|
|
|
|
# Preload the next batch in the background while the current batch is processed
|
|
self._load_next_batch()
|
|
|
|
return current_input, current_target
|