Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

84 lines
3.4 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import os
import glob
from PIL import Image
import torch
from torch.utils.data import Dataset
# Define the folder paths for training and testing datasets
FOLDER_PATHS = [
'/media/skydata/alpha0012/workspace/EdgeFLite/coremodel/dataset_hub/medical_images/train_images',
'/media/skydata/alpha0012/workspace/EdgeFLite/coremodel/dataset_hub/medical_images/test_images'
]
# Custom dataset class inheriting from PyTorch's Dataset class
class PillDataLarge(Dataset):
def __init__(self, train=True, apply_transformation=None, split_factor=1):
"""
Initializes the dataset object.
Args:
- train (bool): If True, load the training dataset, otherwise load the test dataset.
- apply_transformation (callable, optional): Optional apply_transformationations to be applied on an image sample.
- split_factor (int): Number of times to apply the apply_transformationations to the image.
"""
self.train = train # Flag to determine if the dataset is for training or testing
self.apply_transformation = apply_transformation # apply_transformationation to apply to the images
self.split_factor = split_factor # Number of times to apply the apply_transformationation
self.dataset = self._load_data() # Load the dataset
def __len__(self):
"""
Returns the total number of samples in the dataset.
"""
return len(self.dataset)
def _load_data(self):
"""
Loads the data from the dataset folders.
Returns:
- dataset (list): A list containing image file paths and their corresponding class IDs.
"""
folder_path = FOLDER_PATHS[0] if self.train else FOLDER_PATHS[1] # Use train or test folder path
class_names = sorted(os.listdir(folder_path)) # Get class names from folder
class_map = {name: idx for idx, name in enumerate(class_names)} # Map class names to IDs
dataset = []
for class_name, class_id in class_map.items():
folder_class = os.path.join(folder_path, class_name) # Path to class folder
files_jpg = glob.glob(os.path.join(folder_class, '**', '*.jpg'), recursive=True) # Get all jpg files
for file_path in files_jpg:
dataset.append([file_path, class_id]) # Append file path and class ID to the dataset
return dataset
def __getitem__(self, index):
"""
Returns a sample and its corresponding label from the dataset.
Args:
- index (int): Index of the sample.
Returns:
- tuple: A tuple of the image tensor and the label tensor.
"""
Xs = [] # List to store apply_transformationed images
image_path = self.dataset[index][0] # Get image path from dataset
label = torch.tensor(int(self.dataset[index][1])) # Get class label as tensor
X = Image.open(image_path) # Open the image using PIL
if self.apply_transformation:
for _ in range(self.split_factor):
Xs.append(self.apply_transformation(X)) # Apply apply_transformationation multiple times
return torch.cat(Xs, dim=0), label # Concatenate all apply_transformationed images and return with the label
if __name__ == "__main__":
dataset = PillDataLarge() # Create an instance of the dataset
print(len(dataset)) # Print the size of the dataset
print(dataset[0]) # Print the first sample of the dataset