# -*- coding: utf-8 -*-
# @Author: Weisen Pan

# Import necessary libraries
import torch  # PyTorch for tensor computations and neural networks
from torch import nn  # Neural network module
# "decentralized" is not a valid import in PyTorch, possibly a typo. Removed for now.

# Check for available device (CPU or GPU)
# If a GPU is available (CUDA), the code will use it; otherwise, it falls back to CPU.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define normalization layer and the number of initial input channels for the convolutional layers
batch_norm_layer = nn.BatchNorm2d  # 2D Batch Normalization to stabilize training
initial_channels = 32  # Number of channels for the first convolutional layer

# Define the convolutional neural network (CNN) architecture using nn.Sequential
network = nn.Sequential(
    # 1st convolutional layer: takes 3 input channels (RGB image), outputs 'initial_channels' feature maps
    # Uses kernel size 3, stride 2 for downsampling, and padding 1 to maintain spatial dimensions
    nn.Conv2d(in_channels=3, out_channels=initial_channels, kernel_size=3, stride=2, padding=1, bias=False),
    batch_norm_layer(initial_channels),  # Apply Batch Normalization to the output
    nn.ReLU(inplace=True),  # ReLU activation function to introduce non-linearity
    
    # 2nd convolutional layer: takes 'initial_channels' input, outputs the same number of feature maps
    # No downsampling here (stride 1)
    nn.Conv2d(in_channels=initial_channels, out_channels=initial_channels, kernel_size=3, stride=1, padding=1, bias=False),
    batch_norm_layer(initial_channels),  # Batch normalization for better convergence
    nn.ReLU(inplace=True),  # ReLU activation
    
    # 3rd convolutional layer: doubles the number of output channels (for deeper features)
    # Again, no downsampling (stride 1)
    nn.Conv2d(in_channels=initial_channels, out_channels=initial_channels * 2, kernel_size=3, stride=1, padding=1, bias=False),
    batch_norm_layer(initial_channels * 2),  # Batch normalization for the increased feature maps
    nn.ReLU(inplace=True),  # ReLU activation
    
    # Max pooling layer to further downsample the feature maps (reduces spatial dimensions)
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # Pooling with kernel size 3 and stride 2
)

# Create a dummy input tensor simulating a batch of 128 images with 3 channels (RGB), each of size 64x64
sample_input = torch.randn(128, 3, 64, 64)

# Print the defined network architecture and the shape of the output after a forward pass
print(network)
# Perform a forward pass with the sample input and print the resulting output shape
print(network(sample_input).shape)