Skip to content

base

ModelSignatureWrapper(model)

Bases: Module

Model wrapper used to retrieve input shape. It can be used as a decorator of nn.Module, the first call to the forward method will retrieve the input shape and store it in the input_shapes attribute.

Source code in quadra/models/base.py
15
16
17
18
19
20
21
22
23
24
def __init__(self, model: nn.Module):
    super().__init__()
    self.instance = model
    self.input_shapes: Any = None
    self.disable = False

    if isinstance(self.instance, ModelSignatureWrapper):
        # Handle nested ModelSignatureWrapper
        self.input_shapes = self.instance.input_shapes
        self.instance = self.instance.instance

forward(*args, **kwargs)

Retrieve the input shape and forward the model.

Source code in quadra/models/base.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
    """Retrieve the input shape and forward the model."""
    if self.input_shapes is None and not self.disable:
        try:
            self.input_shapes = self._get_input_shapes(*args, **kwargs)
        except Exception:
            # Avoid circular import
            # pylint: disable=import-outside-toplevel
            from quadra.utils.utils import get_logger  # noqa

            log = get_logger(__name__)
            log.warning(
                "Failed to retrieve input shapes after forward! To export the model you'll need to "
                "provide the input shapes manually setting the config.export.input_shapes parameter! "
                "Alternatively you could try to use a forward with supported input types (and their compositions) "
                "(list, tuple, dict, tensors)."
            )
            self.disable = True

    return self.instance.forward(*args, **kwargs)

to(*args, **kwargs)

Handle calls to to method returning the underlying model.

Source code in quadra/models/base.py
47
48
49
def to(self, *args, **kwargs):
    """Handle calls to to method returning the underlying model."""
    return ModelSignatureWrapper(self.instance.to(*args, **kwargs))