Skip to content

mlflow

get_mlflow_logger(trainer)

Safely get Mlflow logger from Trainer loggers.

Parameters:

  • trainer (Trainer) –

    Pytorch Lightning trainer.

Returns:

  • MLFlowLogger | None

    An mlflow logger if available, else None.

Source code in quadra/utils/mlflow.py
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def get_mlflow_logger(trainer: Trainer) -> MLFlowLogger | None:
    """Safely get Mlflow logger from Trainer loggers.

    Args:
        trainer: Pytorch Lightning trainer.

    Returns:
        An mlflow logger if available, else None.
    """
    if isinstance(trainer.logger, MLFlowLogger):
        return trainer.logger

    if isinstance(trainer.logger, list):
        for logger in trainer.logger:
            if isinstance(logger, MLFlowLogger):
                return logger

    return None

infer_signature_input_torch(input_tensor)

Recursively infer the signature input format to pass to mlflow.models.infer_signature.

Raises:

  • ValueError

    If the input type is not supported or when nested dicts or sequences are encountered.

Source code in quadra/utils/mlflow.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def infer_signature_input_torch(input_tensor: Any) -> Any:
    """Recursively infer the signature input format to pass to mlflow.models.infer_signature.

    Raises:
        ValueError: If the input type is not supported or when nested dicts or sequences are encountered.
    """
    if isinstance(input_tensor, Sequence):
        # Mlflow currently does not support sequence outputs, so we use a dict instead
        signature = {}
        for i, x in enumerate(input_tensor):
            if isinstance(x, Sequence):
                # Nested signature is currently not supported by mlflow
                raise ValueError("Nested sequences are not supported")
                # TODO: Enable this once mlflow supports nested signatures
                # signature[f"output_{i}"] = {f"output_{j}": infer_signature_torch(y) for j, y in enumerate(x)}
            if isinstance(x, dict):
                # Nested dicts are not supported
                raise ValueError("Nested dicts are not supported")

            signature[f"output_{i}"] = infer_signature_input_torch(x)
    elif isinstance(input_tensor, torch.Tensor):
        signature = input_tensor.cpu().numpy()
    elif isinstance(input_tensor, dict):
        signature = {}
        for k, v in input_tensor.items():
            if isinstance(v, dict):
                # Nested dicts are not supported
                raise ValueError("Nested dicts are not supported")
            if isinstance(v, Sequence):
                # Nested signature is currently not supported by mlflow
                raise ValueError("Nested sequences are not supported")

            signature[k] = infer_signature_input_torch(v)
    else:
        raise ValueError(f"Unable to infer signature for model output type {type(input_tensor)}")

    return signature

infer_signature_torch_model(model, data)

Infer input and output signature for a PyTorch/Torchscript model.

Source code in quadra/utils/mlflow.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
@torch.inference_mode()
def infer_signature_torch_model(model: NnModuleT, data: list[Any]) -> ModelSignature | None:
    """Infer input and output signature for a PyTorch/Torchscript model."""
    model = model.eval()
    model = model.cpu()
    model_output = model(*data)

    try:
        output_signature = infer_signature_input_torch(model_output)

        if len(data) == 1:
            signature_input = infer_signature_input_torch(data[0])
        else:
            signature_input = infer_signature_input_torch(data)
    except ValueError:
        # TODO: Solve circular import as it is not possible to import get_logger right now
        # log.warning("Unable to infer signature for model output type %s", type(model_output))
        return None

    return infer_signature(signature_input, output_signature)