# -*- coding: utf-8 -*-
# @Author: Weisen Pan

# Load environment modules and required dependencies
source /etc/profile.d/modules.sh
module load gcc/11.2.0    # Load GCC version 11.2.0
module load openmpi/4.1.3 # Load OpenMPI version 4.1.3
module load cuda/11.5/11.5.2 # Load CUDA version 11.5.2
module load cudnn/8.3/8.3.3  # Load cuDNN version 8.3.3
module load nccl/2.11/2.11.4-1 # Load NCCL version 2.11.4-1
module load python/3.10/3.10.4 # Load Python version 3.10.4

# Activate the virtual Python environment
source ~/venv/pytorch1.11+horovod/bin/activate # Activate a virtual environment for PyTorch and Horovod

# Define the log directory, clean up old records if any, and recreate the directory
LOG_PATH="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}" 
rm -rf ${LOG_PATH} # Remove any existing log directory
mkdir -p ${LOG_PATH} # Create a new log directory

# Set up the local data directory and copy the dataset into it
DATA_STORAGE="${SGE_LOCALDIR}/${JOB_ID}/" # Define a local data directory for the job
cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${DATA_STORAGE} # Copy CIFAR-100 dataset to the local directory

# Navigate to the working directory where training scripts are located
cd EdgeFLite # Change directory to the EdgeFLite project

# Execute the training script with federated learning parameters
python run_gkt.py \
    --is_fed=1 \                       # Enable federated learning
    --fixed_cluster=0 \                 # Allow dynamic cluster formation
    --split_factor=1 \                  # Data split factor
    --num_clusters=20 \                 # Number of clusters
    --num_selected=20 \                 # Number of selected clients per round
    --arch="wide_resnet16_8" \          # Network architecture: Wide ResNet 16-8
    --dataset="cifar10" \               # Use CIFAR-10 dataset
    --num_classes=10 \                  # Number of classes in CIFAR-10
    --is_single_branch=0 \              # Multi-branch network
    --is_amp=0 \                        # Disable Automatic Mixed Precision (AMP)
    --num_rounds=300 \                  # Number of federated learning rounds
    --fed_epochs=1 \                    # Number of local training epochs per round
    --cifar10_non_iid="quantity_skew" \ # Non-IID data distribution: quantity skew
    --spid="FGKT_W168_20c_skew" \       # Set a specific job identifier
    --data=${DATA_STORAGE}              # Path to the dataset