Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

110 lines
4.4 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import logging
import numpy as np
import torch.utils.data as data
from PIL import Image
from torchvision.datasets import CIFAR10
# Set up logging
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Supported image extensions
# These are the file extensions that the loaders will support for image formats
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
# Loader using accimage, a faster image loading library than PIL
def load_accimage(path):
import accimage
try:
# Try to load the image with accimage
return accimage.Image(path)
except IOError:
# If there's an error, fallback to PIL for image loading
return load_image_pil(path)
# Loader using PIL (Python Imaging Library)
def load_image_pil(path):
# Open the file in binary mode to avoid resource warnings
with open(path, 'rb') as f:
img = Image.open(f)
# Convert the image to RGB mode (3 channels)
return img.convert('RGB')
# Default image loader that chooses accimage if available, otherwise PIL
def basic_loader(path):
from torchvision import get_image_backend
# Check if the image backend is accimage
if get_image_backend() == 'accimage':
return load_accimage(path)
# Otherwise, fallback to PIL
return load_image_pil(path)
# Custom CIFAR10 dataset with truncation capabilities
# This class extends the torch.utils.data.Dataset to support CIFAR10 with truncation of data
class CIFAR10Truncated(data.Dataset):
def __init__(self, root, dataidxs=None, train=True, apply_transformation=None, target_apply_transformation=None, download=False):
self.root = root # Root directory for the dataset
self.dataidxs = dataidxs # Subset of data indices (optional)
self.train = train # Boolean flag indicating if the dataset is for training
self.apply_transformation = apply_transformation # apply_transformationations to apply to the images (optional)
self.target_apply_transformation = target_apply_transformation # apply_transformationations to apply to the labels (optional)
self.download = download # Boolean flag to download the dataset if not available
# Build the truncated dataset based on the provided indices
self.data, self.target = self._build_truncated_dataset()
def _build_truncated_dataset(self):
# Log whether the dataset is being downloaded
logger.info(f"Download: {self.download}")
# Load the CIFAR10 dataset from torchvision
cifar_data = CIFAR10(self.root, self.train, apply_transformation=self.apply_transformation,
target_apply_transformation=self.target_apply_transformation, download=self.download)
# Extract data (images) and targets (labels) from the CIFAR10 dataset
data = cifar_data.data
target = np.array(cifar_data.targets)
# If data indices are provided, filter the data and targets accordingly
if self.dataidxs is not None:
data = data[self.dataidxs]
target = target[self.dataidxs]
# Return the truncated data and targets
return data, target
def truncate_channel(self, indices):
# Zero out the second and third channels (green and blue) for selected images
for idx in indices:
self.data[idx, :, :, 1] = 0.0 # Zero out the green channel
self.data[idx, :, :, 2] = 0.0 # Zero out the blue channel
def __getitem__(self, index):
"""
Args:
index (int): Index of the image
Returns:
tuple: (image, target) where target is the class label.
"""
img, target = self.data[index], self.target[index]
# Apply image apply_transformationations if any are specified
if self.apply_transformation is not None:
img = self.apply_transformation(img)
# Apply target apply_transformationations if any are specified
if self.target_apply_transformation is not None:
target = self.target_apply_transformation(target)
# Return the apply_transformationed image and its corresponding target
return img, target
def __len__(self):
# Return the total number of images in the dataset
return len(self.data)