213 lines
8.9 KiB
Python
213 lines
8.9 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from sklearn import ensemble
|
|
from .mixup import mixup_loss_criterion, combine_mixup_data
|
|
from . import resnet, resnet_sl
|
|
|
|
__all__ = ['coremodel']
|
|
|
|
def _retrieve_network(arch='wide_resnet28_10'):
|
|
"""
|
|
Get the network architecture based on the provided name.
|
|
|
|
Args:
|
|
arch (str): Name of the architecture.
|
|
|
|
Returns:
|
|
Callable: The network class or function corresponding to the given architecture.
|
|
"""
|
|
networks = {
|
|
'wide_resnet28_10': resnet.wide_resnet28_10,
|
|
'wide_resnet16_8': resnet.wide_resnet16_8,
|
|
'resnet110': resnet.resnet110,
|
|
'wide_resnet_model_50_2': resnet.wide_resnet_model_50_2
|
|
}
|
|
if arch not in networks:
|
|
raise ValueError(f"Architecture {arch} is not supported.")
|
|
return networks[arch]
|
|
|
|
class coremodel(nn.Module):
|
|
def __init__(self, args, norm_layer=None, criterion=None, progress=True):
|
|
"""
|
|
Initialize the coremodel model with multiple sub-networks.
|
|
|
|
Args:
|
|
args (argparse.Namespace): Configuration arguments.
|
|
norm_layer (callable, optional): Normalization layer.
|
|
criterion (callable, optional): Loss function.
|
|
progress (bool): Whether to show progress.
|
|
"""
|
|
super(coremodel, self).__init__()
|
|
|
|
# Configuration parameters
|
|
self.split_factor = args.split_factor
|
|
self.arch = args.arch
|
|
self.loop_factor = args.loop_factor
|
|
self.is_train_sep = args.is_train_sep
|
|
self.epochs = args.epochs
|
|
self.criterion = criterion
|
|
self.is_diff_data_train = args.is_diff_data_train
|
|
self.is_mixup = args.is_mixup
|
|
self.mix_alpha = args.mix_alpha
|
|
|
|
# Define model architectures
|
|
valid_archs = [
|
|
'resnet_model_50', 'resnet_model_101', 'resnet_model_152', 'resnet_model_200',
|
|
'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d',
|
|
'resnext29_8x64d', 'resnext29_16x64d', 'resnet110', 'resnet164',
|
|
'wide_resnet16_8', 'wide_resnet16_12', 'wide_resnet28_10', 'wide_resnet40_10',
|
|
'wide_resnet52_8', 'wide_resnet_model_50_2', 'wide_resnet_model_50_3', 'wide_resnet_model_101_2'
|
|
]
|
|
|
|
if self.arch not in valid_archs:
|
|
raise NotImplementedError(f"Architecture {self.arch} is not implemented.")
|
|
|
|
model_args = {
|
|
'num_classes': args.num_classes,
|
|
'norm_layer': norm_layer,
|
|
'dataset': args.dataset,
|
|
'split_factor': self.split_factor,
|
|
'output_stride': args.output_stride
|
|
}
|
|
# Initialize multiple sub-models based on the loop factor
|
|
self.models = nn.ModuleList([_retrieve_network(self.arch)(models_models_pretrained=args.models_models_pretrained, **model_args) for _ in range(self.loop_factor)])
|
|
|
|
if args.is_identical_init:
|
|
print("INFO: Using identical initialization.")
|
|
self._identical_init()
|
|
|
|
# Ensemble settings
|
|
self.is_ensembled_loss = args.is_ensembled_loss if self.split_factor > 1 else False
|
|
self.ensembled_loss_weight = args.ensembled_loss_weight
|
|
self.is_ensembled_after_softmax = args.is_ensembled_after_softmax if self.split_factor > 1 else False
|
|
self.is_max_ensemble = args.is_max_ensemble if self.split_factor > 1 else False
|
|
|
|
# Co-training settings
|
|
self.is_cot_loss = args.is_cot_loss if self.split_factor > 1 else False
|
|
self.cot_weight = args.cot_weight
|
|
self.is_cot_weight_warm_up = args.is_cot_weight_warm_up
|
|
self.cot_weight_warm_up_epochs = args.cot_weight_warm_up_epochs
|
|
self.cot_loss_choose = args.cot_loss_choose
|
|
print(f"INFO: The co-training loss is {self.cot_loss_choose}.")
|
|
self.num_classes = args.num_classes
|
|
|
|
def forward(self, x, target=None, mode='train', epoch=0, streams=None):
|
|
"""
|
|
Forward pass through the model with optional mixup and co-training loss.
|
|
|
|
Args:
|
|
x (Tensor): Input tensor.
|
|
target (Tensor, optional): Target tensor for loss computation.
|
|
mode (str): Mode of operation ('train', 'val', or 'test').
|
|
epoch (int): Current epoch.
|
|
streams (optional): Additional data streams.
|
|
|
|
Returns:
|
|
Tuple:
|
|
- ensemble_output (Tensor): The ensemble output of shape [batch_size, num_classes].
|
|
- outputs (Tensor): Stack of individual outputs of shape [split_factor, batch_size, num_classes].
|
|
- ce_loss (Tensor): Sum of cross-entropy losses for each model.
|
|
- cot_loss (Tensor): Co-training loss if applicable.
|
|
"""
|
|
outputs, ce_losses = [], []
|
|
|
|
if 'train' in mode:
|
|
if self.is_mixup:
|
|
x, y_a, y_b, lam = combine_mixup_data(x, target, alpha=self.mix_alpha)
|
|
|
|
# Split input data based on the loop factor
|
|
all_x = torch.chunk(x, chunks=self.loop_factor, dim=1) if self.is_diff_data_train else [x]
|
|
|
|
for i in range(self.loop_factor:
|
|
x_input = all_x[i]
|
|
output = self.models[i](x_input)
|
|
loss = mixup_loss_criterion(self.criterion, output, y_a, y_b, lam) if self.is_mixup else self.criterion(output, target)
|
|
outputs.append(output)
|
|
ce_losses.append(loss)
|
|
|
|
elif mode in ['val', 'test']:
|
|
for i in range(self.loop_factor:
|
|
output = self.models[i](x)
|
|
loss = self.criterion(output, target) if self.criterion else torch.zeros(1)
|
|
outputs.append(output)
|
|
ce_losses.append(loss)
|
|
|
|
else:
|
|
return torch.ones(1), None, None, None
|
|
|
|
# Calculate ensemble output and losses
|
|
ensemble_output = self._collect_ensemble_output(outputs)
|
|
ce_loss = torch.sum(torch.stack(ce_losses))
|
|
|
|
if mode in ['val', 'test']:
|
|
return ensemble_output, torch.stack(outputs, dim=0), ce_loss
|
|
|
|
if self.is_cot_loss:
|
|
cot_loss = self._calculate_co_training_loss(outputs, self.cot_loss_choose, epoch)
|
|
else:
|
|
cot_loss = torch.zeros_like(ce_loss)
|
|
|
|
return ensemble_output, torch.stack(outputs, dim=0), ce_loss, cot_loss
|
|
|
|
def _collect_ensemble_output(self, outputs):
|
|
"""
|
|
Calculate the ensemble output from a list of tensors.
|
|
|
|
Args:
|
|
outputs (list of tensors): A list where each tensor has shape [batch_size, num_classes].
|
|
|
|
Returns:
|
|
Tensor: The ensemble output with shape [batch_size, num_classes].
|
|
"""
|
|
stacked_outputs = torch.stack(outputs, dim=0)
|
|
|
|
if self.is_ensembled_after_softmax:
|
|
softmax_outputs = F.softmax(stacked_outputs, dim=-1)
|
|
if self.is_max_ensemble:
|
|
ensemble_output, _ = torch.max(softmax_outputs, dim=0)
|
|
else:
|
|
ensemble_output = torch.mean(softmax_outputs, dim=0)
|
|
else:
|
|
if self.is_max_ensemble:
|
|
ensemble_output, _ = torch.max(stacked_outputs, dim=0)
|
|
else:
|
|
ensemble_output = torch.mean(stacked_outputs, dim=0)
|
|
|
|
return ensemble_output
|
|
|
|
def _calculate_co_training_loss(self, outputs, loss_choose, epoch=0):
|
|
"""
|
|
Calculate the co-training loss between outputs of different networks.
|
|
|
|
Args:
|
|
outputs (list of tensors): A list where each tensor has shape [batch_size, num_classes].
|
|
loss_choose (str): Type of co-training loss to compute ('js_divergence' or 'kl_seperate').
|
|
epoch (int): Current epoch.
|
|
|
|
Returns:
|
|
Tensor: The computed co-training loss.
|
|
"""
|
|
weight_now = self.cot_weight
|
|
if self.is_cot_weight_warm_up and epoch < self.cot_weight_warm_up_epochs:
|
|
weight_now = max(self.cot_weight * epoch / self.cot_weight_warm_up_epochs, 0.005)
|
|
|
|
stacked_outputs = torch.stack(outputs, dim=0)
|
|
|
|
if loss_choose == 'js_divergence':
|
|
p_all = F.softmax(stacked_outputs, dim=-1)
|
|
p_mean = torch.mean(p_all, dim=0)
|
|
H_mean = (-p_mean * torch.log(p_mean + 1e-8)).sum(-1).mean()
|
|
H_sep = (-p_all * F.log_softmax(stacked_outputs, dim=-1)).sum(-1).mean()
|
|
cot_loss = weight_now * (H_mean - H_sep)
|
|
|
|
elif loss_choose == 'kl_seperate':
|
|
outputs_r1 = torch.repeat_interleave(stacked_outputs, self.split_factor - 1, dim=0)
|
|
index_list = [j for i in range(self.split_factor) for j in range(self.split_factor) if j != i]
|
|
outputs_r2 = torch.index_select(stacked_outputs, dim=0, index=torch.tensor(index_list, dtype=torch.long, device=stacked_outputs.device))
|
|
kl_loss = F.kl_div(F.log_softmax(outputs_r1, dim=-1), F.softmax(outputs_r2,”
|