# -*- coding: utf-8 -*- # @Author: Weisen Pan import math class CustomScheduler: def __init__(self, mode='cosine', initial_lr=0.1, num_epochs=100, iters_per_epoch=300, lr_milestones=None, lr_step=100, step_multiplier=0.1, slow_start_epochs=0, slow_start_lr=1e-4, min_lr=1e-3, multiplier=1.0, lower_bound=-6.0, upper_bound=3.0, decay_factor=0.97, decay_epochs=0.8, staircase=True): """ Initialize the learning rate scheduler. Parameters: mode (str): Mode for learning rate adjustment ('cosine', 'poly', 'HTD', 'step', 'exponential'). initial_lr (float): Initial learning rate. num_epochs (int): Total number of epochs. iters_per_epoch (int): Number of iterations per epoch. lr_milestones (list): Epoch milestones for learning rate decay in 'step' mode. lr_step (int): Epoch step size for learning rate reduction in 'step' mode. step_multiplier (float): Multiplication factor for learning rate reduction in 'step' mode. slow_start_epochs (int): Number of slow start epochs for warm-up. slow_start_lr (float): Learning rate during warm-up. min_lr (float): Minimum learning rate limit. multiplier (float): Multiplication factor for applying to different parameter groups. lower_bound (float): Lower bound for the tanh function in 'HTD' mode. upper_bound (float): Upper bound for the tanh function in 'HTD' mode. decay_factor (float): Factor by which learning rate decays in 'exponential' mode. decay_epochs (float): Number of epochs over which learning rate decays in 'exponential' mode. staircase (bool): If True, apply step-wise learning rate decay in 'exponential' mode. """ # Ensure valid mode selection assert mode in ['cosine', 'poly', 'HTD', 'step', 'exponential'], "Invalid mode." # Initialize learning rate settings self.initial_lr = initial_lr self.current_lr = initial_lr self.min_lr = min_lr self.mode = mode self.num_epochs = num_epochs self.iters_per_epoch = iters_per_epoch self.total_iterations = (num_epochs - slow_start_epochs) * iters_per_epoch self.slow_start_iters = slow_start_epochs * iters_per_epoch self.slow_start_lr = slow_start_lr self.multiplier = multiplier self.lr_step = lr_step self.lr_milestones = lr_milestones self.step_multiplier = step_multiplier self.lower_bound = lower_bound self.upper_bound = upper_bound self.decay_factor = decay_factor self.decay_steps = decay_epochs * iters_per_epoch self.staircase = staircase print(f"INFO: Using {self.mode} learning rate scheduler with {slow_start_epochs} warm-up epochs.") def update_lr(self, optimizer, iteration, epoch): """Update the learning rate based on the current iteration and epoch.""" current_iter = epoch * self.iters_per_epoch + iteration # During slow start, linearly increase the learning rate if current_iter <= self.slow_start_iters: lr = self.slow_start_lr + (self.initial_lr - self.slow_start_lr) * (current_iter / self.slow_start_iters) else: # After slow start, calculate learning rate based on the selected mode lr = self._calculate_lr(current_iter - self.slow_start_iters) # Ensure learning rate does not fall below the minimum limit self.current_lr = max(lr, self.min_lr) self._apply_lr(optimizer, self.current_lr) def _calculate_lr(self, adjusted_iter): """Calculate the learning rate based on the selected scheduling mode.""" if self.mode == 'cosine': # Cosine annealing schedule return 0.5 * self.initial_lr * (1 + math.cos(math.pi * adjusted_iter / self.total_iterations)) elif self.mode == 'poly': # Polynomial decay schedule return self.initial_lr * (1 - adjusted_iter / self.total_iterations) ** 0.9 elif self.mode == 'HTD': # Hyperbolic tangent decay schedule ratio = adjusted_iter / self.total_iterations return 0.5 * self.initial_lr * (1 - math.tanh(self.lower_bound + (self.upper_bound - self.lower_bound) * ratio)) elif self.mode == 'step': # Step decay schedule return self._step_lr(adjusted_iter) elif self.mode == 'exponential': # Exponential decay schedule power = math.floor(adjusted_iter / self.decay_steps) if self.staircase else adjusted_iter / self.decay_steps return self.initial_lr * (self.decay_factor ** power) else: raise NotImplementedError("Unknown learning rate mode.") def _step_lr(self, adjusted_iter): """Calculate the learning rate for the 'step' mode.""" epoch = adjusted_iter // self.iters_per_epoch # Count how many milestones or steps have passed if self.lr_milestones: num_steps = sum([1 for milestone in self.lr_milestones if epoch >= milestone]) else: num_steps = epoch // self.lr_step return self.initial_lr * (self.step_multiplier ** num_steps) def _apply_lr(self, optimizer, lr): """Apply the calculated learning rate to the optimizer.""" for i, param_group in enumerate(optimizer.param_groups): # Apply multiplier to parameter groups beyond the first one param_group['lr'] = lr * (self.multiplier if i > 1 else 1.0) def adjust_hyperparameters(args): """Adjust the learning rate and momentum based on the batch size.""" print(f'Adjusting LR and momentum. Original LR: {args.lr}, Original momentum: {args.momentum}') # Set standard batch size for scaling standard_batch_size = 128 if 'cifar' in args.dataset else NotImplementedError # Scale momentum and learning rate args.momentum = args.momentum ** (args.batch_size / standard_batch_size) args.lr *= (args.batch_size / standard_batch_size) print(f'Adjusted LR: {args.lr}, Adjusted momentum: {args.momentum}') return args def separate_parameters(model, weight_decay_for_norm=0): """Separate the model parameters into two groups: regular parameters and norm-based parameters.""" regular_params, norm_params = [], [] for name, param in model.named_parameters(): if param.requires_grad: # Parameters related to normalization and biases are treated separately if 'norm' in name or 'bias' in name: norm_params.append(param) else: regular_params.append(param) # Return parameter groups with corresponding weight decay for norm parameters return [{'params': regular_params}, {'params': norm_params, 'weight_decay': weight_decay_for_norm}]