# -*- coding: utf-8 -*- # @Author: Weisen Pan import warnings from contextlib import contextmanager import os import shutil import tempfile import torch from .folder import ImageFolder from .utils import validate_integrity, extract_archive, verify_str_arg # Dictionary that maps the dataset split (train/val/devkit) to its corresponding archive filename and checksum (md5 hash) ARCHIVE_info_map = { 'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'), 'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'), 'devkit': ('ILSVRC2012_devkit_t12.tar', 'fa75699e90414af021442c21a62c3abf') } # File name where the information map (class info, wnid, etc.) is stored info_map_FILE = "info_map.bin" class ImageNet(ImageFolder): """`ImageNet `_ 2012 Classification Dataset. Args: root (str): Root directory of the ImageNet Dataset. split (str, optional): Dataset split, either ``train`` or ``val``. apply_transformation (callable, optional): A function/apply_transformation to apply to the PIL image. target_apply_transformation (callable, optional): A function/apply_transformation to apply to the target. loader (callable, optional): Function to load an image from its path. Attributes: classes (list): List of class name tuples. class_to_idx (dict): Mapping of class names to indices. wnids (list): List of WordNet IDs. wnid_to_idx (dict): Mapping of WordNet IDs to class indices. imgs (list): List of image path and class index tuples. targets (list): Class index values for each image in the dataset. """ def __init__(self, root, split='train', download=None, **kwargs): # Check if download flag is used, raise warnings since dataset is no longer publicly accessible if download is True: raise RuntimeError("The dataset is no longer publicly accessible. Please download archives externally and place them in the root directory.") elif download is False: warnings.warn("The download flag is deprecated, as the dataset is no longer publicly accessible.", RuntimeWarning) # Expand the root directory path root = self.root = os.path.expanduser(root) # Validate the dataset split (should be either 'train' or 'val') self.split = verify_str_arg(split, "split", ("train", "val")) # Parse dataset archives (train/val/devkit) and prepare the dataset self.extract_archives() # Load WordNet ID to class mappings from the info_map file wnid_to_classes = load_information_map_file(self.root)[0] # Initialize the ImageFolder with the split folder (train/val directory) super().__init__(self.divide_folder_contents, **kwargs) # Set class-related attributes self.root = root self.wnids = self.classes self.wnid_to_idx = self.class_to_idx # Update classes to human-readable names and adjust the class_to_idx mapping self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss} def extract_archives(self): # Check if the info_map file exists and is valid, otherwise parse the devkit archive if not validate_integrity(os.path.join(self.root, info_map_FILE)): extract_devkit_archive(self.root) # If the dataset folder (train/val) does not exist, extract the respective archive if not os.path.isdir(self.divide_folder_contents): if self.split == 'train': process_train_archive(self.root) elif self.split == 'val': process_validation_archive(self.root) @property def divide_folder_contents(self): # Return the path of the folder containing the images (train/val) return os.path.join(self.root, self.split) def extra_repr(self): # Additional representation for the dataset object (showing the split) return f"Split: {self.split}" def load_information_map_file(root, file=None): # Load the info_map file from the root directory file = os.path.join(root, file or info_map_FILE) if validate_integrity(file): return torch.load(file) else: raise RuntimeError(f"The info_map file {file} is either missing or corrupted. Please ensure it exists in the root directory.") def _validate_archive_file(root, file, md5): # Verify if the archive file is present and its checksum matches if not validate_integrity(os.path.join(root, file), md5): raise RuntimeError(f"The archive {file} is either missing or corrupted. Please download it and place it in {root}.") def extract_devkit_archive(root, file=None): """Extract and process the ImageNet 2012 devkit archive to generate info_map information. Args: root (str): Root directory with the devkit archive. file (str, optional): Archive filename. Defaults to 'ILSVRC2012_devkit_t12.tar'. """ import scipy.io as sio # Parse info_map.mat from the devkit, containing class and WordNet ID information def read_info_map_mat_file(devkit_root): info_map_path = os.path.join(devkit_root, "data", "info_map.mat") info_map = sio.loadmat(info_map_path, squeeze_me=True)['synsets'] info_map = [info_map[idx] for idx, num_children in enumerate(info_map[4]) if num_children == 0] idcs, wnids, classes = zip(*info_map)[:3] classes = [tuple(clss.split(', ')) for clss in classes] return {idx: wnid for idx, wnid in zip(idcs, wnids)}, {wnid: clss for wnid, clss in zip(wnids, classes)} # Parse the validation ground truth file for image class labels def process_val_groundtruth_txt(devkit_root): file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt") with open(file) as f: return [int(line.strip()) for line in f] # Context manager to handle temporary directories for archive extraction @contextmanager def get_tmp_dir(): tmp_dir = tempfile.mkdtemp() try: yield tmp_dir finally: shutil.rmtree(tmp_dir) # Extract and process the devkit archive file, md5 = ARCHIVE_info_map["devkit"] _validate_archive_file(root, file, md5) with get_tmp_dir() as tmp_dir: extract_archive(os.path.join(root, file), tmp_dir) devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12") idx_to_wnid, wnid_to_classes = read_info_map_mat_file(devkit_root) val_idcs = process_val_groundtruth_txt(devkit_root) val_wnids = [idx_to_wnid[idx] for idx in val_idcs] # Save the mappings to the info_map file torch.save((wnid_to_classes, val_wnids), os.path.join(root, info_map_FILE)) def process_train_archive(root, file=None, folder="train"): """Extract and organize the ImageNet 2012 train dataset. Args: root (str): Root directory containing the train dataset archive. file (str, optional): Archive filename. Defaults to 'ILSVRC2012_img_train.tar'. folder (str, optional): Destination folder. Defaults to 'train'. """ file, md5 = ARCHIVE_info_map["train"] _validate_archive_file(root, file, md5) train_root = os.path.join(root, folder) extract_archive(os.path.join(root, file), train_root) # Extract each class-specific archive in the train dataset for archive in os.listdir(train_root): extract_archive(os.path.join(train_root, archive), os.path.splitext(archive)[0], remove_finished=True) def process_validation_archive(root, file=None, wnids=None, folder="val"): """Extract and organize the ImageNet 2012 validation dataset. Args: root (str): Root directory containing the validation dataset archive. file (str, optional): Archive filename. Defaults to 'ILSVRC2012_img_val.tar'. wnids (list, optional): WordNet IDs for validation images. Defaults to None (loaded from info_map file). folder (str, optional): Destination folder. Defaults to 'val'. """ file, md5 = ARCHIVE_info_map["val"] if wnids is None: wnids = load_information_map_file(root)[1] _validate_archive_file(root, file, md5) val_root = os.path.join(root, folder) extract_archive(os.path.join(root, file), val_root) # Create directories for each WordNet ID (class) and move validation images into their respective folders for wnid in set(wnids): os.mkdir(os.path.join(val_root, wnid)) for wnid, img in zip(wnids, sorted(os