Skip to content

export

export_model(config, model, export_folder, half_precision, input_shapes=None, idx_to_class=None, pytorch_model_type='model')

Generate deployment models for the task.

Parameters:

  • config (DictConfig) –

    Experiment config

  • model (Any) –

    Model to be exported

  • export_folder (str) –

    Path to save the exported model

  • half_precision (bool) –

    Whether to use half precision for the exported model

  • input_shapes (Optional[List[Any]], default: None ) –

    Input shapes for the exported model

  • idx_to_class (Optional[Dict[int, str]], default: None ) –

    Mapping from class index to class name

  • pytorch_model_type (Literal['backbone', 'model'], default: 'model' ) –

    Type of the pytorch model config to be exported, if it's backbone on disk we will save the config.backbone config, otherwise we will save the config.model

Returns:

  • Dict[str, Any]

    If the model is exported successfully, return a dictionary containing information about the exported model and

  • Dict[str, str]

    a second dictionary containing the paths to the exported models. Otherwise, return two empty dictionaries.

Source code in quadra/utils/export.py
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
def export_model(
    config: DictConfig,
    model: Any,
    export_folder: str,
    half_precision: bool,
    input_shapes: Optional[List[Any]] = None,
    idx_to_class: Optional[Dict[int, str]] = None,
    pytorch_model_type: Literal["backbone", "model"] = "model",
) -> Tuple[Dict[str, Any], Dict[str, str]]:
    """Generate deployment models for the task.

    Args:
        config: Experiment config
        model: Model to be exported
        export_folder: Path to save the exported model
        half_precision: Whether to use half precision for the exported model
        input_shapes: Input shapes for the exported model
        idx_to_class: Mapping from class index to class name
        pytorch_model_type: Type of the pytorch model config to be exported, if it's backbone on disk we will save the
            config.backbone config, otherwise we will save the config.model

    Returns:
        If the model is exported successfully, return a dictionary containing information about the exported model and
        a second dictionary containing the paths to the exported models. Otherwise, return two empty dictionaries.
    """
    if config.export is None or len(config.export.types) == 0:
        log.info("No export type specified skipping export")
        return {}, {}

    os.makedirs(export_folder, exist_ok=True)

    if input_shapes is None:
        # Try to get input shapes from config
        # If this is also None we will try to retrieve it from the ModelSignatureWrapper, if it fails we can't export
        input_shapes = config.export.input_shapes

    export_paths = {}

    for export_type in config.export.types:
        if export_type == "torchscript":
            out = export_torchscript_model(
                model=model,
                input_shapes=input_shapes,
                output_path=export_folder,
                half_precision=half_precision,
            )

            if out is None:
                log.warning("Torchscript export failed, enable debug logging for more details")
                continue

            export_path, input_shapes = out
            export_paths[export_type] = export_path
        elif export_type == "pytorch":
            export_path = export_pytorch_model(
                model=model,
                output_path=export_folder,
            )
            export_paths[export_type] = export_path
            with open(os.path.join(export_folder, "model_config.yaml"), "w") as f:
                OmegaConf.save(getattr(config, pytorch_model_type), f, resolve=True)
        elif export_type == "onnx":
            if not hasattr(config.export, "onnx"):
                log.warning("No onnx configuration found, skipping onnx export")
                continue

            out = export_onnx_model(
                model=model,
                output_path=export_folder,
                onnx_config=config.export.onnx,
                input_shapes=input_shapes,
                half_precision=half_precision,
            )

            if out is None:
                log.warning("ONNX export failed, enable debug logging for more details")
                continue

            export_path, input_shapes = out
            export_paths[export_type] = export_path
        else:
            log.warning("Export type: %s not implemented", export_type)

    if len(export_paths) == 0:
        log.warning("No export type was successful, no model will be available for deployment")
        return {}, export_paths

    model_json = {
        "input_size": input_shapes,
        "classes": idx_to_class,
        "mean": list(config.transforms.mean),
        "std": list(config.transforms.std),
    }

    return model_json, export_paths

export_onnx_model(model, output_path, onnx_config, input_shapes=None, half_precision=False, model_name='model.onnx')

Export a PyTorch model with ONNX.

Parameters:

  • model (Module) –

    PyTorch model to be exported

  • output_path (str) –

    Path to save the model

  • input_shapes (Optional[List[Any]], default: None ) –

    Input shapes for tracing

  • onnx_config (DictConfig) –

    ONNX export configuration

  • half_precision (bool, default: False ) –

    If True, the model will be exported with half precision

  • model_name (str, default: 'model.onnx' ) –

    Name of the exported model

