Skip to content

base

Evaluation(config, model_path, device=None)

Bases: Generic[DataModuleT], Task[DataModuleT]

Base Evaluation Task with deployment models.

Parameters:

  • config (DictConfig) –

    The experiment configuration

  • model_path (str) –

    The model path.

  • device (str | None, default: None ) –

    Device to use for evaluation. If None, the device is automatically determined.

Source code in quadra/tasks/base.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
def __init__(
    self,
    config: DictConfig,
    model_path: str,
    device: str | None = None,
):
    super().__init__(config=config)

    if device is None:
        self.device = utils.get_device()
    else:
        self.device = device

    self.config = config
    self.model_data: dict[str, Any]
    self.model_path = model_path
    self._deployment_model: BaseEvaluationModel
    self.deployment_model_type: str
    self.model_info_filename = "model.json"
    self.report_path = ""
    self.metadata = {"report_files": []}

deployment_model: BaseEvaluationModel property writable

Deployment model.

prepare()

Prepare the evaluation.

Source code in quadra/tasks/base.py
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
def prepare(self) -> None:
    """Prepare the evaluation."""
    with open(os.path.join(Path(self.model_path).parent, self.model_info_filename)) as f:
        self.model_data = json.load(f)

    if not isinstance(self.model_data, dict):
        raise ValueError("Model info file is not a valid json")

    for input_size in self.model_data["input_size"]:
        if len(input_size) != 3:
            continue

        # Adjust the transform for 2D models (CxHxW)
        # We assume that each input size has the same height and width
        if input_size[1] != self.config.transforms.input_height:
            log.warning(
                f"Input height of the model ({input_size[1]}) is different from the one specified "
                + f"in the config ({self.config.transforms.input_height}). Fixing the config."
            )
            self.config.transforms.input_height = input_size[1]

        if input_size[2] != self.config.transforms.input_width:
            log.warning(
                f"Input width of the model ({input_size[2]}) is different from the one specified "
                + f"in the config ({self.config.transforms.input_width}). Fixing the config."
            )
            self.config.transforms.input_width = input_size[2]

    self.deployment_model = self.model_path  # type: ignore[assignment]

LightningTask(config, checkpoint_path=None, run_test=False, report=False)

Bases: Generic[DataModuleT], Task[DataModuleT]

Base Experiment Task.

Parameters:

  • config (DictConfig) –

    The experiment configuration

  • checkpoint_path (str | None, default: None ) –

    The path to the checkpoint to load the model from. Defaults to None.

  • run_test (bool, default: False ) –

    Whether to run the test after training. Defaults to False.

  • report (bool, default: False ) –

    Whether to generate a report. Defaults to False.

Source code in quadra/tasks/base.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def __init__(
    self,
    config: DictConfig,
    checkpoint_path: str | None = None,
    run_test: bool = False,
    report: bool = False,
):
    super().__init__(config=config)
    self.checkpoint_path = checkpoint_path
    self.run_test = run_test
    self.report = report
    self._module: LightningModule
    self._devices: int | list[int]
    self._callbacks: list[Callback]
    self._logger: list[Logger]
    self._trainer: Trainer

callbacks: list[Callback] property writable

List[Callback]: The callbacks.

devices: int | list[int] property writable

List[int]: The devices ids.

logger: list[Logger] property writable

List[Logger]: The loggers.

module: LightningModule property writable

trainer: Trainer property writable

add_callback(callback)

Add a callback to the trainer.

Parameters:

  • callback (Callback) –

    The callback to add

Source code in quadra/tasks/base.py
293
294
295
296
297
298
299
300
def add_callback(self, callback: Callback):
    """Add a callback to the trainer.

    Args:
        callback: The callback to add
    """
    if hasattr(self.trainer, "callbacks") and isinstance(self.trainer.callbacks, list):
        self.trainer.callbacks.append(callback)

execute()

Execute the experiment and all the steps.

Source code in quadra/tasks/base.py
302
303
304
305
306
307
308
309
310
311
312
def execute(self) -> None:
    """Execute the experiment and all the steps."""
    self.prepare()
    self.train()
    if self.run_test:
        self.test()
    if self.config.export is not None and len(self.config.export.types) > 0:
        self.export()
    if self.report:
        self.generate_report()
    self.finalize()

finalize()

Finalize the experiment.

Source code in quadra/tasks/base.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
def finalize(self) -> None:
    """Finalize the experiment."""
    super().finalize()
    utils.finish(
        config=self.config,
        module=self.module,
        datamodule=self.datamodule,
        trainer=self.trainer,
        callbacks=self.callbacks,
        logger=self.logger,
        export_folder=self.export_folder,
    )

    if (
        not self.config.trainer.get("fast_dev_run")
        and self.trainer.checkpoint_callback is not None
        and hasattr(self.trainer.checkpoint_callback, "best_model_path")
    ):
        log.info("Best model ckpt: %s", self.trainer.checkpoint_callback.best_model_path)

