158 lines
5.6 KiB
Python
158 lines
5.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import logging
|
|
import numpy as np
|
|
import torch.utils.data as data
|
|
from PIL import Image
|
|
from torchvision.datasets import CIFAR100
|
|
|
|
# Configure logging
|
|
logging.basicConfig()
|
|
logger = logging.getLogger()
|
|
logger.setLevel(logging.INFO)
|
|
|
|
# Supported image extensions for loading images
|
|
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
|
|
|
|
def load_accimage(path):
|
|
"""
|
|
Attempts to load an image using the accimage backend.
|
|
If accimage fails, it falls back to using the PIL image loader.
|
|
|
|
Args:
|
|
path (str): Path to the image file.
|
|
|
|
Returns:
|
|
accimage.Image: The loaded image if successful, otherwise a PIL image.
|
|
"""
|
|
import accimage
|
|
try:
|
|
return accimage.Image(path)
|
|
except IOError:
|
|
# If accimage fails, use PIL to load the image
|
|
return load_image_pil(path)
|
|
|
|
def load_image_pil(path):
|
|
"""
|
|
Loads an image using PIL, ensuring that file handles are properly closed to prevent warnings.
|
|
|
|
Args:
|
|
path (str): Path to the image file.
|
|
|
|
Returns:
|
|
Image: The image loaded using PIL, converted to RGB format.
|
|
"""
|
|
with open(path, 'rb') as f:
|
|
img = Image.open(f)
|
|
return img.convert('RGB')
|
|
|
|
def basic_loader(path):
|
|
"""
|
|
Selects the appropriate image loader based on the backend configured by torchvision.
|
|
If the backend is 'accimage', it uses load_accimage; otherwise, it uses PIL.
|
|
|
|
Args:
|
|
path (str): Path to the image file.
|
|
|
|
Returns:
|
|
Image: The loaded image.
|
|
"""
|
|
from torchvision import get_image_backend
|
|
if get_image_backend() == 'accimage':
|
|
return load_accimage(path)
|
|
else:
|
|
return load_image_pil(path)
|
|
|
|
class CIFAR100_truncated(data.Dataset):
|
|
"""
|
|
Custom dataset class for CIFAR100 with optional data truncation.
|
|
It allows selecting a subset of the data by index and also enables modification of image channels.
|
|
"""
|
|
|
|
def __init__(self, root, dataidxs=None, train=True, apply_transformation=None, target_apply_transformation=None, download=False):
|
|
"""
|
|
Initializes the CIFAR100_truncated dataset.
|
|
|
|
Args:
|
|
root (str): The root directory where the dataset is stored.
|
|
dataidxs (list or None): List of indices for truncating the dataset, if applicable.
|
|
train (bool): Whether to load the training set (True) or the test set (False).
|
|
apply_transformation (callable, optional): apply_transformationation function applied to images.
|
|
target_apply_transformation (callable, optional): apply_transformationation function applied to targets (labels).
|
|
download (bool): Whether to download the dataset if it is not found in the root directory.
|
|
"""
|
|
self.root = root # Root directory where dataset is stored
|
|
self.dataidxs = dataidxs # List of indices for truncating the dataset
|
|
self.train = train # Specifies whether to load the training set
|
|
self.apply_transformation = apply_transformation # Optional apply_transformationations on images
|
|
self.target_apply_transformation = target_apply_transformation # Optional apply_transformationations on labels
|
|
self.download = download # Specifies whether to download the dataset if missing
|
|
|
|
# Build the truncated dataset based on the provided indices
|
|
self.data, self.target = self.__build_truncated_dataset__()
|
|
|
|
def __build_truncated_dataset__(self):
|
|
"""
|
|
Constructs the truncated dataset based on the provided data indices.
|
|
|
|
Returns:
|
|
tuple: The truncated data and corresponding target labels.
|
|
"""
|
|
cifar_dataobj = CIFAR100(self.root, self.train, self.apply_transformation, self.target_apply_transformation, self.download)
|
|
|
|
# Load all data and targets
|
|
data = cifar_dataobj.data
|
|
target = np.array(cifar_dataobj.targets)
|
|
|
|
# If specific indices are provided, truncate the dataset accordingly
|
|
if self.dataidxs is not None:
|
|
data = data[self.dataidxs]
|
|
target = target[self.dataidxs]
|
|
|
|
return data, target
|
|
|
|
def truncate_channel(self, index):
|
|
"""
|
|
Modifies the selected images by zeroing out the green and blue channels,
|
|
effectively converting them to grayscale-like images.
|
|
|
|
Args:
|
|
index (np.array): The indices of images to modify.
|
|
"""
|
|
for i in range(index.shape[0]):
|
|
gs_index = index[i]
|
|
self.data[gs_index, :, :, 1] = 0.0 # Set the green channel to 0
|
|
self.data[gs_index, :, :, 2] = 0.0 # Set the blue channel to 0
|
|
|
|
def __getitem__(self, index):
|
|
"""
|
|
Retrieves an image and its corresponding target (label) at the given index.
|
|
|
|
Args:
|
|
index (int): Index of the data point to retrieve.
|
|
|
|
Returns:
|
|
tuple: (image, target) where the image is apply_transformationed (if specified), and the target is the label.
|
|
"""
|
|
img, target = self.data[index], self.target[index]
|
|
|
|
# Apply any specified apply_transformationations to the image
|
|
if self.apply_transformation is not None:
|
|
img = self.apply_transformation(img)
|
|
|
|
# Apply any specified apply_transformationations to the target label
|
|
if self.target_apply_transformation is not None:
|
|
target = self.target_apply_transformation(target)
|
|
|
|
return img, target
|
|
|
|
def __len__(self):
|
|
"""
|
|
Returns the total number of data points in the dataset.
|
|
|
|
Returns:
|
|
int: The number of samples in the dataset.
|
|
"""
|
|
return len(self.data)
|