# -*- coding: utf-8 -*-
# @Author: Weisen Pan

import torch

class DataPrefetcher:
    def __init__(self, dataloader):
        # Initialize with the dataloader and create an iterator
        self.dataloader = iter(dataloader)
        # Create a CUDA stream for asynchronous data transfer
        self.cuda_stream = torch.cuda.Stream()
        # Load the next batch of data
        self._load_next_batch()

    def _load_next_batch(self):
        try:
            # Fetch the next batch from the dataloader iterator
            self.batch_input, self.batch_target = next(self.dataloader)
        except StopIteration:
            # If no more data, set inputs and targets to None
            self.batch_input, self.batch_target = None, None
            return

        # Transfer data to GPU asynchronously using the created CUDA stream
        with torch.cuda.stream(self.cuda_stream):
            self.batch_input = self.batch_input.cuda(non_blocking=True)
            self.batch_target = self.batch_target.cuda(non_blocking=True)

    def get_next_batch(self):
        # Synchronize the current stream with the prefetching stream to ensure data is ready
        torch.cuda.current_stream().wait_stream(self.cuda_stream)
        
        # Return the preloaded batch of input and target data
        current_input, current_target = self.batch_input, self.batch_target

        # Preload the next batch in the background while the current batch is processed
        self._load_next_batch()

        return current_input, current_target