109 lines
3.8 KiB
Python
109 lines
3.8 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import json
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def load_state_dict(file):
|
|
"""Load a state dict from a file, handling any potential location issues."""
|
|
try:
|
|
return torch.load(file)
|
|
except AssertionError:
|
|
return torch.load(file, map_location=lambda storage, location: storage)
|
|
|
|
|
|
def flatten_parameters(model):
|
|
"""Flatten the parameters of the model into a single tensor."""
|
|
return torch.cat([param.data.view(-1) for param in model.parameters()])
|
|
|
|
|
|
def set_flattened_parameters(model, flat_params):
|
|
"""Set the model's parameters from a flattened tensor."""
|
|
prev_ind = 0
|
|
for param in model.parameters():
|
|
flat_size = int(np.prod(param.size()))
|
|
param.data.copy_(flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
|
|
prev_ind += flat_size
|
|
|
|
|
|
class RollingAverage:
|
|
"""Class to maintain a running average of a quantity."""
|
|
def __init__(self):
|
|
self.steps = 0
|
|
self.total = 0
|
|
|
|
def update(self, val):
|
|
self.total += val
|
|
self.steps += 1
|
|
|
|
def value(self):
|
|
return self.total / float(self.steps) if self.steps > 0 else 0
|
|
|
|
|
|
def compute_accuracy(output, target, topk=(1,)):
|
|
"""Compute the precision@k for the specified values of k."""
|
|
maxk = max(topk)
|
|
batch_size = target.size(0)
|
|
|
|
_, pred = output.topk(maxk, dim=1, largest=True, sorted=True)
|
|
pred = pred.t()
|
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
|
|
return [correct[:k].reshape(-1).float().sum(0).mul_(100.0 / batch_size) for k in topk]
|
|
|
|
|
|
class KLDivergenceLoss(nn.Module):
|
|
"""Kullback-Leibler Divergence Loss."""
|
|
def __init__(self, temperature=1):
|
|
super(KLDivergenceLoss, self).__init__()
|
|
self.temperature = temperature
|
|
|
|
def forward(self, output_batch, teacher_outputs):
|
|
output_batch = F.log_softmax(output_batch / self.temperature, dim=1)
|
|
teacher_outputs = F.softmax(teacher_outputs / self.temperature, dim=1) + 1e-7
|
|
return self.temperature ** 2 * nn.KLDivLoss(reduction='batchmean')(output_batch, teacher_outputs)
|
|
|
|
|
|
class CELoss(nn.Module):
|
|
"""Cross-Entropy Loss."""
|
|
def __init__(self, temperature=1):
|
|
super(CELoss, self).__init__()
|
|
self.temperature = temperature
|
|
|
|
def forward(self, output_batch, teacher_outputs):
|
|
output_batch = F.log_softmax(output_batch / self.temperature, dim=1)
|
|
teacher_outputs = F.softmax(teacher_outputs / self.temperature, dim=1)
|
|
return -self.temperature ** 2 * torch.sum(output_batch * teacher_outputs) / teacher_outputs.size(0)
|
|
|
|
|
|
def save_dict_to_json(data, json_path):
|
|
"""Save a dictionary of floats to a JSON file."""
|
|
with open(json_path, 'w') as f:
|
|
json.dump({k: float(v) for k, v in data.items()}, f, indent=4)
|
|
|
|
|
|
def get_optimized_params(model, model_params, master_params):
|
|
"""Filter out batch norm parameters from weight decay to improve accuracy."""
|
|
bn_params, remaining_params = split_bn_params(model, model_params, master_params)
|
|
return [{'params': bn_params, 'weight_decay': 0}, {'params': remaining_params}]
|
|
|
|
|
|
def split_bn_params(model, model_params, master_params):
|
|
"""Split parameters into batch norm and non-batch norm."""
|
|
def get_bn_params(module):
|
|
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
|
return set(module.parameters())
|
|
return {p for child in module.children() for p in get_bn_params(child)}
|
|
|
|
mod_bn_params = get_bn_params(model)
|
|
zipped_params = zip(model_params, master_params)
|
|
|
|
mas_bn_params = [p_mast for p_mod, p_mast in zipped_params if p_mod in mod_bn_params]
|
|
mas_rem_params = [p_mast for p_mod, p_mast in zipped_params if p_mod not in mod_bn_params]
|
|
|
|
return mas_bn_params, mas_rem_params
|