Bases: Module
Model wrapper used to retrieve input shape. It can be used as a decorator of nn.Module, the first call to the
forward method will retrieve the input shape and store it in the input_shapes attribute.
It will also save the model summary in a file called model_summary.txt in the current working directory.
Source code in quadra/models/base.py
21
22
23
24
25
26
27
28
29
30 | def __init__(self, model: nn.Module):
super().__init__()
self.instance = model
self.input_shapes: Any = None
self.disable = False
if isinstance(self.instance, ModelSignatureWrapper):
# Handle nested ModelSignatureWrapper
self.input_shapes = self.instance.input_shapes
self.instance = self.instance.instance
|
Recursive function to retrieve the input shapes.
Source code in quadra/models/base.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120 | def _get_input_shape(self, inp: Sequence | torch.Tensor) -> list[Any] | tuple[Any, ...] | dict[str, Any]:
"""Recursive function to retrieve the input shapes."""
if isinstance(inp, list):
return [self._get_input_shape(i) for i in inp]
if isinstance(inp, tuple):
return tuple(self._get_input_shape(i) for i in inp)
if isinstance(inp, torch.Tensor):
return tuple(inp.shape[1:])
if isinstance(inp, dict):
return {k: self._get_input_shape(v) for k, v in inp.items()}
raise ValueError(f"Input type {type(inp)} not supported")
|
Retrieve the input shapes from the input. Inputs will be in the same order as the forward method
signature.
Source code in quadra/models/base.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104 | def _get_input_shapes(self, *args: Any, **kwargs: Any) -> list[Any]:
"""Retrieve the input shapes from the input. Inputs will be in the same order as the forward method
signature.
"""
input_shapes = []
for arg in args:
input_shapes.append(self._get_input_shape(arg))
if isinstance(self.instance.forward, torch.ScriptMethod):
# Handle torchscript backbones
for i, argument in enumerate(self.instance.forward.schema.arguments): # type: ignore[attr-defined]
if i < (len(args) + 1): # +1 for self
continue
if argument.name == "self":
continue
if argument.name in kwargs:
input_shapes.append(self._get_input_shape(kwargs[argument.name]))
else:
# Retrieve the default value
input_shapes.append(self._get_input_shape(argument.default_value))
else:
signature = inspect.signature(self.instance.forward)
for i, key in enumerate(signature.parameters.keys()):
if i < len(args):
continue
if key in kwargs:
input_shapes.append(self._get_input_shape(kwargs[key]))
else:
# Retrieve the default value
input_shapes.append(self._get_input_shape(signature.parameters[key].default))
return input_shapes
|
cpu(*args, **kwargs)
Handle calls to to method returning the underlying model.
Source code in quadra/models/base.py
| def cpu(self, *args, **kwargs):
"""Handle calls to to method returning the underlying model."""
self.instance = self.instance.cpu(*args, **kwargs)
return self
|
forward(*args, **kwargs)
Retrieve the input shape and forward the model, if the input shape is already retrieved it will just forward
the model.
Source code in quadra/models/base.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48 | def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
"""Retrieve the input shape and forward the model, if the input shape is already retrieved it will just forward
the model.
"""
if self.input_shapes is None and not self.disable:
try:
self.input_shapes = self._get_input_shapes(*args, **kwargs)
except Exception:
log.warning(
"Failed to retrieve input shapes after forward! To export the model you'll need to "
"provide the input shapes manually setting the config.export.input_shapes parameter! "
"Alternatively you could try to use a forward with supported input types (and their compositions) "
"(list, tuple, dict, tensors)."
)
self.disable = True
return self.instance.forward(*args, **kwargs)
|
half(*args, **kwargs)
Handle calls to to method returning the underlying model.
Source code in quadra/models/base.py
| def half(self, *args, **kwargs):
"""Handle calls to to method returning the underlying model."""
self.instance = self.instance.half(*args, **kwargs)
return self
|
to(*args, **kwargs)
Handle calls to to method returning the underlying model.
Source code in quadra/models/base.py
| def to(self, *args, **kwargs):
"""Handle calls to to method returning the underlying model."""
self.instance = self.instance.to(*args, **kwargs)
return self
|