# -*- coding: utf-8 -*- # @Author: Weisen Pan import torch class DataPrefetcher: def __init__(self, dataloader): # Initialize with the dataloader and create an iterator self.dataloader = iter(dataloader) # Create a CUDA stream for asynchronous data transfer self.cuda_stream = torch.cuda.Stream() # Load the next batch of data self._load_next_batch() def _load_next_batch(self): try: # Fetch the next batch from the dataloader iterator self.batch_input, self.batch_target = next(self.dataloader) except StopIteration: # If no more data, set inputs and targets to None self.batch_input, self.batch_target = None, None return # Transfer data to GPU asynchronously using the created CUDA stream with torch.cuda.stream(self.cuda_stream): self.batch_input = self.batch_input.cuda(non_blocking=True) self.batch_target = self.batch_target.cuda(non_blocking=True) def get_next_batch(self): # Synchronize the current stream with the prefetching stream to ensure data is ready torch.cuda.current_stream().wait_stream(self.cuda_stream) # Return the preloaded batch of input and target data current_input, current_target = self.batch_input, self.batch_target # Preload the next batch in the background while the current batch is processed self._load_next_batch() return current_input, current_target