Source code in quadra/utils/export.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
@torch.inference_mode()
def export_onnx_model(
    model: nn.Module,
    output_path: str,
    onnx_config: DictConfig,
    input_shapes: Optional[List[Any]] = None,
    half_precision: bool = False,
    model_name: str = "model.onnx",
) -> Optional[Tuple[str, Any]]:
    """Export a PyTorch model with ONNX.

    Args:
        model: PyTorch model to be exported
        output_path: Path to save the model
        input_shapes: Input shapes for tracing
        onnx_config: ONNX export configuration
        half_precision: If True, the model will be exported with half precision
        model_name: Name of the exported model
    """
    if not ONNX_AVAILABLE:
        log.warning("ONNX is not installed, can not export model in this format.")
        log.warning("Please install ONNX capabilities for quadra with: pip install .[onnx]")
        return None

    model.eval()
    if half_precision:
        model.to("cuda:0")
        model = model.half()
    else:
        model.cpu()

    if hasattr(onnx_config, "fixed_batch_size") and onnx_config.fixed_batch_size is not None:
        batch_size = onnx_config.fixed_batch_size
    else:
        batch_size = 1

    model_inputs = extract_torch_model_inputs(
        model=model, input_shapes=input_shapes, half_precision=half_precision, batch_size=batch_size
    )
    if model_inputs is None:
        return None

    if isinstance(model, ModelSignatureWrapper):
        model = model.instance

    inp, input_shapes = model_inputs

    os.makedirs(output_path, exist_ok=True)

    model_path = os.path.join(output_path, model_name)

    input_names = onnx_config.input_names if hasattr(onnx_config, "input_names") else None

    if input_names is None:
        input_names = []
        for i, _ in enumerate(inp):
            input_names.append(f"input_{i}")

    output = [model(*inp)]
    output_names = onnx_config.output_names if hasattr(onnx_config, "output_names") else None

    if output_names is None:
        output_names = []
        for i, _ in enumerate(output):
            output_names.append(f"output_{i}")

    dynamic_axes = onnx_config.dynamic_axes if hasattr(onnx_config, "dynamic_axes") else None

    if hasattr(onnx_config, "fixed_batch_size") and onnx_config.fixed_batch_size is not None:
        dynamic_axes = None
    else:
        if dynamic_axes is None:
            dynamic_axes = {}
            for i, _ in enumerate(input_names):
                dynamic_axes[input_names[i]] = {0: "batch_size"}

            for i, _ in enumerate(output_names):
                dynamic_axes[output_names[i]] = {0: "batch_size"}

    onnx_config = cast(Dict[str, Any], OmegaConf.to_container(onnx_config, resolve=True))

    onnx_config["input_names"] = input_names
    onnx_config["output_names"] = output_names
    onnx_config["dynamic_axes"] = dynamic_axes

    simplify = onnx_config.pop("simplify", False)
    _ = onnx_config.pop("fixed_batch_size", None)

    if len(inp) == 1:
        inp = inp[0]

    if isinstance(inp, list):
        inp = tuple(inp)  # onnx doesn't like lists representing tuples of inputs

    if isinstance(inp, dict):
        raise ValueError("ONNX export does not support model with dict inputs")

    try:
        torch.onnx.export(model=model, args=inp, f=model_path, **onnx_config)

        onnx_model = onnx.load(model_path)
        # Check if ONNX model is valid
        onnx.checker.check_model(onnx_model)
    except Exception as e:
        log.debug("ONNX export failed with error: %s", e)
        return None

    log.info("ONNX model saved to %s", os.path.join(os.getcwd(), model_path))

    if simplify:
        log.info("Attempting to simplify ONNX model")
        onnx_model = onnx.load(model_path)

        try:
            simplified_model, check = onnx_simplify(onnx_model)
        except Exception as e:
            log.debug("ONNX simplification failed with error: %s", e)
            check = False

        if not check:
            log.warning("Something failed during model simplification, only original ONNX model will be exported")
        else:
            model_filename, model_extension = os.path.splitext(model_name)
            model_name = f"{model_filename}_simplified{model_extension}"
            model_path = os.path.join(output_path, model_name)
            onnx.save(simplified_model, model_path)
            log.info("Simplified ONNX model saved to %s", os.path.join(os.getcwd(), model_path))

    return os.path.join(os.getcwd(), model_path), input_shapes

export_pytorch_model(model, output_path, model_name='model.pth')

Export pytorch model's parameter dictionary using a deserialized state_dict.

Parameters:

  • model (Module) –

    PyTorch model to be exported

  • output_path (str) –

    Path to save the model

  • model_name (str, default: 'model.pth' ) –

    Name of the exported model

Returns:

  • str

    If the model is exported successfully, the path to the model is returned.

