Skip to content

optimizers

LARS(params, lr=required, momentum=0, dampening=0, weight_decay=0, nesterov=False, trust_coefficient=0.001, eps=1e-08)

Bases: Optimizer

Extends SGD in PyTorch with LARS scaling from the paper Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>_.

Parameters:

  • params (List[Parameter]) –

    iterable of parameters to optimize or dicts defining parameter groups

  • lr (_RequiredParameter) –

    learning rate

  • momentum (float) –

    momentum factor (default: 0)

  • weight_decay (float) –

    weight decay (L2 penalty) (default: 0)

  • dampening (float) –

    dampening for momentum (default: 0)

  • nesterov (bool) –

    enables Nesterov momentum (default: False)

  • trust_coefficient (float) –

    trust coefficient for computing LR (default: 0.001)

  • eps (float) –

    eps for division denominator (default: 1e-8).

Example

model = torch.nn.Linear(10, 1) input = torch.Tensor(10) target = torch.Tensor([1.]) loss_fn = lambda input, target: (input - target) ** 2

optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9) optimizer.zero_grad() loss_fn(model(input), target).backward() optimizer.step()

.. note:: The application of momentum in the SGD part is modified according to the PyTorch standards. LARS scaling fits into the equation in the following fashion.

.. math::
    \begin{aligned}
        g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\
        v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\
        p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
    \\end{aligned}

where :math:`p`, :math:`g`, :math:`v`, :math:`\\mu` and :math:`\beta` denote the
parameters, gradient, velocity, momentum, and weight decay respectively.
The :math:`lars_lr` is defined by Eq. 6 in the paper.
The Nesterov version is analogously modified.

.. warning:: Parameters with weight decay set to 0 will automatically be excluded from layer-wise LR scaling. This is to ensure consistency with papers like SimCLR and BYOL.

Source code in quadra/optimizers/lars.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def __init__(
    self,
    params: List[Parameter],
    lr: _RequiredParameter = required,
    momentum: float = 0,
    dampening: float = 0,
    weight_decay: float = 0,
    nesterov: bool = False,
    trust_coefficient: float = 0.001,
    eps: float = 1e-8,
):
    if lr is not required and lr < 0.0:
        raise ValueError(f"Invalid learning rate: {lr}")
    if momentum < 0.0:
        raise ValueError(f"Invalid momentum value: {momentum}")
    if weight_decay < 0.0:
        raise ValueError(f"Invalid weight_decay value: {weight_decay}")

    defaults = {
        "lr": lr,
        "momentum": momentum,
        "dampening": dampening,
        "weight_decay": weight_decay,
        "nesterov": nesterov,
    }
    if nesterov and (momentum <= 0 or dampening != 0):
        raise ValueError("Nesterov momentum requires a momentum and zero dampening")

    self.eps = eps
    self.trust_coefficient = trust_coefficient

    super().__init__(params, defaults)

step(closure=None)

Performs a single optimization step.

Parameters:

  • closure (Optional[Callable]) –

    A closure that reevaluates the model and returns the loss. Defaults to None.

Source code in quadra/optimizers/lars.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
@torch.no_grad()
def step(self, closure: Optional[Callable] = None):
    """Performs a single optimization step.

    Args:
        closure: A closure that reevaluates the model and returns the loss. Defaults to None.
    """
    loss = None
    if closure is not None:
        with torch.enable_grad():
            loss = closure()

    # exclude scaling for params with 0 weight decay
    for group in self.param_groups:
        weight_decay = group["weight_decay"]
        momentum = group["momentum"]
        dampening = group["dampening"]
        nesterov = group["nesterov"]

        for p in group["params"]:
            if p.grad is None:
                continue

            d_p = p.grad
            p_norm = torch.norm(p.data)
            g_norm = torch.norm(p.grad.data)

            # lars scaling + weight decay part
            if weight_decay != 0:
                if p_norm != 0 and g_norm != 0:
                    lars_lr = p_norm / (g_norm + p_norm * weight_decay + self.eps)
                    lars_lr *= self.trust_coefficient

                    d_p = d_p.add(p, alpha=weight_decay)
                    d_p *= lars_lr

            # sgd part
            if momentum != 0:
                param_state = self.state[p]
                if "momentum_buffer" not in param_state:
                    buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
                else:
                    buf = param_state["momentum_buffer"]
                    buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
                if nesterov:
                    d_p = d_p.add(buf, alpha=momentum)
                else:
                    d_p = buf

            p.add_(d_p, alpha=-group["lr"])

    return loss

