# -*- coding: utf-8 -*-
# @Author: Weisen Pan

from __future__ import absolute_import, division, print_function
import json
import torch
from config import *

# Function to save hyperparameters into a JSON file
def store_hyperparameters_json(args):
    """Save hyperparameters to a JSON file."""
    # Create the model directory if it does not exist
    os.makedirs(args.model_dir, exist_ok=True)
    # Determine the filename based on whether it's evaluation or training mode
    filename = os.path.join(args.model_dir, 'hparams_eval.json' if args.evaluate else 'hparams_train.json')
    # Convert the arguments to a dictionary
    hparams = vars(args)
    # Write the hyperparameters to a JSON file with indentation and sorted keys
    with open(filename, 'w') as f:
        json.dump(hparams, f, indent=4, sort_keys=True)

# Function to add parser arguments for command-line interface
def add_parser_arguments(parser):
    # Dataset and model settings
    parser.add_argument('--data', type=str, default=f"{data_dir}/dataset_hub/", help='Path to dataset')  # Path to the dataset
    parser.add_argument('--model_dir', type=str, default="EdgeFLite", help='Directory to save the model')  # Directory where the model is saved
    parser.add_argument('--arch', type=str, default='wide_resnet16_8', choices=[
        'resnet110', 'resnet_model_110sl', 'wide_resnet16_8', 'wide_resnetsl16_8', 
        'wide_resnet_model_50_2', 'wide_resnetsl50_2'], help='Neural architecture name')  # Neural architecture options

    # Normalization and training settings
    parser.add_argument('--norm_mode', type=str, default='batch', choices=['batch', 'group', 'layer', 'instance', 'none'], help='Batch normalization style')  # Type of normalization used
    parser.add_argument('--is_syncbn', default=0, type=int, help='Use nn.SyncBatchNorm or not')  # Whether to use synchronized batch normalization
    parser.add_argument('--workers', default=16, type=int, help='Number of data loading workers')  # Number of workers for data loading
    parser.add_argument('--epochs', default=650, type=int, help='Total epochs to run')  # Total number of training epochs
    parser.add_argument('--start_epoch', default=0, type=int, help='Manual epoch number for restarts')  # Starting epoch number for restarting training
    parser.add_argument('--eval_per_epoch', default=1, type=int, help='Evaluation frequency per epoch')  # Frequency of evaluation during training
    parser.add_argument('--spid', default="EdgeFLite", type=str, help='Experiment name')  # Name of the experiment
    parser.add_argument('--save_weight', default=False, type=bool, help='Save model weights')  # Whether to save model weights

    # Data augmentation settings
    parser.add_argument('--batch_size', default=128, type=int, help='Mini-batch size for training')  # Batch size for training
    parser.add_argument('--eval_batch_size', default=100, type=int, help='Mini-batch size for evaluation')  # Batch size for evaluation
    parser.add_argument('--crop_size', default=32, type=int, help='Crop size for images')  # Size of the image crops
    parser.add_argument('--output_stride', default=8, type=int, help='Output stride for model')  # Output stride for the model
    parser.add_argument('--padding', default=4, type=int, help='Padding size for images')  # Padding size for image processing

    # Learning rate settings
    parser.add_argument('--lr_mode', type=str, default='cos', choices=['cos', 'step', 'poly', 'HTD', 'exponential'], help='Learning rate strategy')  # Strategy for adjusting learning rate
    parser.add_argument('--lr', '--learning_rate', default=0.1, type=float, help='Initial learning rate')  # Initial learning rate value
    parser.add_argument('--optimizer', type=str, default='SGD', choices=['SGD', 'AdamW', 'RMSprop', 'RMSpropTF'], help='Optimizer choice')  # Choice of optimizer
    parser.add_argument('--lr_milestones', nargs='+', type=int, default=[100, 200], help='Epochs for learning rate steps')  # Epochs where learning rate adjustments occur
    parser.add_argument('--lr_step_multiplier', default=0.1, type=float, help='Multiplier at learning rate milestones')  # Multiplier applied at learning rate steps
    parser.add_argument('--end_lr', type=float, default=1e-4, help='Ending learning rate')  # Final learning rate value

    # Additional hyperparameters
    parser.add_argument('--weight_decay', default=1e-4, type=float, help='Weight decay for regularization')  # Weight decay for L2 regularization
    parser.add_argument('--momentum', default=0.9, type=float, help='Optimizer momentum')  # Momentum for optimizers like SGD
    parser.add_argument('--print_freq', default=20, type=int, help='Print frequency for logging')  # Frequency for printing logs during training

    # Federated learning settings
    parser.add_argument('--is_fed', default=1, type=int, help='Enable federated learning')  # Enable or disable federated learning
    parser.add_argument('--num_clusters', default=20, type=int, help='Number of clusters for federated learning')  # Number of clusters in federated learning
    parser.add_argument('--num_selected', default=20, type=int, help='Number of clients selected for training per round')  # Number of clients selected each round
    parser.add_argument('--num_rounds', default=300, type=int, help='Total number of training rounds')  # Total number of federated learning rounds

    # Processing and decentralized training settings
    parser.add_argument('--gpu', default=None, type=int, help='GPU ID to use')  # GPU ID to be used for training
    parser.add_argument('--no_cuda', action='store_true', default=False, help='Disable CUDA training')  # Whether to disable CUDA
    parser.add_argument('--gpu_ids', type=str, default='0', help='Comma-separated list of GPU IDs for training')  # Comma-separated GPU IDs for multi-GPU training

    # Parse command-line arguments
    args = parser.parse_args()

    # Additional configurations
    args.cuda = not args.no_cuda and torch.cuda.is_available()  # Enable CUDA if not disabled and available
    if args.cuda:
        args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]  # Parse GPU IDs from comma-separated string
        args.num_gpus = len(args.gpu_ids)  # Count number of GPUs being used

    return args