Bases: Callback
Custom callback used to setup a warmup scheduler.
Parameters:
-
scheduler_config
(
DictConfig
)
–
Source code in quadra/callbacks/scheduler.py
| def __init__(
self,
scheduler_config: DictConfig,
) -> None:
self.scheduler_config = scheduler_config
|
on_fit_start(trainer, pl_module)
Called when fit begins.
Source code in quadra/callbacks/scheduler.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69 | @rank_zero_only
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""Called when fit begins."""
if not hasattr(trainer, "datamodule"):
raise ValueError("Trainer must have a datamodule attribute.")
if not any(isinstance(s.scheduler, CosineAnnealingWithLinearWarmUp) for s in trainer.lr_scheduler_configs):
return
log.info("Using warmup scheduler, forcing optimizer learning rate to zero.")
for i, _ in enumerate(trainer.optimizers):
for param_group in trainer.optimizers[i].param_groups:
param_group["lr"] = 0.0
trainer.optimizers[i].defaults["lr"] = 0.0
batch_size = trainer.datamodule.batch_size
train_dataloader = trainer.datamodule.train_dataloader()
len_train_dataloader = len(train_dataloader)
if isinstance(trainer.device_ids, list) and pl_module.device.type == "cuda":
num_gpus = len(trainer.device_ids)
len_train_dataloader = len_train_dataloader // num_gpus
if not train_dataloader.drop_last:
len_train_dataloader += int((len_train_dataloader % num_gpus) != 0)
if len_train_dataloader == 1:
log.warning(
"From this dataset size, we can only generate single batch. The batch size will be set as lenght of"
" the dataset "
)
batch_size = len(train_dataloader.dataset)
if isinstance(trainer.device_ids, list) and pl_module.device.type == "cuda":
batch_size = batch_size * len(trainer.device_ids)
scheduler = hydra.utils.instantiate(
self.scheduler_config,
optimizer=pl_module.optimizer,
batch_size=batch_size,
len_loader=len_train_dataloader,
)
for i, s in enumerate(trainer.lr_scheduler_configs):
if isinstance(s.scheduler, CosineAnnealingWithLinearWarmUp):
trainer.lr_scheduler_configs[i].scheduler = scheduler
|