84 lines
3.4 KiB
Python
84 lines
3.4 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import os
|
|
import glob
|
|
from PIL import Image
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
# Define the folder paths for training and testing datasets
|
|
FOLDER_PATHS = [
|
|
'/media/skydata/alpha0012/workspace/EdgeFLite/coremodel/dataset_hub/medical_images/train_images',
|
|
'/media/skydata/alpha0012/workspace/EdgeFLite/coremodel/dataset_hub/medical_images/test_images'
|
|
]
|
|
|
|
# Custom dataset class inheriting from PyTorch's Dataset class
|
|
class PillDataLarge(Dataset):
|
|
def __init__(self, train=True, apply_transformation=None, split_factor=1):
|
|
"""
|
|
Initializes the dataset object.
|
|
|
|
Args:
|
|
- train (bool): If True, load the training dataset, otherwise load the test dataset.
|
|
- apply_transformation (callable, optional): Optional apply_transformationations to be applied on an image sample.
|
|
- split_factor (int): Number of times to apply the apply_transformationations to the image.
|
|
"""
|
|
self.train = train # Flag to determine if the dataset is for training or testing
|
|
self.apply_transformation = apply_transformation # apply_transformationation to apply to the images
|
|
self.split_factor = split_factor # Number of times to apply the apply_transformationation
|
|
self.dataset = self._load_data() # Load the dataset
|
|
|
|
def __len__(self):
|
|
"""
|
|
Returns the total number of samples in the dataset.
|
|
"""
|
|
return len(self.dataset)
|
|
|
|
def _load_data(self):
|
|
"""
|
|
Loads the data from the dataset folders.
|
|
|
|
Returns:
|
|
- dataset (list): A list containing image file paths and their corresponding class IDs.
|
|
"""
|
|
folder_path = FOLDER_PATHS[0] if self.train else FOLDER_PATHS[1] # Use train or test folder path
|
|
class_names = sorted(os.listdir(folder_path)) # Get class names from folder
|
|
class_map = {name: idx for idx, name in enumerate(class_names)} # Map class names to IDs
|
|
|
|
dataset = []
|
|
for class_name, class_id in class_map.items():
|
|
folder_class = os.path.join(folder_path, class_name) # Path to class folder
|
|
files_jpg = glob.glob(os.path.join(folder_class, '**', '*.jpg'), recursive=True) # Get all jpg files
|
|
for file_path in files_jpg:
|
|
dataset.append([file_path, class_id]) # Append file path and class ID to the dataset
|
|
|
|
return dataset
|
|
|
|
def __getitem__(self, index):
|
|
"""
|
|
Returns a sample and its corresponding label from the dataset.
|
|
|
|
Args:
|
|
- index (int): Index of the sample.
|
|
|
|
Returns:
|
|
- tuple: A tuple of the image tensor and the label tensor.
|
|
"""
|
|
Xs = [] # List to store apply_transformationed images
|
|
image_path = self.dataset[index][0] # Get image path from dataset
|
|
label = torch.tensor(int(self.dataset[index][1])) # Get class label as tensor
|
|
|
|
X = Image.open(image_path) # Open the image using PIL
|
|
|
|
if self.apply_transformation:
|
|
for _ in range(self.split_factor):
|
|
Xs.append(self.apply_transformation(X)) # Apply apply_transformationation multiple times
|
|
|
|
return torch.cat(Xs, dim=0), label # Concatenate all apply_transformationed images and return with the label
|
|
|
|
if __name__ == "__main__":
|
|
dataset = PillDataLarge() # Create an instance of the dataset
|
|
print(len(dataset)) # Print the size of the dataset
|
|
print(dataset[0]) # Print the first sample of the dataset
|