sam
SAM(params, base_optimizer, rho=0.05, adaptive=True, **kwargs)
¶
Bases: Optimizer
PyTorch implementation of Sharpness-Aware-Minization paper: https://arxiv.org/abs/2010.01412 and https://arxiv.org/abs/2102.11600. Taken from: https://github.com/davda54/sam.
Parameters:
-
params
(
list[Parameter]
) –model parameters.
-
base_optimizer
(
Optimizer
) –optimizer to use.
-
rho
(
float
, default:0.05
) –Postive float value used to scale the gradients.
-
adaptive
(
bool
, default:True
) –Boolean flag indicating whether to use adaptive step update.
-
**kwargs
(
Any
, default:{}
) –Additional parameters for the base optimizer.
Source code in quadra/optimizers/sam.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
|
first_step(zero_grad=False)
¶
First step for SAM optimizer.
Parameters:
-
zero_grad
(
bool
, default:False
) –Boolean flag indicating whether to zero the gradients.
Returns:
-
None
–None
Source code in quadra/optimizers/sam.py
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
|
second_step(zero_grad=False)
¶
Second step for SAM optimizer.
Parameters:
-
zero_grad
(
bool
, default:False
) –Boolean flag indicating whether to zero the gradients.
Returns:
-
None
–None
Source code in quadra/optimizers/sam.py
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
|
step(closure=None)
¶
Step for SAM optimizer.
Parameters:
-
closure
(
Callable | None
, default:None
) –The Optional closure for enable grad.
Returns:
-
None
–None
Source code in quadra/optimizers/sam.py
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
|