# -*- coding: utf-8 -*- # @Author: Weisen Pan import torch from torch.optim import Optimizer class CustomRMSprop(Optimizer): """ Implements a modified version of the RMSprop algorithm with TensorFlow-style epsilon handling. Main differences in this implementation: 1. Epsilon is incorporated within the square root operation. 2. The moving average of squared gradients is initialized to 1. 3. The momentum buffer accumulates updates scaled by the learning rate. """ def __init__(self, params, lr=0.01, alpha=0.99, eps=1e-8, momentum=0, weight_decay=0, centered=False, decoupled_decay=False, lr_in_momentum=True): """ Initializes the optimizer with the provided parameters. Arguments: - params: iterable of parameters to optimize or dicts defining parameter groups - lr: learning rate (default: 0.01) - alpha: smoothing constant for the moving average (default: 0.99) - eps: small value to prevent division by zero (default: 1e-8) - momentum: momentum factor (default: 0) - weight_decay: weight decay (L2 penalty) (default: 0) - centered: if True, compute centered RMSprop (default: False) - decoupled_decay: if True, decouples weight decay from gradient update (default: False) - lr_in_momentum: if True, applies learning rate within the momentum buffer (default: True) """ if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") if eps < 0.0: raise ValueError(f"Invalid epsilon value: {eps}") if momentum < 0.0: raise ValueError(f"Invalid momentum value: {momentum}") if weight_decay < 0.0: raise ValueError(f"Invalid weight decay: {weight_decay}") if alpha < 0.0: raise ValueError(f"Invalid alpha value: {alpha}") # Store the optimizer defaults defaults = { 'lr': lr, 'alpha': alpha, 'eps': eps, 'momentum': momentum, 'centered': centered, 'weight_decay': weight_decay, 'decoupled_decay': decoupled_decay, 'lr_in_momentum': lr_in_momentum } super().__init__(params, defaults) def step(self, closure=None): """ Performs a single optimization step. Arguments: - closure: A closure that reevaluates the model and returns the loss. """ # Get the loss value if a closure is provided loss = closure() if closure is not None else None # Iterate over parameter groups for group in self.param_groups: lr = group['lr'] momentum = group['momentum'] weight_decay = group['weight_decay'] alpha = group['alpha'] eps = group['eps'] # Iterate over parameters in the group for p in group['params']: if p.grad is None: continue grad = p.grad.data # Get gradient data if grad.is_sparse: raise RuntimeError("RMSprop does not support sparse gradients.") # Get the state of the parameter state = self.state[p] # Initialize state if it doesn't exist if not state: state['step'] = 0 state['square_avg'] = torch.ones_like(p.data) # Initialize moving average of squared gradients to 1 if momentum > 0: state['momentum_buffer'] = torch.zeros_like(p.data) # Initialize momentum buffer if group['centered']: state['grad_avg'] = torch.zeros_like(p.data) # Initialize moving average of gradients if centered square_avg = state['square_avg'] one_minus_alpha = 1 - alpha state['step'] += 1 # Update the step count # Apply weight decay if weight_decay != 0: if group['decoupled_decay']: p.data.mul_(1 - lr * weight_decay) # Apply decoupled weight decay else: grad.add_(p.data, alpha=weight_decay) # Apply traditional weight decay # Update the moving average of squared gradients square_avg.add_((grad ** 2) - square_avg, alpha=one_minus_alpha) # Compute the denominator for gradient update if group['centered']: grad_avg = state['grad_avg'] grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha) avg = (square_avg - grad_avg ** 2).add_(eps).sqrt_() # Centered RMSprop else: avg = square_avg.add_(eps).sqrt_() # Standard RMSprop # Apply momentum if needed if momentum > 0: buf = state['momentum_buffer'] if group['lr_in_momentum']: buf.mul_(momentum).addcdiv_(grad, avg, value=lr) # Apply learning rate inside momentum buffer p.data.add_(-buf) else: buf.mul_(momentum).addcdiv_(grad, avg) # Standard momentum update p.data.add_(buf, alpha=-lr) else: p.data.addcdiv_(grad, avg, value=-lr) # Update parameter without momentum return loss # Return the loss if closure was provided