""" BROS Copyright 2022-present NAVER Corp. Apache License v2.0 """ import math import numpy as np from torch.optim.lr_scheduler import LambdaLR def linear_scheduler(optimizer, warmup_steps, training_steps, last_epoch=-1): """linear_scheduler with warmup from huggingface""" def lr_lambda(current_step): if current_step < warmup_steps: return float(current_step) / float(max(1, warmup_steps)) return max( 0.0, float(training_steps - current_step) / float(max(1, training_steps - warmup_steps)), ) return LambdaLR(optimizer, lr_lambda, last_epoch) def cosine_scheduler( optimizer, warmup_steps, training_steps, cycles=0.5, last_epoch=-1 ): """Cosine LR scheduler with warmup from huggingface""" def lr_lambda(current_step): if current_step < warmup_steps: return current_step / max(1, warmup_steps) progress = current_step - warmup_steps progress /= max(1, training_steps - warmup_steps) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * cycles * 2 * progress))) return LambdaLR(optimizer, lr_lambda, last_epoch) def multistep_scheduler(optimizer, warmup_steps, milestones, gamma=0.1, last_epoch=-1): def lr_lambda(current_step): if current_step < warmup_steps: # calculate a warmup ratio return current_step / max(1, warmup_steps) else: # calculate a multistep lr scaling ratio idx = np.searchsorted(milestones, current_step) return gamma ** idx return LambdaLR(optimizer, lr_lambda, last_epoch)