Source code in quadra/utils/export.py
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
def export_pytorch_model(model: nn.Module, output_path: str, model_name: str = "model.pth") -> str:
    """Export pytorch model's parameter dictionary using a deserialized state_dict.

    Args:
        model: PyTorch model to be exported
        output_path: Path to save the model
        model_name: Name of the exported model

    Returns:
        If the model is exported successfully, the path to the model is returned.

    """
    if isinstance(model, ModelSignatureWrapper):
        model = model.instance

    os.makedirs(output_path, exist_ok=True)
    model.eval()
    model.cpu()
    model_path = os.path.join(output_path, model_name)
    torch.save(model.state_dict(), model_path)
    log.info("Pytorch model saved to %s", os.path.join(output_path, model_name))

    return os.path.join(os.getcwd(), model_path)

export_torchscript_model(model, output_path, input_shapes=None, half_precision=False, model_name='model.pt')

Export a PyTorch model with TorchScript.

Parameters:

  • model (Module) –

    PyTorch model to be exported

  • input_shapes (Optional[List[Any]], default: None ) –

    Inputs shape for tracing

  • output_path (str) –

    Path to save the model

  • half_precision (bool, default: False ) –

    If True, the model will be exported with half precision

  • model_name (str, default: 'model.pt' ) –

    Name of the exported model

Returns:

  • Optional[Tuple[str, Any]]

    If the model is exported successfully, the path to the model and the input shape are returned.

Source code in quadra/utils/export.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
@torch.inference_mode()
def export_torchscript_model(
    model: nn.Module,
    output_path: str,
    input_shapes: Optional[List[Any]] = None,
    half_precision: bool = False,
    model_name: str = "model.pt",
) -> Optional[Tuple[str, Any]]:
    """Export a PyTorch model with TorchScript.

    Args:
        model: PyTorch model to be exported
        input_shapes: Inputs shape for tracing
        output_path: Path to save the model
        half_precision: If True, the model will be exported with half precision
        model_name: Name of the exported model

    Returns:
        If the model is exported successfully, the path to the model and the input shape are returned.

    """
    if isinstance(model, CflowLightning):
        log.warning("Exporting cflow model with torchscript is not supported yet.")
        return None

    model.eval()
    if half_precision:
        model.to("cuda:0")
        model = model.half()
    else:
        model.cpu()

    model_inputs = extract_torch_model_inputs(model, input_shapes, half_precision)

    if model_inputs is None:
        return None

    if isinstance(model, ModelSignatureWrapper):
        model = model.instance

    inp, input_shapes = model_inputs

    try:
        try:
            model_jit = torch.jit.trace(model, inp)
        except RuntimeError as e:
            log.warning("Standard tracing failed with exception %s, attempting tracing with strict=False", e)
            model_jit = torch.jit.trace(model, inp, strict=False)

        os.makedirs(output_path, exist_ok=True)

        model_path = os.path.join(output_path, model_name)
        model_jit.save(model_path)

        log.info("Torchscript model saved to %s", os.path.join(os.getcwd(), model_path))

        return os.path.join(os.getcwd(), model_path), input_shapes
    except Exception as e:
        log.debug("Failed to export torchscript model with exception: %s", e)
        return None

extract_torch_model_inputs(model, input_shapes=None, half_precision=False, batch_size=1)

Extract the input shapes for the given model and generate a list of torch tensors with the given device and dtype.

Parameters:

  • model (Union[Module, ModelSignatureWrapper]) –

    Module or ModelSignatureWrapper

  • input_shapes (Optional[List[Any]], default: None ) –

    Inputs shapes

  • half_precision (bool, default: False ) –

    If True, the model will be exported with half precision

  • batch_size (int, default: 1 ) –

    Batch size for the input shapes

Source code in quadra/utils/export.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def extract_torch_model_inputs(
    model: Union[nn.Module, ModelSignatureWrapper],
    input_shapes: Optional[List[Any]] = None,
    half_precision: bool = False,
    batch_size: int = 1,
) -> Optional[Tuple[Union[List[Any], Tuple[Any, ...], torch.Tensor], List[Any]]]:
    """Extract the input shapes for the given model and generate a list of torch tensors with the
    given device and dtype.

    Args:
        model: Module or ModelSignatureWrapper
        input_shapes: Inputs shapes
        half_precision: If True, the model will be exported with half precision
        batch_size: Batch size for the input shapes
    """
    if isinstance(model, ModelSignatureWrapper):
        if input_shapes is None:
            input_shapes = model.input_shapes

    if input_shapes is None:
        log.warning(
            "Input shape is None, can not trace model! Please provide input_shapes in the task export configuration."
        )
        return None

    if half_precision:
        inp = generate_torch_inputs(
            input_shapes=input_shapes, device="cuda:0", half_precision=True, dtype=torch.float16, batch_size=batch_size
        )
    else:
        inp = generate_torch_inputs(
            input_shapes=input_shapes, device="cpu", half_precision=False, dtype=torch.float32, batch_size=batch_size
        )

    return inp, input_shapes

