Skip to content

base

ModelSignatureWrapper(model)

Bases: Module

Model wrapper used to retrieve input shape. It can be used as a decorator of nn.Module, the first call to the forward method will retrieve the input shape and store it in the input_shapes attribute. It will also save the model summary in a file called model_summary.txt in the current working directory.

Source code in quadra/models/base.py
21
22
23
24
25
26
27
28
29
30
def __init__(self, model: nn.Module):
    super().__init__()
    self.instance = model
    self.input_shapes: Any = None
    self.disable = False

    if isinstance(self.instance, ModelSignatureWrapper):
        # Handle nested ModelSignatureWrapper
        self.input_shapes = self.instance.input_shapes
        self.instance = self.instance.instance

cpu(*args, **kwargs)

Handle calls to to method returning the underlying model.

Source code in quadra/models/base.py
62
63
64
65
66
def cpu(self, *args, **kwargs):
    """Handle calls to to method returning the underlying model."""
    self.instance = self.instance.cpu(*args, **kwargs)

    return self

forward(*args, **kwargs)

Retrieve the input shape and forward the model, if the input shape is already retrieved it will just forward the model.

Source code in quadra/models/base.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
    """Retrieve the input shape and forward the model, if the input shape is already retrieved it will just forward
    the model.
    """
    if self.input_shapes is None and not self.disable:
        try:
            self.input_shapes = self._get_input_shapes(*args, **kwargs)
        except Exception:
            log.warning(
                "Failed to retrieve input shapes after forward! To export the model you'll need to "
                "provide the input shapes manually setting the config.export.input_shapes parameter! "
                "Alternatively you could try to use a forward with supported input types (and their compositions) "
                "(list, tuple, dict, tensors)."
            )
            self.disable = True

    return self.instance.forward(*args, **kwargs)

half(*args, **kwargs)

Handle calls to to method returning the underlying model.

Source code in quadra/models/base.py
56
57
58
59
60
def half(self, *args, **kwargs):
    """Handle calls to to method returning the underlying model."""
    self.instance = self.instance.half(*args, **kwargs)

    return self

to(*args, **kwargs)

Handle calls to to method returning the underlying model.

Source code in quadra/models/base.py
50
51
52
53
54
def to(self, *args, **kwargs):
    """Handle calls to to method returning the underlying model."""
    self.instance = self.instance.to(*args, **kwargs)

    return self