Skip to content

base

LearningRateScheduler(optimizer, init_lr)

Bases: _LRScheduler

Provides inteface of learning rate scheduler.

Note

Do not use this class directly, use one of the sub classes.

Source code in quadra/schedulers/base.py
14
15
16
17
def __init__(self, optimizer: Optimizer, init_lr: Tuple[float, ...]):
    # pylint: disable=super-init-not-called
    self.optimizer = optimizer
    self.init_lr = init_lr

get_lr()

Get the current learning rate if the optimizer is available.

Source code in quadra/schedulers/base.py
39
40
41
42
43
44
45
def get_lr(self):
    """Get the current learning rate if the optimizer is available."""
    if self.optimizer is not None:
        for g in self.optimizer.param_groups:
            return g["lr"]

    return None

set_lr(lr)

Set the learning rate for the optimizer.

Source code in quadra/schedulers/base.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def set_lr(self, lr: Tuple[float, ...]):
    """Set the learning rate for the optimizer."""
    if self.optimizer is not None:
        for i, g in enumerate(self.optimizer.param_groups):
            if "fix_lr" in g and g["fix_lr"]:
                if len(lr) == 1:
                    lr_to_set = self.init_lr[0]
                else:
                    lr_to_set = self.init_lr[i]
            else:
                if len(lr) == 1:
                    lr_to_set = lr[0]
                else:
                    lr_to_set = lr[i]
            g["lr"] = lr_to_set

step(*args, **kwargs)

Base method, must be implemented by the sub classes.

Source code in quadra/schedulers/base.py
19
20
21
def step(self, *args, **kwargs):
    """Base method, must be implemented by the sub classes."""
    raise NotImplementedError