# -*- coding: utf-8 -*- # @Author: Weisen Pan import json import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def load_state_dict(file): """Load a state dict from a file, handling any potential location issues.""" try: return torch.load(file) except AssertionError: return torch.load(file, map_location=lambda storage, location: storage) def flatten_parameters(model): """Flatten the parameters of the model into a single tensor.""" return torch.cat([param.data.view(-1) for param in model.parameters()]) def set_flattened_parameters(model, flat_params): """Set the model's parameters from a flattened tensor.""" prev_ind = 0 for param in model.parameters(): flat_size = int(np.prod(param.size())) param.data.copy_(flat_params[prev_ind:prev_ind + flat_size].view(param.size())) prev_ind += flat_size class RollingAverage: """Class to maintain a running average of a quantity.""" def __init__(self): self.steps = 0 self.total = 0 def update(self, val): self.total += val self.steps += 1 def value(self): return self.total / float(self.steps) if self.steps > 0 else 0 def compute_accuracy(output, target, topk=(1,)): """Compute the precision@k for the specified values of k.""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, dim=1, largest=True, sorted=True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) return [correct[:k].reshape(-1).float().sum(0).mul_(100.0 / batch_size) for k in topk] class KLDivergenceLoss(nn.Module): """Kullback-Leibler Divergence Loss.""" def __init__(self, temperature=1): super(KLDivergenceLoss, self).__init__() self.temperature = temperature def forward(self, output_batch, teacher_outputs): output_batch = F.log_softmax(output_batch / self.temperature, dim=1) teacher_outputs = F.softmax(teacher_outputs / self.temperature, dim=1) + 1e-7 return self.temperature ** 2 * nn.KLDivLoss(reduction='batchmean')(output_batch, teacher_outputs) class CELoss(nn.Module): """Cross-Entropy Loss.""" def __init__(self, temperature=1): super(CELoss, self).__init__() self.temperature = temperature def forward(self, output_batch, teacher_outputs): output_batch = F.log_softmax(output_batch / self.temperature, dim=1) teacher_outputs = F.softmax(teacher_outputs / self.temperature, dim=1) return -self.temperature ** 2 * torch.sum(output_batch * teacher_outputs) / teacher_outputs.size(0) def save_dict_to_json(data, json_path): """Save a dictionary of floats to a JSON file.""" with open(json_path, 'w') as f: json.dump({k: float(v) for k, v in data.items()}, f, indent=4) def get_optimized_params(model, model_params, master_params): """Filter out batch norm parameters from weight decay to improve accuracy.""" bn_params, remaining_params = split_bn_params(model, model_params, master_params) return [{'params': bn_params, 'weight_decay': 0}, {'params': remaining_params}] def split_bn_params(model, model_params, master_params): """Split parameters into batch norm and non-batch norm.""" def get_bn_params(module): if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): return set(module.parameters()) return {p for child in module.children() for p in get_bn_params(child)} mod_bn_params = get_bn_params(model) zipped_params = zip(model_params, master_params) mas_bn_params = [p_mast for p_mod, p_mast in zipped_params if p_mod in mod_bn_params] mas_rem_params = [p_mast for p_mod, p_mast in zipped_params if p_mod not in mod_bn_params] return mas_bn_params, mas_rem_params