# -*- coding: utf-8 -*- # @Author: Weisen Pan import os import torch import torch.nn as nn import torch.nn.functional as F from sklearn import ensemble from .mixup import mixup_loss_criterion, combine_mixup_data from . import resnet, resnet_sl __all__ = ['coremodel'] def _retrieve_network(arch='wide_resnet28_10'): """ Get the network architecture based on the provided name. Args: arch (str): Name of the architecture. Returns: Callable: The network class or function corresponding to the given architecture. """ networks = { 'wide_resnet28_10': resnet.wide_resnet28_10, 'wide_resnet16_8': resnet.wide_resnet16_8, 'resnet110': resnet.resnet110, 'wide_resnet_model_50_2': resnet.wide_resnet_model_50_2 } if arch not in networks: raise ValueError(f"Architecture {arch} is not supported.") return networks[arch] class coremodel(nn.Module): def __init__(self, args, norm_layer=None, criterion=None, progress=True): """ Initialize the coremodel model with multiple sub-networks. Args: args (argparse.Namespace): Configuration arguments. norm_layer (callable, optional): Normalization layer. criterion (callable, optional): Loss function. progress (bool): Whether to show progress. """ super(coremodel, self).__init__() # Configuration parameters self.split_factor = args.split_factor self.arch = args.arch self.loop_factor = args.loop_factor self.is_train_sep = args.is_train_sep self.epochs = args.epochs self.criterion = criterion self.is_diff_data_train = args.is_diff_data_train self.is_mixup = args.is_mixup self.mix_alpha = args.mix_alpha # Define model architectures valid_archs = [ 'resnet_model_50', 'resnet_model_101', 'resnet_model_152', 'resnet_model_200', 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext29_8x64d', 'resnext29_16x64d', 'resnet110', 'resnet164', 'wide_resnet16_8', 'wide_resnet16_12', 'wide_resnet28_10', 'wide_resnet40_10', 'wide_resnet52_8', 'wide_resnet_model_50_2', 'wide_resnet_model_50_3', 'wide_resnet_model_101_2' ] if self.arch not in valid_archs: raise NotImplementedError(f"Architecture {self.arch} is not implemented.") model_args = { 'num_classes': args.num_classes, 'norm_layer': norm_layer, 'dataset': args.dataset, 'split_factor': self.split_factor, 'output_stride': args.output_stride } # Initialize multiple sub-models based on the loop factor self.models = nn.ModuleList([_retrieve_network(self.arch)(models_models_pretrained=args.models_models_pretrained, **model_args) for _ in range(self.loop_factor)]) if args.is_identical_init: print("INFO: Using identical initialization.") self._identical_init() # Ensemble settings self.is_ensembled_loss = args.is_ensembled_loss if self.split_factor > 1 else False self.ensembled_loss_weight = args.ensembled_loss_weight self.is_ensembled_after_softmax = args.is_ensembled_after_softmax if self.split_factor > 1 else False self.is_max_ensemble = args.is_max_ensemble if self.split_factor > 1 else False # Co-training settings self.is_cot_loss = args.is_cot_loss if self.split_factor > 1 else False self.cot_weight = args.cot_weight self.is_cot_weight_warm_up = args.is_cot_weight_warm_up self.cot_weight_warm_up_epochs = args.cot_weight_warm_up_epochs self.cot_loss_choose = args.cot_loss_choose print(f"INFO: The co-training loss is {self.cot_loss_choose}.") self.num_classes = args.num_classes def forward(self, x, target=None, mode='train', epoch=0, streams=None): """ Forward pass through the model with optional mixup and co-training loss. Args: x (Tensor): Input tensor. target (Tensor, optional): Target tensor for loss computation. mode (str): Mode of operation ('train', 'val', or 'test'). epoch (int): Current epoch. streams (optional): Additional data streams. Returns: Tuple: - ensemble_output (Tensor): The ensemble output of shape [batch_size, num_classes]. - outputs (Tensor): Stack of individual outputs of shape [split_factor, batch_size, num_classes]. - ce_loss (Tensor): Sum of cross-entropy losses for each model. - cot_loss (Tensor): Co-training loss if applicable. """ outputs, ce_losses = [], [] if 'train' in mode: if self.is_mixup: x, y_a, y_b, lam = combine_mixup_data(x, target, alpha=self.mix_alpha) # Split input data based on the loop factor all_x = torch.chunk(x, chunks=self.loop_factor, dim=1) if self.is_diff_data_train else [x] for i in range(self.loop_factor: x_input = all_x[i] output = self.models[i](x_input) loss = mixup_loss_criterion(self.criterion, output, y_a, y_b, lam) if self.is_mixup else self.criterion(output, target) outputs.append(output) ce_losses.append(loss) elif mode in ['val', 'test']: for i in range(self.loop_factor: output = self.models[i](x) loss = self.criterion(output, target) if self.criterion else torch.zeros(1) outputs.append(output) ce_losses.append(loss) else: return torch.ones(1), None, None, None # Calculate ensemble output and losses ensemble_output = self._collect_ensemble_output(outputs) ce_loss = torch.sum(torch.stack(ce_losses)) if mode in ['val', 'test']: return ensemble_output, torch.stack(outputs, dim=0), ce_loss if self.is_cot_loss: cot_loss = self._calculate_co_training_loss(outputs, self.cot_loss_choose, epoch) else: cot_loss = torch.zeros_like(ce_loss) return ensemble_output, torch.stack(outputs, dim=0), ce_loss, cot_loss def _collect_ensemble_output(self, outputs): """ Calculate the ensemble output from a list of tensors. Args: outputs (list of tensors): A list where each tensor has shape [batch_size, num_classes]. Returns: Tensor: The ensemble output with shape [batch_size, num_classes]. """ stacked_outputs = torch.stack(outputs, dim=0) if self.is_ensembled_after_softmax: softmax_outputs = F.softmax(stacked_outputs, dim=-1) if self.is_max_ensemble: ensemble_output, _ = torch.max(softmax_outputs, dim=0) else: ensemble_output = torch.mean(softmax_outputs, dim=0) else: if self.is_max_ensemble: ensemble_output, _ = torch.max(stacked_outputs, dim=0) else: ensemble_output = torch.mean(stacked_outputs, dim=0) return ensemble_output def _calculate_co_training_loss(self, outputs, loss_choose, epoch=0): """ Calculate the co-training loss between outputs of different networks. Args: outputs (list of tensors): A list where each tensor has shape [batch_size, num_classes]. loss_choose (str): Type of co-training loss to compute ('js_divergence' or 'kl_seperate'). epoch (int): Current epoch. Returns: Tensor: The computed co-training loss. """ weight_now = self.cot_weight if self.is_cot_weight_warm_up and epoch < self.cot_weight_warm_up_epochs: weight_now = max(self.cot_weight * epoch / self.cot_weight_warm_up_epochs, 0.005) stacked_outputs = torch.stack(outputs, dim=0) if loss_choose == 'js_divergence': p_all = F.softmax(stacked_outputs, dim=-1) p_mean = torch.mean(p_all, dim=0) H_mean = (-p_mean * torch.log(p_mean + 1e-8)).sum(-1).mean() H_sep = (-p_all * F.log_softmax(stacked_outputs, dim=-1)).sum(-1).mean() cot_loss = weight_now * (H_mean - H_sep) elif loss_choose == 'kl_seperate': outputs_r1 = torch.repeat_interleave(stacked_outputs, self.split_factor - 1, dim=0) index_list = [j for i in range(self.split_factor) for j in range(self.split_factor) if j != i] outputs_r2 = torch.index_select(stacked_outputs, dim=0, index=torch.tensor(index_list, dtype=torch.long, device=stacked_outputs.device)) kl_loss = F.kl_div(F.log_softmax(outputs_r1, dim=-1), F.softmax(outputs_r2,”