Skip to content

helpers

check_deployment_model(export_type)

Check that the runtime model is present and valid.

Parameters:

  • export_type (str) –

    The type of the exported model.

Source code in quadra/utils/tests/helpers.py
36
37
38
39
40
41
42
43
44
45
def check_deployment_model(export_type: str):
    """Check that the runtime model is present and valid.

    Args:
        export_type: The type of the exported model.
    """
    extension = get_export_extension(export_type)

    assert os.path.exists(f"deployment_model/model.{extension}")
    assert os.path.exists("deployment_model/model.json")

execute_quadra_experiment(overrides, experiment_path)

Execute quadra experiment.

Source code in quadra/utils/tests/helpers.py
21
22
23
24
25
26
27
28
29
30
31
32
33
def execute_quadra_experiment(overrides: list[str], experiment_path: Path) -> None:
    """Execute quadra experiment."""
    with initialize_config_module(config_module="quadra.configs", version_base="1.3.0"):
        if not experiment_path.exists():
            experiment_path.mkdir(parents=True)
        os.chdir(experiment_path)
        # cfg = compose(config_name="config", overrides=overrides)
        cfg = compose(config_name="config", overrides=overrides, return_hydra_config=True)
        # workaround without actual main function
        # check https://github.com/facebookresearch/hydra/issues/2017 for more details
        HydraConfig.instance().set_config(cfg)

        main(cfg)

get_quadra_test_device()

Get the device to use for the tests. If the QUADRA_TEST_DEVICE environment variable is set, it is used.

Source code in quadra/utils/tests/helpers.py
48
49
50
def get_quadra_test_device():
    """Get the device to use for the tests. If the QUADRA_TEST_DEVICE environment variable is set, it is used."""
    return os.environ.get("QUADRA_TEST_DEVICE", "cpu")

setup_trainer_for_lightning()

Setup trainer for lightning depending on the device. If cuda is used, the device index is also set. If cpu is used, the trainer is set to lightning_cpu.

Returns:

  • list[str]

    A list of overrides for the trainer.

Source code in quadra/utils/tests/helpers.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def setup_trainer_for_lightning() -> list[str]:
    """Setup trainer for lightning depending on the device. If cuda is used, the device index is also set.
    If cpu is used, the trainer is set to lightning_cpu.

    Returns:
        A list of overrides for the trainer.
    """
    overrides = []
    device = get_quadra_test_device()
    torch_device = torch.device(device)
    if torch_device.type == "cuda":
        device_index = torch_device.index
        overrides.append("trainer=lightning_gpu")
        overrides.append(f"trainer.devices=[{device_index}]")
    else:
        overrides.append("trainer=lightning_cpu")

    return overrides