generate_torch_inputs(input_shapes, device, half_precision=False, dtype=torch.float32, batch_size=1)

Given a list of input shapes that can contain either lists, tuples or dicts, with tuples being the input shapes of the model, generate a list of torch tensors with the given device and dtype.

Source code in quadra/utils/export.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def generate_torch_inputs(
    input_shapes: List[Any],
    device: str,
    half_precision: bool = False,
    dtype: torch.dtype = torch.float32,
    batch_size: int = 1,
) -> Union[List[Any], Tuple[Any, ...], torch.Tensor]:
    """Given a list of input shapes that can contain either lists, tuples or dicts, with tuples being the input shapes
    of the model, generate a list of torch tensors with the given device and dtype.
    """
    if isinstance(input_shapes, list):
        if any(isinstance(inp, (Sequence, dict)) for inp in input_shapes):
            return [generate_torch_inputs(inp, device, half_precision, dtype) for inp in input_shapes]

        # Base case
        inp = torch.randn((batch_size, *input_shapes), dtype=dtype, device=device)

    if isinstance(input_shapes, dict):
        return {k: generate_torch_inputs(v, device, half_precision, dtype) for k, v in input_shapes.items()}

    if isinstance(input_shapes, tuple):
        if any(isinstance(inp, (Sequence, dict)) for inp in input_shapes):
            # The tuple contains a list, tuple or dict
            return tuple(generate_torch_inputs(inp, device, half_precision, dtype) for inp in input_shapes)

        # Base case
        inp = torch.randn((batch_size, *input_shapes), dtype=dtype, device=device)

    if half_precision:
        inp = inp.half()

    return inp

get_export_extension(export_type)

Get the extension of the exported model.

Parameters:

  • export_type (str) –

    The type of the exported model.

Returns:

  • str

    The extension of the exported model.

Source code in quadra/utils/export.py
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
def get_export_extension(export_type: str) -> str:
    """Get the extension of the exported model.

    Args:
        export_type: The type of the exported model.

    Returns:
        The extension of the exported model.
    """
    if export_type == "onnx":
        extension = "onnx"
    elif export_type == "torchscript":
        extension = "pt"
    elif export_type == "pytorch":
        extension = "pth"
    else:
        raise ValueError(f"Unsupported export type {export_type}")

    return extension

import_deployment_model(model_path, inference_config, device, model_architecture=None)

Try to import a model for deployment, currently only supports torchscript .pt files and state dictionaries .pth files.

Parameters:

  • model_path (str) –

    Path to the model

  • inference_config (DictConfig) –

    Inference configuration, should contain keys for the different deployment models

  • device (str) –

    Device to load the model on

  • model_architecture (Optional[Module], default: None ) –

    Optional model architecture to use for loading a plain pytorch model

Returns:

Source code in quadra/utils/export.py
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
def import_deployment_model(
    model_path: str, inference_config: DictConfig, device: str, model_architecture: Optional[nn.Module] = None
) -> BaseEvaluationModel:
    """Try to import a model for deployment, currently only supports torchscript .pt files and
    state dictionaries .pth files.

    Args:
        model_path: Path to the model
        inference_config: Inference configuration, should contain keys for the different deployment models
        device: Device to load the model on
        model_architecture: Optional model architecture to use for loading a plain pytorch model

    Returns:
        A tuple containing the model and the model type
    """
    log.info("Importing trained model")

    file_extension = os.path.splitext(os.path.basename(model_path))[1]
    deployment_model: Optional[BaseEvaluationModel] = None

    if file_extension == ".pt":
        deployment_model = TorchscriptEvaluationModel(config=inference_config.torchscript)
    elif file_extension == ".pth":
        if model_architecture is None:
            raise ValueError("model_architecture must be specified when loading a .pth file")

        deployment_model = TorchEvaluationModel(config=inference_config.pytorch, model_architecture=model_architecture)
    elif file_extension == ".onnx":
        deployment_model = ONNXEvaluationModel(config=inference_config.onnx)

    if deployment_model is not None:
        deployment_model.load_from_disk(model_path=model_path, device=device)

        log.info("Imported %s model", deployment_model.__class__.__name__)

        return deployment_model

    raise ValueError(f"Unable to load model with extension {file_extension}, valid extensions are: ['.pt', 'pth']")