Skip to content

lightning

BatchSizeFinder(find_train_batch_size=True, find_validation_batch_size=False, find_test_batch_size=False, find_predict_batch_size=False, mode='power', steps_per_trial=3, init_val=2, max_trials=25, batch_arg_name='batch_size')

Bases: BatchSizeFinder

Batch size finder setting the proper model training status as the current one from lightning seems bugged. It also allows to skip some batch size finding steps.

Parameters:

  • find_train_batch_size (bool, default: True ) –

    Whether to find the training batch size.

  • find_validation_batch_size (bool, default: False ) –

    Whether to find the validation batch size.

  • find_test_batch_size (bool, default: False ) –

    Whether to find the test batch size.

  • find_predict_batch_size (bool, default: False ) –

    Whether to find the predict batch size.

  • mode (str, default: 'power' ) –

    The mode to use for batch size finding. See pytorch_lightning.callbacks.BatchSizeFinder for more details.

  • steps_per_trial (int, default: 3 ) –

    The number of steps per trial. See pytorch_lightning.callbacks.BatchSizeFinder for more details.

  • init_val (int, default: 2 ) –

    The initial value for batch size. See pytorch_lightning.callbacks.BatchSizeFinder for more details.

  • max_trials (int, default: 25 ) –

    The maximum number of trials. See pytorch_lightning.callbacks.BatchSizeFinder for more details.

  • batch_arg_name (str, default: 'batch_size' ) –

    The name of the batch size argument. See pytorch_lightning.callbacks.BatchSizeFinder for more details.

Source code in quadra/callbacks/lightning.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def __init__(
    self,
    find_train_batch_size: bool = True,
    find_validation_batch_size: bool = False,
    find_test_batch_size: bool = False,
    find_predict_batch_size: bool = False,
    mode: str = "power",
    steps_per_trial: int = 3,
    init_val: int = 2,
    max_trials: int = 25,
    batch_arg_name: str = "batch_size",
) -> None:
    super().__init__(mode, steps_per_trial, init_val, max_trials, batch_arg_name)
    self.find_train_batch_size = find_train_batch_size
    self.find_validation_batch_size = find_validation_batch_size
    self.find_test_batch_size = find_test_batch_size
    self.find_predict_batch_size = find_predict_batch_size

LightningTrainerBaseSetup(log_every_n_steps=1)

Bases: Callback

Custom callback used to setup a lightning trainer with default options.

Parameters:

  • log_every_n_steps (int, default: 1 ) –

    Default value for trainer.log_every_n_steps if the dataloader is too small.

Source code in quadra/callbacks/lightning.py
19
20
def __init__(self, log_every_n_steps: int = 1) -> None:
    self.log_every_n_steps = log_every_n_steps

on_fit_start(trainer, pl_module)

Called on every stage.

Source code in quadra/callbacks/lightning.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
@rank_zero_only
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
    """Called on every stage."""
    if not hasattr(trainer, "datamodule") or not hasattr(trainer, "log_every_n_steps"):
        raise ValueError("Trainer must have a datamodule and log_every_n_steps attribute.")

    len_train_dataloader = len(trainer.datamodule.train_dataloader())
    if len_train_dataloader <= trainer.log_every_n_steps:
        if len_train_dataloader > self.log_every_n_steps:
            trainer.log_every_n_steps = self.log_every_n_steps
            log.info("`trainer.log_every_n_steps` is too high, setting it to %d", self.log_every_n_steps)
        else:
            trainer.log_every_n_steps = 1
            log.warning(
                "The default log_every_n_steps %d is too high given the datamodule lenght %d, fallback to 1",
                self.log_every_n_steps,
                len_train_dataloader,
            )