Skip to content

sam

SAM(params, base_optimizer, rho=0.05, adaptive=True, **kwargs)

Bases: Optimizer

PyTorch implementation of Sharpness-Aware-Minization paper: https://arxiv.org/abs/2010.01412 and https://arxiv.org/abs/2102.11600. Taken from: https://github.com/davda54/sam.

Parameters:

  • params (list[Parameter]) –

    model parameters.

  • base_optimizer (Optimizer) –

    optimizer to use.

  • rho (float, default: 0.05 ) –

    Postive float value used to scale the gradients.

  • adaptive (bool, default: True ) –

    Boolean flag indicating whether to use adaptive step update.

  • **kwargs (Any, default: {} ) –

    Additional parameters for the base optimizer.

Source code in quadra/optimizers/sam.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def __init__(
    self,
    params: list[Parameter],
    base_optimizer: torch.optim.Optimizer,
    rho: float = 0.05,
    adaptive: bool = True,
    **kwargs: Any,
):
    assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

    defaults = {"rho": rho, "adaptive": adaptive, **kwargs}
    super().__init__(params, defaults)

    if callable(base_optimizer):
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
    else:
        self.base_optimizer = base_optimizer
    self.rho = rho
    self.adaptive = adaptive
    self.param_groups = self.base_optimizer.param_groups

first_step(zero_grad=False)

First step for SAM optimizer.

Parameters:

  • zero_grad (bool, default: False ) –

    Boolean flag indicating whether to zero the gradients.

Returns:

  • None

    None

Source code in quadra/optimizers/sam.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
@torch.no_grad()
def first_step(self, zero_grad: bool = False) -> None:
    """First step for SAM optimizer.

    Args:
        zero_grad: Boolean flag indicating whether to zero the gradients.

    Returns:
        None
    """
    grad_norm = self._grad_norm()
    for group in self.param_groups:
        scale = self.rho / (grad_norm + 1e-12)

        for p in group["params"]:
            if p.grad is None:
                continue
            e_w = (torch.pow(p, 2) if self.adaptive else 1.0) * p.grad * scale.to(p)
            p.add_(e_w)  # climb to the local maximum "w + e(w)"
            self.state[p]["e_w"] = e_w

    if zero_grad:
        self.zero_grad()

second_step(zero_grad=False)

Second step for SAM optimizer.

Parameters:

  • zero_grad (bool, default: False ) –

    Boolean flag indicating whether to zero the gradients.

Returns:

  • None

    None

Source code in quadra/optimizers/sam.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
@torch.no_grad()
def second_step(self, zero_grad: bool = False) -> None:
    """Second step for SAM optimizer.

    Args:
        zero_grad: Boolean flag indicating whether to zero the gradients.

    Returns:
        None

    """
    for group in self.param_groups:
        for p in group["params"]:
            if p.grad is None:
                continue
            p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"

    self.base_optimizer.step()  # do the actual "sharpness-aware" update

    if zero_grad:
        self.zero_grad()

step(closure=None)

Step for SAM optimizer.

Parameters:

  • closure (Callable | None, default: None ) –

    The Optional closure for enable grad.

Returns:

  • None

    None

Source code in quadra/optimizers/sam.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
@torch.no_grad()
def step(self, closure: Callable | None = None) -> None:  # type: ignore[override]
    """Step for SAM optimizer.

    Args:
        closure: The Optional closure for enable grad.

    Returns:
        None

    """
    if closure is not None:
        closure = torch.enable_grad()(closure)

    self.first_step(zero_grad=True)
    if closure is not None:
        closure()
    self.second_step(zero_grad=False)