# -*- coding: utf-8 -*- # @Author: Weisen Pan from PIL import Image from torch.utils.data import DataLoader, Dataset import torch import os # Importing the HOME configuration from config import HOME class PillDataBase(Dataset): def __init__(self, data_dir=HOME + '/dataset_hub/pill_base', train=True, apply_transformation=None, split_factor=1): """ Initialize the dataset. Args: data_dir (str): Directory where the dataset is stored. train (bool): Flag to indicate if it's a training or testing dataset. apply_transformation (callable): Optional apply_transformationation applied to images (e.g., resizing, normalization). split_factor (int): Number of times each image is split into parts for augmentation purposes. """ self.train = train self.apply_transformation = apply_transformation self.split_factor = split_factor self.data_dir = data_dir + '/pill_base' self.dataset = self._load_data() def __len__(self): """Return the number of samples in the dataset.""" return len(self.dataset) def _load_data(self): """ Load the dataset by reading the corresponding text file (train.txt or test.txt). The dataset text file contains the image file paths and corresponding labels. Returns: dataset (list): List of image file paths and their respective labels. """ dataset = [] txt_path = os.path.join(self.data_dir, 'train.txt' if self.train else 'test.txt') with open(txt_path, 'r') as file: lines = file.readlines() for line in lines: # Each line contains an image path and a label separated by space filename, label = line.strip().split(' ') # Adjust the image path to the correct directory structure filename = filename.replace('/home/tung/Tung/research/Open-Pill/FACIL/data/Pill_Base_X', self.data_dir) # Append the image file path and label as an integer dataset.append([filename, int(label)]) return dataset def __getitem__(self, index): """ Retrieve a specific sample from the dataset at the given index. Args: index (int): Index of the image and label to retrieve. Returns: tuple: A tensor of concatenated apply_transformationed images and the corresponding label. """ images = [] image_path = self.dataset[index][0] label = torch.tensor(int(self.dataset[index][1])) # Open the image file image = Image.open(image_path) # Apply apply_transformationations to the image if provided and split into parts as specified by split_factor if self.apply_transformation: for _ in range(self.split_factor): images.append(self.apply_transformation(image)) # Concatenate all apply_transformationed image splits into a single tensor return torch.cat(images, dim=0), label if __name__ == "__main__": # Example of how to instantiate and use the dataset dataset = PillDataBase()