prepare()

Prepare the experiment.

Source code in quadra/tasks/base.py
122
123
124
125
126
127
128
129
130
131
132
133
134
def prepare(self) -> None:
    """Prepare the experiment."""
    super().prepare()

    # First setup loggers since some callbacks might need logger setup correctly.
    if "logger" in self.config:
        self.logger = self.config.logger

    if "callbacks" in self.config:
        self.callbacks = self.config.callbacks

    self.devices = self.config.trainer.devices
    self.trainer = self.config.trainer

test()

Test the model.

Source code in quadra/tasks/base.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def test(self) -> Any:
    """Test the model."""
    log.info("Starting testing!")

    best_model = None
    if (
        self.trainer.checkpoint_callback is not None
        and hasattr(self.trainer.checkpoint_callback, "best_model_path")
        and self.trainer.checkpoint_callback.best_model_path is not None
        and len(self.trainer.checkpoint_callback.best_model_path) > 0
    ):
        best_model = self.trainer.checkpoint_callback.best_model_path

    if best_model is None:
        log.warning(
            "No best checkpoint model found, using last weights for test, this might lead to worse results, "
            "consider using a checkpoint callback."
        )

    return self.trainer.test(model=self.module, datamodule=self.datamodule, ckpt_path=best_model)

train()

Train the model.

Source code in quadra/tasks/base.py
241
242
243
244
245
246
247
248
249
250
def train(self) -> None:
    """Train the model."""
    log.info("Starting training!")
    utils.log_hyperparameters(
        config=self.config,
        model=self.module,
        trainer=self.trainer,
    )

    self.trainer.fit(model=self.module, datamodule=self.datamodule)

PlaceholderTask

Bases: Task

Placeholder task.

execute()

Execute the task and all the steps.

Source code in quadra/tasks/base.py
318
319
320
321
322
def execute(self) -> None:
    """Execute the task and all the steps."""
    log.info("Running Placeholder Task.")
    log.info("Quadra Version: %s", str(get_version()))
    log.info("If you are reading this, it means that library is installed correctly!")

Task(config)

Bases: Generic[DataModuleT]

Base Experiment Task.

Parameters:

  • config (DictConfig) –

    The experiment configuration.

Source code in quadra/tasks/base.py
35
36
37
38
39
40
def __init__(self, config: DictConfig):
    self.config = config
    self.export_folder: str = "deployment_model"
    self._datamodule: DataModuleT
    self.metadata: dict[str, Any]
    self.save_config()

datamodule: DataModuleT property writable

execute()

Execute the experiment and all the steps.

Source code in quadra/tasks/base.py
84
85
86
87
88
89
90
91
92
def execute(self) -> None:
    """Execute the experiment and all the steps."""
    self.prepare()
    self.train()
    self.test()
    if self.config.export is not None and len(self.config.export.types) > 0:
        self.export()
    self.generate_report()
    self.finalize()

export()

Export model for production.

Source code in quadra/tasks/base.py
72
73
74
def export(self) -> None:
    """Export model for production."""
    log.info("Export model for production not implemented for this task!")

finalize()

Finalize the experiment.

Source code in quadra/tasks/base.py
80
81
82
def finalize(self) -> None:
    """Finalize the experiment."""
    log.info("Results are saved in %s", os.getcwd())

generate_report()

Generate a report.

Source code in quadra/tasks/base.py
76
77
78
def generate_report(self) -> None:
    """Generate a report."""
    log.info("Report generation not implemented for this task!")

prepare()

Prepare the experiment.

Source code in quadra/tasks/base.py
48
49
50
def prepare(self) -> None:
    """Prepare the experiment."""
    self.datamodule = self.config.datamodule

save_config()

Save the experiment configuration when running an Hydra experiment.

Source code in quadra/tasks/base.py
42
43
44
45
46
def save_config(self) -> None:
    """Save the experiment configuration when running an Hydra experiment."""
    if HydraConfig.initialized():
        with open("config_resolved.yaml", "w") as fp:
            OmegaConf.save(config=OmegaConf.to_container(self.config, resolve=True), f=fp.name)

test()

Test the model.

Source code in quadra/tasks/base.py
68
69
70
def test(self) -> Any:
    """Test the model."""
    log.info("Testing not implemented for this task!")

train()

Train the model.

Source code in quadra/tasks/base.py
64
65
66
def train(self) -> Any:
    """Train the model."""
    log.info("Training not implemented for this task!")