Skip to content

lars

  • https://arxiv.org/pdf/1708.03888.pdf
  • https://github.com/pytorch/pytorch/blob/1.6/torch/optim/sgd.py.

LARS(params, lr=required, momentum=0, dampening=0, weight_decay=0, nesterov=False, trust_coefficient=0.001, eps=1e-08)

Bases: Optimizer

Extends SGD in PyTorch with LARS scaling from the paper Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>_.

Parameters:

  • params (list[Parameter]) –

    iterable of parameters to optimize or dicts defining parameter groups

  • lr (_RequiredParameter, default: required ) –

    learning rate

  • momentum (float, default: 0 ) –

    momentum factor (default: 0)

  • weight_decay (float, default: 0 ) –

    weight decay (L2 penalty) (default: 0)

  • dampening (float, default: 0 ) –

    dampening for momentum (default: 0)

  • nesterov (bool, default: False ) –

    enables Nesterov momentum (default: False)

  • trust_coefficient (float, default: 0.001 ) –

    trust coefficient for computing LR (default: 0.001)

  • eps (float, default: 1e-08 ) –

    eps for division denominator (default: 1e-8).

Example

model = torch.nn.Linear(10, 1) input = torch.Tensor(10) target = torch.Tensor([1.]) loss_fn = lambda input, target: (input - target) ** 2

optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9) optimizer.zero_grad() loss_fn(model(input), target).backward() optimizer.step()

.. note:: The application of momentum in the SGD part is modified according to the PyTorch standards. LARS scaling fits into the equation in the following fashion.

.. math::
    \begin{aligned}
        g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\
        v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\
        p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
    \\end{aligned}

where :math:`p`, :math:`g`, :math:`v`, :math:`\\mu` and :math:`\beta` denote the
parameters, gradient, velocity, momentum, and weight decay respectively.
The :math:`lars_lr` is defined by Eq. 6 in the paper.
The Nesterov version is analogously modified.

.. warning:: Parameters with weight decay set to 0 will automatically be excluded from layer-wise LR scaling. This is to ensure consistency with papers like SimCLR and BYOL.

Source code in quadra/optimizers/lars.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def __init__(
    self,
    params: list[Parameter],
    lr: _RequiredParameter = required,
    momentum: float = 0,
    dampening: float = 0,
    weight_decay: float = 0,
    nesterov: bool = False,
    trust_coefficient: float = 0.001,
    eps: float = 1e-8,
):
    if lr is not required and lr < 0.0:  # type: ignore[operator]
        raise ValueError(f"Invalid learning rate: {lr}")
    if momentum < 0.0:
        raise ValueError(f"Invalid momentum value: {momentum}")
    if weight_decay < 0.0:
        raise ValueError(f"Invalid weight_decay value: {weight_decay}")

    defaults = {
        "lr": lr,
        "momentum": momentum,
        "dampening": dampening,
        "weight_decay": weight_decay,
        "nesterov": nesterov,
    }
    if nesterov and (momentum <= 0 or dampening != 0):
        raise ValueError("Nesterov momentum requires a momentum and zero dampening")

    self.eps = eps
    self.trust_coefficient = trust_coefficient

    super().__init__(params, defaults)

step(closure=None)

Performs a single optimization step.

Parameters:

  • closure (Callable | None, default: None ) –

    A closure that reevaluates the model and returns the loss. Defaults to None.

Source code in quadra/optimizers/lars.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
@torch.no_grad()
def step(self, closure: Callable | None = None):
    """Performs a single optimization step.

    Args:
        closure: A closure that reevaluates the model and returns the loss. Defaults to None.
    """
    loss = None
    if closure is not None:
        with torch.enable_grad():
            loss = closure()

    # exclude scaling for params with 0 weight decay
    for group in self.param_groups:
        weight_decay = group["weight_decay"]
        momentum = group["momentum"]
        dampening = group["dampening"]
        nesterov = group["nesterov"]

        for p in group["params"]:
            if p.grad is None:
                continue

            d_p = p.grad
            p_norm = torch.norm(p.data)
            g_norm = torch.norm(p.grad.data)

            # lars scaling + weight decay part
            if weight_decay != 0 and p_norm != 0 and g_norm != 0:
                lars_lr = p_norm / (g_norm + p_norm * weight_decay + self.eps)
                lars_lr *= self.trust_coefficient

                d_p = d_p.add(p, alpha=weight_decay)
                d_p *= lars_lr

            # sgd part
            if momentum != 0:
                param_state = self.state[p]
                if "momentum_buffer" not in param_state:
                    buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
                else:
                    buf = param_state["momentum_buffer"]
                    buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
                if nesterov:
                    d_p = d_p.add(buf, alpha=momentum)
                else:
                    d_p = buf

            p.add_(d_p, alpha=-group["lr"])

    return loss