# -*- coding: utf-8 -*-
# @Author: Weisen Pan

from PIL import Image
from torch.utils.data import DataLoader, Dataset
import torch
import os

# Importing the HOME configuration
from config import HOME

class PillDataBase(Dataset):
    def __init__(self, data_dir=HOME + '/dataset_hub/pill_base', train=True, apply_transformation=None, split_factor=1):
        """
        Initialize the dataset.
        
        Args:
            data_dir (str): Directory where the dataset is stored.
            train (bool): Flag to indicate if it's a training or testing dataset.
            apply_transformation (callable): Optional apply_transformationation applied to images (e.g., resizing, normalization).
            split_factor (int): Number of times each image is split into parts for augmentation purposes.
        """
        self.train = train
        self.apply_transformation = apply_transformation
        self.split_factor = split_factor
        self.data_dir = data_dir + '/pill_base'
        self.dataset = self._load_data()

    def __len__(self):
        """Return the number of samples in the dataset."""
        return len(self.dataset)

    def _load_data(self):
        """
        Load the dataset by reading the corresponding text file (train.txt or test.txt).
        
        The dataset text file contains the image file paths and corresponding labels.
        
        Returns:
            dataset (list): List of image file paths and their respective labels.
        """
        dataset = []
        txt_path = os.path.join(self.data_dir, 'train.txt' if self.train else 'test.txt')

        with open(txt_path, 'r') as file:
            lines = file.readlines()
            for line in lines:
                # Each line contains an image path and a label separated by space
                filename, label = line.strip().split(' ')
                # Adjust the image path to the correct directory structure
                filename = filename.replace('/home/tung/Tung/research/Open-Pill/FACIL/data/Pill_Base_X', self.data_dir)
                # Append the image file path and label as an integer
                dataset.append([filename, int(label)])
                
        return dataset

    def __getitem__(self, index):
        """
        Retrieve a specific sample from the dataset at the given index.
        
        Args:
            index (int): Index of the image and label to retrieve.
        
        Returns:
            tuple: A tensor of concatenated apply_transformationed images and the corresponding label.
        """
        images = []
        image_path = self.dataset[index][0]
        label = torch.tensor(int(self.dataset[index][1]))

        # Open the image file
        image = Image.open(image_path)
        
        # Apply apply_transformationations to the image if provided and split into parts as specified by split_factor
        if self.apply_transformation:
            for _ in range(self.split_factor):
                images.append(self.apply_transformation(image))

        # Concatenate all apply_transformationed image splits into a single tensor
        return torch.cat(images, dim=0), label

if __name__ == "__main__":
    # Example of how to instantiate and use the dataset
    dataset = PillDataBase()