Skip to content

base

ModelSignatureWrapper(model)

Bases: Module

Model wrapper used to retrieve input shape. It can be used as a decorator of nn.Module, the first call to the forward method will retrieve the input shape and store it in the input_shapes attribute. It will also save the model summary in a file called model_summary.txt in the current working directory.

Source code in quadra/models/base.py
21
22
23
24
25
26
27
28
29
30
def __init__(self, model: nn.Module):
    super().__init__()
    self.instance = model
    self.input_shapes: Any = None
    self.disable = False

    if isinstance(self.instance, ModelSignatureWrapper):
        # Handle nested ModelSignatureWrapper
        self.input_shapes = self.instance.input_shapes
        self.instance = self.instance.instance

_get_input_shape(inp)

Recursive function to retrieve the input shapes.

Source code in quadra/models/base.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def _get_input_shape(self, inp: Sequence | torch.Tensor) -> list[Any] | tuple[Any, ...] | dict[str, Any]:
    """Recursive function to retrieve the input shapes."""
    if isinstance(inp, list):
        return [self._get_input_shape(i) for i in inp]

    if isinstance(inp, tuple):
        return tuple(self._get_input_shape(i) for i in inp)

    if isinstance(inp, torch.Tensor):
        return tuple(inp.shape[1:])

    if isinstance(inp, dict):
        return {k: self._get_input_shape(v) for k, v in inp.items()}

    raise ValueError(f"Input type {type(inp)} not supported")

_get_input_shapes(*args, **kwargs)

Retrieve the input shapes from the input. Inputs will be in the same order as the forward method signature.

Source code in quadra/models/base.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def _get_input_shapes(self, *args: Any, **kwargs: Any) -> list[Any]:
    """Retrieve the input shapes from the input. Inputs will be in the same order as the forward method
    signature.
    """
    input_shapes = []

    for arg in args:
        input_shapes.append(self._get_input_shape(arg))

    if isinstance(self.instance.forward, torch.ScriptMethod):
        # Handle torchscript backbones
        for i, argument in enumerate(self.instance.forward.schema.arguments):  # type: ignore[attr-defined]
            if i < (len(args) + 1):  # +1 for self
                continue

            if argument.name == "self":
                continue

            if argument.name in kwargs:
                input_shapes.append(self._get_input_shape(kwargs[argument.name]))
            else:
                # Retrieve the default value
                input_shapes.append(self._get_input_shape(argument.default_value))
    else:
        signature = inspect.signature(self.instance.forward)

        for i, key in enumerate(signature.parameters.keys()):
            if i < len(args):
                continue

            if key in kwargs:
                input_shapes.append(self._get_input_shape(kwargs[key]))
            else:
                # Retrieve the default value
                input_shapes.append(self._get_input_shape(signature.parameters[key].default))

    return input_shapes

cpu(*args, **kwargs)

Handle calls to to method returning the underlying model.

Source code in quadra/models/base.py
62
63
64
65
66
def cpu(self, *args, **kwargs):
    """Handle calls to to method returning the underlying model."""
    self.instance = self.instance.cpu(*args, **kwargs)

    return self

forward(*args, **kwargs)

Retrieve the input shape and forward the model, if the input shape is already retrieved it will just forward the model.

Source code in quadra/models/base.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
    """Retrieve the input shape and forward the model, if the input shape is already retrieved it will just forward
    the model.
    """
    if self.input_shapes is None and not self.disable:
        try:
            self.input_shapes = self._get_input_shapes(*args, **kwargs)
        except Exception:
            log.warning(
                "Failed to retrieve input shapes after forward! To export the model you'll need to "
                "provide the input shapes manually setting the config.export.input_shapes parameter! "
                "Alternatively you could try to use a forward with supported input types (and their compositions) "
                "(list, tuple, dict, tensors)."
            )
            self.disable = True

    return self.instance.forward(*args, **kwargs)

half(*args, **kwargs)

Handle calls to to method returning the underlying model.

Source code in quadra/models/base.py
56
57
58
59
60
def half(self, *args, **kwargs):
    """Handle calls to to method returning the underlying model."""
    self.instance = self.instance.half(*args, **kwargs)

    return self

to(*args, **kwargs)

Handle calls to to method returning the underlying model.

Source code in quadra/models/base.py
50
51
52
53
54
def to(self, *args, **kwargs):
    """Handle calls to to method returning the underlying model."""
    self.instance = self.instance.to(*args, **kwargs)

    return self