SAM(params, base_optimizer, rho=0.05, adaptive=True, **kwargs)

Bases: torch.optim.Optimizer

PyTorch implementation of Sharpness-Aware-Minization paper: https://arxiv.org/abs/2010.01412 and https://arxiv.org/abs/2102.11600. Taken from: https://github.com/davda54/sam.

Parameters:

  • params (List[Parameter]) –

    model parameters.

  • base_optimizer (torch.optim.Optimizer) –

    optimizer to use.

  • rho (float) –

    Postive float value used to scale the gradients.

  • adaptive (bool) –

    Boolean flag indicating whether to use adaptive step update.

  • **kwargs (Any) –

    Additional parameters for the base optimizer.

Source code in quadra/optimizers/sam.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(
    self,
    params: List[Parameter],
    base_optimizer: torch.optim.Optimizer,
    rho: float = 0.05,
    adaptive: bool = True,
    **kwargs: Any,
):
    assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

    defaults = {"rho": rho, "adaptive": adaptive, **kwargs}
    super().__init__(params, defaults)

    if callable(base_optimizer):
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
    else:
        self.base_optimizer = base_optimizer
    self.rho = rho
    self.adaptive = adaptive
    self.param_groups = self.base_optimizer.param_groups

first_step(zero_grad=False)

First step for SAM optimizer.

Parameters:

  • zero_grad (bool) –

    Boolean flag indicating whether to zero the gradients.

Returns:

  • None

    None

Source code in quadra/optimizers/sam.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
@torch.no_grad()
def first_step(self, zero_grad: bool = False) -> None:
    """First step for SAM optimizer.

    Args:
        zero_grad: Boolean flag indicating whether to zero the gradients.

    Returns:
        None
    """
    grad_norm = self._grad_norm()
    for group in self.param_groups:
        scale = self.rho / (grad_norm + 1e-12)

        for p in group["params"]:
            if p.grad is None:
                continue
            e_w = (torch.pow(p, 2) if self.adaptive else 1.0) * p.grad * scale.to(p)
            p.add_(e_w)  # climb to the local maximum "w + e(w)"
            self.state[p]["e_w"] = e_w

    if zero_grad:
        self.zero_grad()

second_step(zero_grad=False)

Second step for SAM optimizer.

Parameters:

  • zero_grad (bool) –

    Boolean flag indicating whether to zero the gradients.

Returns:

  • None

    None

Source code in quadra/optimizers/sam.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
@torch.no_grad()
def second_step(self, zero_grad: bool = False) -> None:
    """Second step for SAM optimizer.

    Args:
        zero_grad: Boolean flag indicating whether to zero the gradients.

    Returns:
        None

    """
    for group in self.param_groups:
        for p in group["params"]:
            if p.grad is None:
                continue
            p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"

    self.base_optimizer.step()  # do the actual "sharpness-aware" update

    if zero_grad:
        self.zero_grad()

step(closure=None)

Step for SAM optimizer.

Parameters:

Returns:

  • None

    None

Source code in quadra/optimizers/sam.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
@torch.no_grad()
def step(self, closure: Optional[Callable] = None) -> None:
    """Step for SAM optimizer.

    Args:
        closure: The Optional closure for enable grad.

    Returns:
        None

    """
    if closure is not None:
        closure = torch.enable_grad()(closure)

    self.first_step(zero_grad=True)
    if closure is not None:
        closure()
    self.second_step(zero_grad=False)