qwerrwe / src /axolotl /utils /schedulers.py
winglian's picture
black formatting
2bc1a5b
raw
history blame
1.33 kB
from torch.optim.lr_scheduler import LRScheduler
class InterpolatingLogScheduler(LRScheduler):
def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1):
"""A scheduler that interpolates learning rates in a logarithmic fashion
Args:
- optimizer: pytorch optimizer
- num_steps: int, the number of steps over which to increase from the min_lr to the max_lr
- min_lr: float, the minimum learning rate
- max_lr: float, the maximum learning rate
Usage:
fc = nn.Linear(1,1)
optimizer = optim.Adam(fc.parameters())
lr_scheduler = InterpolatingLogScheduler(optimizer, num_steps=400, min_lr=1e-6, max_lr=1e-4)
"""
self.num_steps = num_steps
self.min_lr = min_lr
self.max_lr = max_lr
self.q = (max_lr / min_lr) ** (1 / (num_steps - 1))
super().__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch <= 0:
lrs = [self.min_lr for base_lr in self.base_lrs]
elif self.last_epoch < self.num_steps:
lrs = [
self.min_lr * (self.q ** (self.last_epoch - 1))
for base_lr in self.base_lrs
]
else:
lrs = [self.max_lr for base_lr in self.base_lrs]
return lrs