Bases: Module
Model wrapper used to retrieve input shape. It can be used as a decorator of nn.Module, the first call to the
forward method will retrieve the input shape and store it in the input_shapes attribute.
Source code in quadra/models/base.py
15
16
17
18
19
20
21
22
23
24 | def __init__(self, model: nn.Module):
super().__init__()
self.instance = model
self.input_shapes: Any = None
self.disable = False
if isinstance(self.instance, ModelSignatureWrapper):
# Handle nested ModelSignatureWrapper
self.input_shapes = self.instance.input_shapes
self.instance = self.instance.instance
|
forward(*args, **kwargs)
Retrieve the input shape and forward the model.
Source code in quadra/models/base.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45 | def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
"""Retrieve the input shape and forward the model."""
if self.input_shapes is None and not self.disable:
try:
self.input_shapes = self._get_input_shapes(*args, **kwargs)
except Exception:
# Avoid circular import
# pylint: disable=import-outside-toplevel
from quadra.utils.utils import get_logger # noqa
log = get_logger(__name__)
log.warning(
"Failed to retrieve input shapes after forward! To export the model you'll need to "
"provide the input shapes manually setting the config.export.input_shapes parameter! "
"Alternatively you could try to use a forward with supported input types (and their compositions) "
"(list, tuple, dict, tensors)."
)
self.disable = True
return self.instance.forward(*args, **kwargs)
|
to(*args, **kwargs)
Handle calls to to method returning the underlying model.
Source code in quadra/models/base.py
| def to(self, *args, **kwargs):
"""Handle calls to to method returning the underlying model."""
return ModelSignatureWrapper(self.instance.to(*args, **kwargs))
|