110 lines
4.4 KiB
Python
110 lines
4.4 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import logging
|
|
import numpy as np
|
|
import torch.utils.data as data
|
|
from PIL import Image
|
|
from torchvision.datasets import CIFAR10
|
|
|
|
# Set up logging
|
|
logging.basicConfig()
|
|
logger = logging.getLogger()
|
|
logger.setLevel(logging.INFO)
|
|
|
|
# Supported image extensions
|
|
# These are the file extensions that the loaders will support for image formats
|
|
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
|
|
|
|
# Loader using accimage, a faster image loading library than PIL
|
|
def load_accimage(path):
|
|
import accimage
|
|
try:
|
|
# Try to load the image with accimage
|
|
return accimage.Image(path)
|
|
except IOError:
|
|
# If there's an error, fallback to PIL for image loading
|
|
return load_image_pil(path)
|
|
|
|
# Loader using PIL (Python Imaging Library)
|
|
def load_image_pil(path):
|
|
# Open the file in binary mode to avoid resource warnings
|
|
with open(path, 'rb') as f:
|
|
img = Image.open(f)
|
|
# Convert the image to RGB mode (3 channels)
|
|
return img.convert('RGB')
|
|
|
|
# Default image loader that chooses accimage if available, otherwise PIL
|
|
def basic_loader(path):
|
|
from torchvision import get_image_backend
|
|
# Check if the image backend is accimage
|
|
if get_image_backend() == 'accimage':
|
|
return load_accimage(path)
|
|
# Otherwise, fallback to PIL
|
|
return load_image_pil(path)
|
|
|
|
# Custom CIFAR10 dataset with truncation capabilities
|
|
# This class extends the torch.utils.data.Dataset to support CIFAR10 with truncation of data
|
|
class CIFAR10Truncated(data.Dataset):
|
|
|
|
def __init__(self, root, dataidxs=None, train=True, apply_transformation=None, target_apply_transformation=None, download=False):
|
|
self.root = root # Root directory for the dataset
|
|
self.dataidxs = dataidxs # Subset of data indices (optional)
|
|
self.train = train # Boolean flag indicating if the dataset is for training
|
|
self.apply_transformation = apply_transformation # apply_transformationations to apply to the images (optional)
|
|
self.target_apply_transformation = target_apply_transformation # apply_transformationations to apply to the labels (optional)
|
|
self.download = download # Boolean flag to download the dataset if not available
|
|
|
|
# Build the truncated dataset based on the provided indices
|
|
self.data, self.target = self._build_truncated_dataset()
|
|
|
|
def _build_truncated_dataset(self):
|
|
# Log whether the dataset is being downloaded
|
|
logger.info(f"Download: {self.download}")
|
|
|
|
# Load the CIFAR10 dataset from torchvision
|
|
cifar_data = CIFAR10(self.root, self.train, apply_transformation=self.apply_transformation,
|
|
target_apply_transformation=self.target_apply_transformation, download=self.download)
|
|
|
|
# Extract data (images) and targets (labels) from the CIFAR10 dataset
|
|
data = cifar_data.data
|
|
target = np.array(cifar_data.targets)
|
|
|
|
# If data indices are provided, filter the data and targets accordingly
|
|
if self.dataidxs is not None:
|
|
data = data[self.dataidxs]
|
|
target = target[self.dataidxs]
|
|
|
|
# Return the truncated data and targets
|
|
return data, target
|
|
|
|
def truncate_channel(self, indices):
|
|
# Zero out the second and third channels (green and blue) for selected images
|
|
for idx in indices:
|
|
self.data[idx, :, :, 1] = 0.0 # Zero out the green channel
|
|
self.data[idx, :, :, 2] = 0.0 # Zero out the blue channel
|
|
|
|
def __getitem__(self, index):
|
|
"""
|
|
Args:
|
|
index (int): Index of the image
|
|
|
|
Returns:
|
|
tuple: (image, target) where target is the class label.
|
|
"""
|
|
img, target = self.data[index], self.target[index]
|
|
|
|
# Apply image apply_transformationations if any are specified
|
|
if self.apply_transformation is not None:
|
|
img = self.apply_transformation(img)
|
|
# Apply target apply_transformationations if any are specified
|
|
if self.target_apply_transformation is not None:
|
|
target = self.target_apply_transformation(target)
|
|
|
|
# Return the apply_transformationed image and its corresponding target
|
|
return img, target
|
|
|
|
def __len__(self):
|
|
# Return the total number of images in the dataset
|
|
return len(self.data)
|