Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

158 lines
5.6 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import logging
import numpy as np
import torch.utils.data as data
from PIL import Image
from torchvision.datasets import CIFAR100
# Configure logging
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Supported image extensions for loading images
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
def load_accimage(path):
"""
Attempts to load an image using the accimage backend.
If accimage fails, it falls back to using the PIL image loader.
Args:
path (str): Path to the image file.
Returns:
accimage.Image: The loaded image if successful, otherwise a PIL image.
"""
import accimage
try:
return accimage.Image(path)
except IOError:
# If accimage fails, use PIL to load the image
return load_image_pil(path)
def load_image_pil(path):
"""
Loads an image using PIL, ensuring that file handles are properly closed to prevent warnings.
Args:
path (str): Path to the image file.
Returns:
Image: The image loaded using PIL, converted to RGB format.
"""
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def basic_loader(path):
"""
Selects the appropriate image loader based on the backend configured by torchvision.
If the backend is 'accimage', it uses load_accimage; otherwise, it uses PIL.
Args:
path (str): Path to the image file.
Returns:
Image: The loaded image.
"""
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return load_accimage(path)
else:
return load_image_pil(path)
class CIFAR100_truncated(data.Dataset):
"""
Custom dataset class for CIFAR100 with optional data truncation.
It allows selecting a subset of the data by index and also enables modification of image channels.
"""
def __init__(self, root, dataidxs=None, train=True, apply_transformation=None, target_apply_transformation=None, download=False):
"""
Initializes the CIFAR100_truncated dataset.
Args:
root (str): The root directory where the dataset is stored.
dataidxs (list or None): List of indices for truncating the dataset, if applicable.
train (bool): Whether to load the training set (True) or the test set (False).
apply_transformation (callable, optional): apply_transformationation function applied to images.
target_apply_transformation (callable, optional): apply_transformationation function applied to targets (labels).
download (bool): Whether to download the dataset if it is not found in the root directory.
"""
self.root = root # Root directory where dataset is stored
self.dataidxs = dataidxs # List of indices for truncating the dataset
self.train = train # Specifies whether to load the training set
self.apply_transformation = apply_transformation # Optional apply_transformationations on images
self.target_apply_transformation = target_apply_transformation # Optional apply_transformationations on labels
self.download = download # Specifies whether to download the dataset if missing
# Build the truncated dataset based on the provided indices
self.data, self.target = self.__build_truncated_dataset__()
def __build_truncated_dataset__(self):
"""
Constructs the truncated dataset based on the provided data indices.
Returns:
tuple: The truncated data and corresponding target labels.
"""
cifar_dataobj = CIFAR100(self.root, self.train, self.apply_transformation, self.target_apply_transformation, self.download)
# Load all data and targets
data = cifar_dataobj.data
target = np.array(cifar_dataobj.targets)
# If specific indices are provided, truncate the dataset accordingly
if self.dataidxs is not None:
data = data[self.dataidxs]
target = target[self.dataidxs]
return data, target
def truncate_channel(self, index):
"""
Modifies the selected images by zeroing out the green and blue channels,
effectively converting them to grayscale-like images.
Args:
index (np.array): The indices of images to modify.
"""
for i in range(index.shape[0]):
gs_index = index[i]
self.data[gs_index, :, :, 1] = 0.0 # Set the green channel to 0
self.data[gs_index, :, :, 2] = 0.0 # Set the blue channel to 0
def __getitem__(self, index):
"""
Retrieves an image and its corresponding target (label) at the given index.
Args:
index (int): Index of the data point to retrieve.
Returns:
tuple: (image, target) where the image is apply_transformationed (if specified), and the target is the label.
"""
img, target = self.data[index], self.target[index]
# Apply any specified apply_transformationations to the image
if self.apply_transformation is not None:
img = self.apply_transformation(img)
# Apply any specified apply_transformationations to the target label
if self.target_apply_transformation is not None:
target = self.target_apply_transformation(target)
return img, target
def __len__(self):
"""
Returns the total number of data points in the dataset.
Returns:
int: The number of samples in the dataset.
"""
return len(self.data)