qwerrwe / src /axolotl /utils /schedulers.py
winglian's picture
support for multi line inference input, log sweep over learning rates
9105935
raw
history blame
1.3 kB
from torch.optim.lr_scheduler import LRScheduler
class InterpolatingLogScheduler(LRScheduler):
def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1):
"""A scheduler that interpolates learning rates in a logarithmic fashion
Args:
- optimizer: pytorch optimizer
- num_steps: int, the number of steps over which to increase from the min_lr to the max_lr
- min_lr: float, the minimum learning rate
- max_lr: float, the maximum learning rate
Usage:
fc = nn.Linear(1,1)
optimizer = optim.Adam(fc.parameters())
lr_scheduler = InterpolatingLogScheduler(optimizer, num_steps=400, min_lr=1e-6, max_lr=1e-4)
"""
self.num_steps = num_steps
self.min_lr = min_lr
self.max_lr = max_lr
self.q = (max_lr / min_lr) ** (1 / num_steps - 1)
super().__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch == 0:
lr = self.min_lr
elif self.last_epoch < self.num_steps:
# FIXME, not perfect as we need to account for number of steps are in an epoch, etc
lr = self.min_lr * (self.q ** self.last_epoch)
else:
lr = self.max_lr
return [lr for _ in self.base_lrs]