# -*- coding: utf-8 -*-
# @Author: Weisen Pan

import os
import torch
import torch.utils.data as data

# VisionDataset is a custom dataset class inheriting from PyTorch's Dataset class.
# It handles the initialization and representation of a vision-related dataset,
# including optional apply_transformationation of input data and targets.
class VisionDataset(data.Dataset):
    _repr_indent = 4  # Defines the indentation level for dataset representation

    def __init__(self, root, apply_transformations=None, apply_transformation=None, target_apply_transformation=None):
        # Initializes the dataset by setting root directory and optional apply_transformationations
        # If root is a string, expand any user directory shortcuts like "~"
        self.root = os.path.expanduser(root) if isinstance(root, str) else root

        # Check if either 'apply_transformations' or 'apply_transformation/target_apply_transformation' is provided (but not both)
        has_apply_transformations = apply_transformations is not None
        has_separate_apply_transformation = apply_transformation is not None or target_apply_transformation is not None

        if has_apply_transformations and has_separate_apply_transformation:
            raise ValueError("Only one of 'apply_transformations' or 'apply_transformation/target_apply_transformation' can be provided.")

        # Set apply_transformationations
        self.apply_transformation = apply_transformation
        self.target_apply_transformation = target_apply_transformation

        # If separate apply_transformations are provided, wrap them in a StandardTransform
        if has_separate_apply_transformation:
            apply_transformations = StandardTransform(apply_transformation, target_apply_transformation)
        self.apply_transformations = apply_transformations

    # Placeholder for the method to retrieve an item by index
    def __getitem__(self, index):
        raise NotImplementedError

    # Placeholder for the method to return dataset length
    def __len__(self):
        raise NotImplementedError

    # Representation of the dataset including number of datapoints, root directory, and apply_transformations
    def __repr__(self):
        head = f"Dataset {self.__class__.__name__}"
        body = [f"Number of datapoints: {self.__len__()}"]
        if self.root is not None:
            body.append(f"Root location: {self.root}")
        body += self.extra_repr().splitlines()  # Include any additional representation details
        if hasattr(self, "apply_transformations") and self.apply_transformations is not None:
            body.append(repr(self.apply_transformations))  # Include apply_transformationation details if applicable
        lines = [head] + [" " * self._repr_indent + line for line in body]
        return '\n'.join(lines)

    # Utility to format the representation of the apply_transformation and target_apply_transformation attributes
    def _format_apply_transformation_repr(self, apply_transformation, head):
        lines = apply_transformation.__repr__().splitlines()
        return [f"{head}{lines[0]}"] + [f"{' ' * len(head)}{line}" for line in lines[1:]]

    # Hook for adding extra dataset-specific information in the representation
    def extra_repr(self):
        return ""


# StandardTransform class handles the application of the apply_transformation and target_apply_transformation
# during dataset iteration or data loading.
class StandardTransform:
    def __init__(self, apply_transformation=None, target_apply_transformation=None):
        # Initialize with optional input and target apply_transformationations
        self.apply_transformation = apply_transformation
        self.target_apply_transformation = target_apply_transformation

    # Calls the appropriate apply_transformations on the input and target when invoked
    def __call__(self, input, target):
        if self.apply_transformation is not None:
            input = self.apply_transformation(input)
        if self.target_apply_transformation is not None:
            target = self.target_apply_transformation(target)
        return input, target

    # Utility to format the apply_transformationation representation
    def _format_apply_transformation_repr(self, apply_transformation, head):
        lines = apply_transformation.__repr__().splitlines()
        return [f"{head}{lines[0]}"] + [f"{' ' * len(head)}{line}" for line in lines[1:]]

    # Representation of the StandardTransform including both input and target apply_transformationations
    def __repr__(self):
        body = [self.__class__.__name__]
        if self.apply_transformation is not None:
            body += self._format_apply_transformation_repr(self.apply_transformation, "apply_transformation: ")
        if self.target_apply_transformation is not None:
            body += self._format_apply_transformation_repr(self.target_apply_transformation, "Target apply_transformation: ")

        return '\n'.join(body)