Skip to content

lightning

BatchSizeFinder(find_train_batch_size=True, find_validation_batch_size=False, find_test_batch_size=False, find_predict_batch_size=False, mode='power', steps_per_trial=3, init_val=2, max_trials=25, batch_arg_name='batch_size')

Bases: BatchSizeFinder

Batch size finder setting the proper model training status as the current one from lightning seems bugged. It also allows to skip some batch size finding steps.

Parameters:

  • find_train_batch_size (bool, default: True ) –

    Whether to find the training batch size.

  • find_validation_batch_size (bool, default: False ) –

    Whether to find the validation batch size.

  • find_test_batch_size (bool, default: False ) –

    Whether to find the test batch size.

  • find_predict_batch_size (bool, default: False ) –

    Whether to find the predict batch size.

  • mode (str, default: 'power' ) –

    The mode to use for batch size finding. See pytorch_lightning.callbacks.BatchSizeFinder for more details.

  • steps_per_trial (int, default: 3 ) –

    The number of steps per trial. See pytorch_lightning.callbacks.BatchSizeFinder for more details.

  • init_val (int, default: 2 ) –

    The initial value for batch size. See pytorch_lightning.callbacks.BatchSizeFinder for more details.

  • max_trials (int, default: 25 ) –

    The maximum number of trials. See pytorch_lightning.callbacks.BatchSizeFinder for more details.

  • batch_arg_name (str, default: 'batch_size' ) –

    The name of the batch size argument. See pytorch_lightning.callbacks.BatchSizeFinder for more details.

Source code in quadra/callbacks/lightning.py
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
def __init__(
    self,
    find_train_batch_size: bool = True,
    find_validation_batch_size: bool = False,
    find_test_batch_size: bool = False,
    find_predict_batch_size: bool = False,
    mode: str = "power",
    steps_per_trial: int = 3,
    init_val: int = 2,
    max_trials: int = 25,
    batch_arg_name: str = "batch_size",
) -> None:
    super().__init__(mode, steps_per_trial, init_val, max_trials, batch_arg_name)
    self.find_train_batch_size = find_train_batch_size
    self.find_validation_batch_size = find_validation_batch_size
    self.find_test_batch_size = find_test_batch_size
    self.find_predict_batch_size = find_predict_batch_size

scale_batch_size(trainer, pl_module)

Scale the batch size.

