# -*- coding: utf-8 -*-
# @Author: Weisen Pan

# Load necessary system modules for the job
source /etc/profile.d/modules.sh
module load gcc/11.2.0               # Load GCC compiler
module load openmpi/4.1.3            # Load OpenMPI for distributed computing
module load cuda/11.5/11.5.2         # Load CUDA for GPU acceleration
module load cudnn/8.3/8.3.3          # Load cuDNN for deep learning frameworks
module load nccl/2.11/2.11.4-1       # Load NCCL for multi-GPU communication
module load python/3.10/3.10.4       # Load Python 3.10 environment

# Activate the required Python virtual environment
source ~/venv/pytorch1.11+horovod/bin/activate  # Activate PyTorch 1.11 + Horovod environment

# Define log directory and clean up any existing records before starting
LOG_PATH="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}"  # Set log path
rm -rf ${LOG_PATH}  # Remove any existing log directory
mkdir -p ${LOG_PATH}  # Create new log directory

# Copy the dataset to the local temporary directory
DATA_DIR="${SGE_LOCALDIR}/${JOB_ID}/"  # Set the local directory for dataset
cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${DATA_DIR}  # Copy CIFAR-100 dataset to the local directory

# Move to the directory containing the training scripts
cd EdgeFLite  # Change to EdgeFLite project directory

# Start the federated learning training process with the specified parameters
python run_gkt.py \
    --is_fed=1 \                        # Enable federated learning
    --fixed_cluster=0 \                  # Use dynamic clustering
    --split_factor=1 \                   # Set data split factor
    --num_clusters=20 \                  # Set the number of clusters
    --num_selected=20 \                  # Number of selected clients per round
    --arch="resnet_model_110sl" \               # Model architecture (ResNet 110 with single-layer output)
    --dataset="cifar100" \               # Dataset used (CIFAR-100)
    --num_classes=100 \                  # Number of classes in the dataset
    --is_single_branch=0 \               # Enable multi-branch model
    --is_amp=0 \                         # Disable automatic mixed precision
    --num_rounds=650 \                   # Number of federated learning rounds
    --fed_epochs=1 \                     # Number of local epochs per federated round
    --spid="FGKT_R110_20c_650r" \        # Experiment ID for logging and tracking
    --data=${DATA_DIR}                   # Specify the path to the dataset