Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

109 lines
3.8 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def load_state_dict(file):
"""Load a state dict from a file, handling any potential location issues."""
try:
return torch.load(file)
except AssertionError:
return torch.load(file, map_location=lambda storage, location: storage)
def flatten_parameters(model):
"""Flatten the parameters of the model into a single tensor."""
return torch.cat([param.data.view(-1) for param in model.parameters()])
def set_flattened_parameters(model, flat_params):
"""Set the model's parameters from a flattened tensor."""
prev_ind = 0
for param in model.parameters():
flat_size = int(np.prod(param.size()))
param.data.copy_(flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
prev_ind += flat_size
class RollingAverage:
"""Class to maintain a running average of a quantity."""
def __init__(self):
self.steps = 0
self.total = 0
def update(self, val):
self.total += val
self.steps += 1
def value(self):
return self.total / float(self.steps) if self.steps > 0 else 0
def compute_accuracy(output, target, topk=(1,)):
"""Compute the precision@k for the specified values of k."""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, dim=1, largest=True, sorted=True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
return [correct[:k].reshape(-1).float().sum(0).mul_(100.0 / batch_size) for k in topk]
class KLDivergenceLoss(nn.Module):
"""Kullback-Leibler Divergence Loss."""
def __init__(self, temperature=1):
super(KLDivergenceLoss, self).__init__()
self.temperature = temperature
def forward(self, output_batch, teacher_outputs):
output_batch = F.log_softmax(output_batch / self.temperature, dim=1)
teacher_outputs = F.softmax(teacher_outputs / self.temperature, dim=1) + 1e-7
return self.temperature ** 2 * nn.KLDivLoss(reduction='batchmean')(output_batch, teacher_outputs)
class CELoss(nn.Module):
"""Cross-Entropy Loss."""
def __init__(self, temperature=1):
super(CELoss, self).__init__()
self.temperature = temperature
def forward(self, output_batch, teacher_outputs):
output_batch = F.log_softmax(output_batch / self.temperature, dim=1)
teacher_outputs = F.softmax(teacher_outputs / self.temperature, dim=1)
return -self.temperature ** 2 * torch.sum(output_batch * teacher_outputs) / teacher_outputs.size(0)
def save_dict_to_json(data, json_path):
"""Save a dictionary of floats to a JSON file."""
with open(json_path, 'w') as f:
json.dump({k: float(v) for k, v in data.items()}, f, indent=4)
def get_optimized_params(model, model_params, master_params):
"""Filter out batch norm parameters from weight decay to improve accuracy."""
bn_params, remaining_params = split_bn_params(model, model_params, master_params)
return [{'params': bn_params, 'weight_decay': 0}, {'params': remaining_params}]
def split_bn_params(model, model_params, master_params):
"""Split parameters into batch norm and non-batch norm."""
def get_bn_params(module):
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
return set(module.parameters())
return {p for child in module.children() for p in get_bn_params(child)}
mod_bn_params = get_bn_params(model)
zipped_params = zip(model_params, master_params)
mas_bn_params = [p_mast for p_mod, p_mast in zipped_params if p_mod in mod_bn_params]
mas_rem_params = [p_mast for p_mod, p_mast in zipped_params if p_mod not in mod_bn_params]
return mas_bn_params, mas_rem_params