Source code in quadra/callbacks/lightning.py
488
489
490
491
492
493
494
495
496
497
498
499
500
501
def scale_batch_size(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
    """Scale the batch size."""
    new_size = _scale_batch_size(
        trainer,
        self._mode,
        self._steps_per_trial,
        self._init_val,
        self._max_trials,
        self._batch_arg_name,
    )

    self.optimal_batch_size = new_size
    if self._early_exit:
        raise _TunerExitException()

LightningTrainerBaseSetup(log_every_n_steps=1)

Bases: Callback

Custom callback used to setup a lightning trainer with default options.

Parameters:

  • log_every_n_steps (int, default: 1 ) –

    Default value for trainer.log_every_n_steps if the dataloader is too small.

Source code in quadra/callbacks/lightning.py
384
385
def __init__(self, log_every_n_steps: int = 1) -> None:
    self.log_every_n_steps = log_every_n_steps

on_fit_start(trainer, pl_module)

Called on every stage.

Source code in quadra/callbacks/lightning.py
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
@rank_zero_only
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
    """Called on every stage."""
    if not hasattr(trainer, "datamodule") or not hasattr(trainer, "log_every_n_steps"):
        raise ValueError("Trainer must have a datamodule and log_every_n_steps attribute.")

    len_train_dataloader = len(trainer.datamodule.train_dataloader())
    if len_train_dataloader <= trainer.log_every_n_steps:
        if len_train_dataloader > self.log_every_n_steps:
            trainer.log_every_n_steps = self.log_every_n_steps
            log.info("`trainer.log_every_n_steps` is too high, setting it to %d", self.log_every_n_steps)
        else:
            trainer.log_every_n_steps = 1
            log.warning(
                "The default log_every_n_steps %d is too high given the datamodule lenght %d, fallback to 1",
                self.log_every_n_steps,
                len_train_dataloader,
            )

__scale_batch_dump_params(trainer)

Dump the parameters that need to be reset after the batch size finder..

Source code in quadra/callbacks/lightning.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def __scale_batch_dump_params(trainer: pl.Trainer) -> dict[str, Any]:
    """Dump the parameters that need to be reset after the batch size finder.."""
    dumped_params = {
        "loggers": trainer.loggers,
        "callbacks": trainer.callbacks,  # type: ignore[attr-defined]
    }
    loop = trainer._active_loop
    assert loop is not None
    if isinstance(loop, pl.loops._FitLoop):
        dumped_params["max_steps"] = trainer.max_steps
        dumped_params["limit_train_batches"] = trainer.limit_train_batches
        dumped_params["limit_val_batches"] = trainer.limit_val_batches
    elif isinstance(loop, pl.loops._EvaluationLoop):
        stage = trainer.state.stage
        assert stage is not None
        dumped_params["limit_eval_batches"] = getattr(trainer, f"limit_{stage.dataloader_prefix}_batches")
        dumped_params["loop_verbose"] = loop.verbose

    dumped_params["loop_state_dict"] = deepcopy(loop.state_dict())
    return dumped_params

__scale_batch_reset_params(trainer, steps_per_trial)

Reset the parameters that need to be reset after the batch size finder.

Source code in quadra/callbacks/lightning.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def __scale_batch_reset_params(trainer: pl.Trainer, steps_per_trial: int) -> None:
    """Reset the parameters that need to be reset after the batch size finder."""
    from pytorch_lightning.loggers.logger import DummyLogger  # pylint: disable=import-outside-toplevel

    trainer.logger = DummyLogger() if trainer.logger is not None else None
    trainer.callbacks = []  # type: ignore[attr-defined]

    loop = trainer._active_loop
    assert loop is not None
    if isinstance(loop, pl.loops._FitLoop):
        trainer.limit_train_batches = 1.0
        trainer.limit_val_batches = steps_per_trial
        trainer.fit_loop.epoch_loop.max_steps = steps_per_trial
    elif isinstance(loop, pl.loops._EvaluationLoop):
        stage = trainer.state.stage
        assert stage is not None
        setattr(trainer, f"limit_{stage.dataloader_prefix}_batches", steps_per_trial)
        loop.verbose = False

__scale_batch_restore_params(trainer, params)

Restore the parameters that need to be reset after the batch size finder.

Source code in quadra/callbacks/lightning.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def __scale_batch_restore_params(trainer: pl.Trainer, params: dict[str, Any]) -> None:
    """Restore the parameters that need to be reset after the batch size finder."""
    # TODO: There are more states that needs to be reset (#4512 and #4870)
    trainer.loggers = params["loggers"]
    trainer.callbacks = params["callbacks"]  # type: ignore[attr-defined]

    loop = trainer._active_loop
    assert loop is not None
    if isinstance(loop, pl.loops._FitLoop):
        loop.epoch_loop.max_steps = params["max_steps"]
        trainer.limit_train_batches = params["limit_train_batches"]
        trainer.limit_val_batches = params["limit_val_batches"]
    elif isinstance(loop, pl.loops._EvaluationLoop):
        stage = trainer.state.stage
        assert stage is not None
        setattr(trainer, f"limit_{stage.dataloader_prefix}_batches", params["limit_eval_batches"])

    loop.load_state_dict(deepcopy(params["loop_state_dict"]))
    loop.restarting = False
    if isinstance(loop, pl.loops._EvaluationLoop) and "loop_verbose" in params:
        loop.verbose = params["loop_verbose"]

    # make sure the loop's state is reset
    _reset_dataloaders(trainer)
    loop.reset()