82 lines
6.2 KiB
Python
82 lines
6.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
from __future__ import absolute_import, division, print_function
|
|
import json
|
|
import torch
|
|
from config import *
|
|
|
|
# Function to save hyperparameters into a JSON file
|
|
def store_hyperparameters_json(args):
|
|
"""Save hyperparameters to a JSON file."""
|
|
# Create the model directory if it does not exist
|
|
os.makedirs(args.model_dir, exist_ok=True)
|
|
# Determine the filename based on whether it's evaluation or training mode
|
|
filename = os.path.join(args.model_dir, 'hparams_eval.json' if args.evaluate else 'hparams_train.json')
|
|
# Convert the arguments to a dictionary
|
|
hparams = vars(args)
|
|
# Write the hyperparameters to a JSON file with indentation and sorted keys
|
|
with open(filename, 'w') as f:
|
|
json.dump(hparams, f, indent=4, sort_keys=True)
|
|
|
|
# Function to add parser arguments for command-line interface
|
|
def add_parser_arguments(parser):
|
|
# Dataset and model settings
|
|
parser.add_argument('--data', type=str, default=f"{data_dir}/dataset_hub/", help='Path to dataset') # Path to the dataset
|
|
parser.add_argument('--model_dir', type=str, default="EdgeFLite", help='Directory to save the model') # Directory where the model is saved
|
|
parser.add_argument('--arch', type=str, default='wide_resnet16_8', choices=[
|
|
'resnet110', 'resnet_model_110sl', 'wide_resnet16_8', 'wide_resnetsl16_8',
|
|
'wide_resnet_model_50_2', 'wide_resnetsl50_2'], help='Neural architecture name') # Neural architecture options
|
|
|
|
# Normalization and training settings
|
|
parser.add_argument('--norm_mode', type=str, default='batch', choices=['batch', 'group', 'layer', 'instance', 'none'], help='Batch normalization style') # Type of normalization used
|
|
parser.add_argument('--is_syncbn', default=0, type=int, help='Use nn.SyncBatchNorm or not') # Whether to use synchronized batch normalization
|
|
parser.add_argument('--workers', default=16, type=int, help='Number of data loading workers') # Number of workers for data loading
|
|
parser.add_argument('--epochs', default=650, type=int, help='Total epochs to run') # Total number of training epochs
|
|
parser.add_argument('--start_epoch', default=0, type=int, help='Manual epoch number for restarts') # Starting epoch number for restarting training
|
|
parser.add_argument('--eval_per_epoch', default=1, type=int, help='Evaluation frequency per epoch') # Frequency of evaluation during training
|
|
parser.add_argument('--spid', default="EdgeFLite", type=str, help='Experiment name') # Name of the experiment
|
|
parser.add_argument('--save_weight', default=False, type=bool, help='Save model weights') # Whether to save model weights
|
|
|
|
# Data augmentation settings
|
|
parser.add_argument('--batch_size', default=128, type=int, help='Mini-batch size for training') # Batch size for training
|
|
parser.add_argument('--eval_batch_size', default=100, type=int, help='Mini-batch size for evaluation') # Batch size for evaluation
|
|
parser.add_argument('--crop_size', default=32, type=int, help='Crop size for images') # Size of the image crops
|
|
parser.add_argument('--output_stride', default=8, type=int, help='Output stride for model') # Output stride for the model
|
|
parser.add_argument('--padding', default=4, type=int, help='Padding size for images') # Padding size for image processing
|
|
|
|
# Learning rate settings
|
|
parser.add_argument('--lr_mode', type=str, default='cos', choices=['cos', 'step', 'poly', 'HTD', 'exponential'], help='Learning rate strategy') # Strategy for adjusting learning rate
|
|
parser.add_argument('--lr', '--learning_rate', default=0.1, type=float, help='Initial learning rate') # Initial learning rate value
|
|
parser.add_argument('--optimizer', type=str, default='SGD', choices=['SGD', 'AdamW', 'RMSprop', 'RMSpropTF'], help='Optimizer choice') # Choice of optimizer
|
|
parser.add_argument('--lr_milestones', nargs='+', type=int, default=[100, 200], help='Epochs for learning rate steps') # Epochs where learning rate adjustments occur
|
|
parser.add_argument('--lr_step_multiplier', default=0.1, type=float, help='Multiplier at learning rate milestones') # Multiplier applied at learning rate steps
|
|
parser.add_argument('--end_lr', type=float, default=1e-4, help='Ending learning rate') # Final learning rate value
|
|
|
|
# Additional hyperparameters
|
|
parser.add_argument('--weight_decay', default=1e-4, type=float, help='Weight decay for regularization') # Weight decay for L2 regularization
|
|
parser.add_argument('--momentum', default=0.9, type=float, help='Optimizer momentum') # Momentum for optimizers like SGD
|
|
parser.add_argument('--print_freq', default=20, type=int, help='Print frequency for logging') # Frequency for printing logs during training
|
|
|
|
# Federated learning settings
|
|
parser.add_argument('--is_fed', default=1, type=int, help='Enable federated learning') # Enable or disable federated learning
|
|
parser.add_argument('--num_clusters', default=20, type=int, help='Number of clusters for federated learning') # Number of clusters in federated learning
|
|
parser.add_argument('--num_selected', default=20, type=int, help='Number of clients selected for training per round') # Number of clients selected each round
|
|
parser.add_argument('--num_rounds', default=300, type=int, help='Total number of training rounds') # Total number of federated learning rounds
|
|
|
|
# Processing and decentralized training settings
|
|
parser.add_argument('--gpu', default=None, type=int, help='GPU ID to use') # GPU ID to be used for training
|
|
parser.add_argument('--no_cuda', action='store_true', default=False, help='Disable CUDA training') # Whether to disable CUDA
|
|
parser.add_argument('--gpu_ids', type=str, default='0', help='Comma-separated list of GPU IDs for training') # Comma-separated GPU IDs for multi-GPU training
|
|
|
|
# Parse command-line arguments
|
|
args = parser.parse_args()
|
|
|
|
# Additional configurations
|
|
args.cuda = not args.no_cuda and torch.cuda.is_available() # Enable CUDA if not disabled and available
|
|
if args.cuda:
|
|
args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] # Parse GPU IDs from comma-separated string
|
|
args.num_gpus = len(args.gpu_ids) # Count number of GPUs being used
|
|
|
|
return args
|