# -*- coding: utf-8 -*-
# @Author: Weisen Pan

import torch
import torch.nn as nn

def _count_rnn_cell(input_size, hidden_size, bias=True):
    """Calculate the total operations for a single RNN cell.
    
    Args:
        input_size (int): Size of the input.
        hidden_size (int): Size of the hidden state.
        bias (bool, optional): Whether the RNN cell uses bias. Defaults to True.
    
    Returns:
        int: Total number of operations for the RNN cell.
    """
    ops = hidden_size * (input_size + hidden_size) + hidden_size
    if bias:
        ops += hidden_size * 2
    return ops

def count_rnn_cell(cell: nn.RNNCell, x: torch.Tensor):
    """Count operations for the RNNCell over a batch of input.
    
    Args:
        cell (nn.RNNCell): The RNNCell to count operations for.
        x (torch.Tensor): Input tensor.
    """
    ops = _count_rnn_cell(cell.input_size, cell.hidden_size, cell.bias)
    batch_size = x[0].size(0)
    total_ops = ops * batch_size
    cell.total_ops += torch.DoubleTensor([int(total_ops)])

def _count_gru_cell(input_size, hidden_size, bias=True):
    """Calculate the total operations for a single GRU cell.
    
    Args:
        input_size (int): Size of the input.
        hidden_size (int): Size of the hidden state.
        bias (bool, optional): Whether the GRU cell uses bias. Defaults to True.
    
    Returns:
        int: Total number of operations for the GRU cell.
    """
    ops = (hidden_size + input_size) * hidden_size + hidden_size
    if bias:
        ops += hidden_size * 2
    ops *= 2  # For reset and update gates

    ops += (hidden_size + input_size) * hidden_size + hidden_size  # Calculate new gate
    if bias:
        ops += hidden_size * 2
    ops += hidden_size  # Hadamard product
    ops += hidden_size * 3  # Final output

    return ops

def count_gru_cell(cell: nn.GRUCell, x: torch.Tensor):
    """Count operations for the GRUCell over a batch of input.
    
    Args:
        cell (nn.GRUCell): The GRUCell to count operations for.
        x (torch.Tensor): Input tensor.
    """
    ops = _count_gru_cell(cell.input_size, cell.hidden_size, cell.bias)
    batch_size = x[0].size(0)
    total_ops = ops * batch_size
    cell.total_ops += torch.DoubleTensor([int(total_ops)])

def _count_lstm_cell(input_size, hidden_size, bias=True):
    """Calculate the total operations for a single LSTM cell.
    
    Args:
        input_size (int): Size of the input.
        hidden_size (int): Size of the hidden state.
        bias (bool, optional): Whether the LSTM cell uses bias. Defaults to True.
    
    Returns:
        int: Total number of operations for the LSTM cell.
    """
    ops = (input_size + hidden_size) * hidden_size + hidden_size
    if bias:
        ops += hidden_size * 2
    ops *= 4  # For input, forget, output, and cell gates

    ops += hidden_size * 3  # Cell state update
    ops += hidden_size  # Final output

    return ops

def count_lstm_cell(cell: nn.LSTMCell, x: torch.Tensor):
    """Count operations for the LSTMCell over a batch of input.
    
    Args:
        cell (nn.LSTMCell): The LSTMCell to count operations for.
        x (torch.Tensor): Input tensor.
    """
    ops = _count_lstm_cell(cell.input_size, cell.hidden_size, cell.bias)
    batch_size = x[0].size(0)
    total_ops = ops * batch_size
    cell.total_ops += torch.DoubleTensor([int(total_ops)])

def _count_rnn_layers(model: nn.RNN, num_layers, input_size, hidden_size):
    """Calculate the total operations for RNN layers.
    
    Args:
        model (nn.RNN): The RNN model.
        num_layers (int): Number of layers in the RNN.
        input_size (int): Size of the input.
        hidden_size (int): Size of the hidden state.
    
    Returns:
        int: Total number of operations for the RNN layers.
    """
    ops = _count_rnn_cell(input_size, hidden_size, model.bias)
    for _ in range(num_layers - 1):
        ops += _count_rnn_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias)
    return ops

def count_rnn(model: nn.RNN, x: torch.Tensor):
    """Count operations for the entire RNN over a batch of input.
    
    Args:
        model (nn.RNN): The RNN model.
        x (torch.Tensor): Input tensor.
    """
    batch_size = x[0].size(0) if model.batch_first else x[0].size(1)
    num_steps = x[0].size(1) if model.batch_first else x[0].size(0)

    ops = _count_rnn_layers(model, model.num_layers, model.input_size, model.hidden_size)
    total_ops = ops * num_steps * batch_size
    model.total_ops += torch.DoubleTensor([int(total_ops)])

def _count_gru_layers(model: nn.GRU, num_layers, input_size, hidden_size):
    """Calculate the total operations for GRU layers.
    
    Args:
        model (nn.GRU): The GRU model.
        num_layers (int): Number of layers in the GRU.
        input_size (int): Size of the input.
        hidden_size (int): Size of the hidden state.
    
    Returns:
        int: Total number of operations for the GRU layers.
    """
    ops = _count_gru_cell(input_size, hidden_size, model.bias)
    for _ in range(num_layers - 1):
        ops += _count_gru_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias)
    return ops

def count_gru(model: nn.GRU, x: torch.Tensor):
    """Count operations for the entire GRU over a batch of input.
    
    Args:
        model (nn.GRU): The GRU model.
        x (torch.Tensor): Input tensor.
    """
    batch_size = x[0].size(0) if model.batch_first else x[0].size(1)
    num_steps = x[0].size(1) if model.batch_first else x[0].size(0)

    ops = _count_gru_layers(model, model.num_layers, model.input_size, model.hidden_size)
    total_ops = ops * num_steps * batch_size
    model.total_ops += torch.DoubleTensor([int(total_ops)])

def _count_lstm_layers(model: nn.LSTM, num_layers, input_size, hidden_size):
    """Calculate the total operations for LSTM layers.
    
    Args:
        model (nn.LSTM): The LSTM model.
        num_layers (int): Number of layers in the LSTM.
        input_size (int): Size of the input.
        hidden_size (int): Size of the hidden state.
    
    Returns:
        int: Total number of operations for the LSTM layers.
    """
    ops = _count_lstm_cell(input_size, hidden_size, model.bias)
    for _ in range(num_layers - 1):
        ops += _count_lstm_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias)
    return ops

def count_lstm(model: nn.LSTM, x: torch.Tensor):
    """Count operations for the entire LSTM over a batch of input.
    
    Args:
        model (nn.LSTM): The LSTM model.
        x (torch.Tensor): Input tensor.
    """
    batch_size = x[0].size(0) if model.batch_first else x[0].size(1)
    num_steps = x[0].size(1) if model.batch_first else x[0].size(0)

    ops = _count_lstm_layers(model, model.num_layers, model.input_size, model.hidden_size)
    total_ops = ops * num_steps * batch_size
    model.total_ops += torch.DoubleTensor([int(total_ops)])