Skip to content

scheduler

WarmupInit(scheduler_config)

Bases: Callback

Custom callback used to setup a warmup scheduler.

Parameters:

  • scheduler_config (DictConfig) –

    scheduler configuration.

Source code in quadra/callbacks/scheduler.py
20
21
22
23
24
def __init__(
    self,
    scheduler_config: DictConfig,
) -> None:
    self.scheduler_config = scheduler_config

on_fit_start(trainer, pl_module)

Called when fit begins.

Source code in quadra/callbacks/scheduler.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
@rank_zero_only
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
    """Called when fit begins."""
    if not hasattr(trainer, "datamodule"):
        raise ValueError("Trainer must have a datamodule attribute.")

    if not any(isinstance(s.scheduler, CosineAnnealingWithLinearWarmUp) for s in trainer.lr_scheduler_configs):
        return

    log.info("Using warmup scheduler, forcing optimizer learning rate to zero.")
    for i, _ in enumerate(trainer.optimizers):
        for param_group in trainer.optimizers[i].param_groups:
            param_group["lr"] = 0.0
        trainer.optimizers[i].defaults["lr"] = 0.0

    batch_size = trainer.datamodule.batch_size
    train_dataloader = trainer.datamodule.train_dataloader()
    len_train_dataloader = len(train_dataloader)
    if isinstance(trainer.device_ids, list) and pl_module.device.type == "cuda":
        num_gpus = len(trainer.device_ids)
        len_train_dataloader = len_train_dataloader // num_gpus
        if not train_dataloader.drop_last:
            len_train_dataloader += int((len_train_dataloader % num_gpus) != 0)

    if len_train_dataloader == 1:
        log.warning(
            "From this dataset size, we can only generate single batch. The batch size will be set as lenght of"
            " the dataset "
        )
        batch_size = len(train_dataloader.dataset)

    if isinstance(trainer.device_ids, list) and pl_module.device.type == "cuda":
        batch_size = batch_size * len(trainer.device_ids)

    scheduler = hydra.utils.instantiate(
        self.scheduler_config,
        optimizer=pl_module.optimizer,
        batch_size=batch_size,
        len_loader=len_train_dataloader,
    )

    for i, s in enumerate(trainer.lr_scheduler_configs):
        if isinstance(s.scheduler, CosineAnnealingWithLinearWarmUp):
            trainer.lr_scheduler_configs[i].scheduler = scheduler