# -*- coding: utf-8 -*- # @Author: Weisen Pan import os import glob from PIL import Image import torch from torch.utils.data import Dataset # Define the folder paths for training and testing datasets FOLDER_PATHS = [ '/media/skydata/alpha0012/workspace/EdgeFLite/coremodel/dataset_hub/medical_images/train_images', '/media/skydata/alpha0012/workspace/EdgeFLite/coremodel/dataset_hub/medical_images/test_images' ] # Custom dataset class inheriting from PyTorch's Dataset class class PillDataLarge(Dataset): def __init__(self, train=True, apply_transformation=None, split_factor=1): """ Initializes the dataset object. Args: - train (bool): If True, load the training dataset, otherwise load the test dataset. - apply_transformation (callable, optional): Optional apply_transformationations to be applied on an image sample. - split_factor (int): Number of times to apply the apply_transformationations to the image. """ self.train = train # Flag to determine if the dataset is for training or testing self.apply_transformation = apply_transformation # apply_transformationation to apply to the images self.split_factor = split_factor # Number of times to apply the apply_transformationation self.dataset = self._load_data() # Load the dataset def __len__(self): """ Returns the total number of samples in the dataset. """ return len(self.dataset) def _load_data(self): """ Loads the data from the dataset folders. Returns: - dataset (list): A list containing image file paths and their corresponding class IDs. """ folder_path = FOLDER_PATHS[0] if self.train else FOLDER_PATHS[1] # Use train or test folder path class_names = sorted(os.listdir(folder_path)) # Get class names from folder class_map = {name: idx for idx, name in enumerate(class_names)} # Map class names to IDs dataset = [] for class_name, class_id in class_map.items(): folder_class = os.path.join(folder_path, class_name) # Path to class folder files_jpg = glob.glob(os.path.join(folder_class, '**', '*.jpg'), recursive=True) # Get all jpg files for file_path in files_jpg: dataset.append([file_path, class_id]) # Append file path and class ID to the dataset return dataset def __getitem__(self, index): """ Returns a sample and its corresponding label from the dataset. Args: - index (int): Index of the sample. Returns: - tuple: A tuple of the image tensor and the label tensor. """ Xs = [] # List to store apply_transformationed images image_path = self.dataset[index][0] # Get image path from dataset label = torch.tensor(int(self.dataset[index][1])) # Get class label as tensor X = Image.open(image_path) # Open the image using PIL if self.apply_transformation: for _ in range(self.split_factor): Xs.append(self.apply_transformation(X)) # Apply apply_transformationation multiple times return torch.cat(Xs, dim=0), label # Concatenate all apply_transformationed images and return with the label if __name__ == "__main__": dataset = PillDataLarge() # Create an instance of the dataset print(len(dataset)) # Print the size of the dataset print(dataset[0]) # Print the first sample of the dataset