Skip to content

export

export_model(config, model, export_folder, half_precision, input_shapes=None, idx_to_class=None, pytorch_model_type='model')

Generate deployment models for the task.

Parameters:

  • config (DictConfig) –

    Experiment config

  • model (Any) –

    Model to be exported

  • export_folder (str) –

    Path to save the exported model

  • half_precision (bool) –

    Whether to use half precision for the exported model

  • input_shapes (list[Any] | None, default: None ) –

    Input shapes for the exported model

  • idx_to_class (dict[int, str] | None, default: None ) –

    Mapping from class index to class name

  • pytorch_model_type (Literal['backbone', 'model'], default: 'model' ) –

    Type of the pytorch model config to be exported, if it's backbone on disk we will save the config.backbone config, otherwise we will save the config.model

Returns:

  • dict[str, Any]

    If the model is exported successfully, return a dictionary containing information about the exported model and

  • dict[str, str]

    a second dictionary containing the paths to the exported models. Otherwise, return two empty dictionaries.

Source code in quadra/utils/export.py
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
def export_model(
    config: DictConfig,
    model: Any,
    export_folder: str,
    half_precision: bool,
    input_shapes: list[Any] | None = None,
    idx_to_class: dict[int, str] | None = None,
    pytorch_model_type: Literal["backbone", "model"] = "model",
) -> tuple[dict[str, Any], dict[str, str]]:
    """Generate deployment models for the task.

    Args:
        config: Experiment config
        model: Model to be exported
        export_folder: Path to save the exported model
        half_precision: Whether to use half precision for the exported model
        input_shapes: Input shapes for the exported model
        idx_to_class: Mapping from class index to class name
        pytorch_model_type: Type of the pytorch model config to be exported, if it's backbone on disk we will save the
            config.backbone config, otherwise we will save the config.model

    Returns:
        If the model is exported successfully, return a dictionary containing information about the exported model and
        a second dictionary containing the paths to the exported models. Otherwise, return two empty dictionaries.
    """
    if config.export is None or len(config.export.types) == 0:
        log.info("No export type specified skipping export")
        return {}, {}

    os.makedirs(export_folder, exist_ok=True)

    if input_shapes is None:
        # Try to get input shapes from config
        # If this is also None we will try to retrieve it from the ModelSignatureWrapper, if it fails we can't export
        input_shapes = config.export.input_shapes

    export_paths = {}

    for export_type in config.export.types:
        if export_type == "torchscript":
            out = export_torchscript_model(
                model=model,
                input_shapes=input_shapes,
                output_path=export_folder,
                half_precision=half_precision,
            )

            if out is None:
                log.warning("Torchscript export failed, enable debug logging for more details")
                continue

            export_path, input_shapes = out
            export_paths[export_type] = export_path
        elif export_type == "pytorch":
            export_path = export_pytorch_model(
                model=model,
                output_path=export_folder,
            )
            export_paths[export_type] = export_path
            with open(os.path.join(export_folder, "model_config.yaml"), "w") as f:
                OmegaConf.save(getattr(config, pytorch_model_type), f, resolve=True)
        elif export_type == "onnx":
            if not hasattr(config.export, "onnx"):
                log.warning("No onnx configuration found, skipping onnx export")
                continue

            out = export_onnx_model(
                model=model,
                output_path=export_folder,
                onnx_config=config.export.onnx,
                input_shapes=input_shapes,
                half_precision=half_precision,
            )

            if out is None:
                log.warning("ONNX export failed, enable debug logging for more details")
                continue

            export_path, input_shapes = out
            export_paths[export_type] = export_path
        else:
            log.warning("Export type: %s not implemented", export_type)

    if len(export_paths) == 0:
        log.warning("No export type was successful, no model will be available for deployment")
        return {}, export_paths

    model_json = {
        "input_size": input_shapes,
        "classes": idx_to_class,
        "mean": list(config.transforms.mean),
        "std": list(config.transforms.std),
    }

    return model_json, export_paths

export_onnx_model(model, output_path, onnx_config, input_shapes=None, half_precision=False, model_name='model.onnx')

Export a PyTorch model with ONNX.

Parameters:

  • model (Module) –

    PyTorch model to be exported

  • output_path (str) –

    Path to save the model

  • input_shapes (list[Any] | None, default: None ) –

    Input shapes for tracing

  • onnx_config (DictConfig) –

    ONNX export configuration

  • half_precision (bool, default: False ) –

    If True, the model will be exported with half precision

  • model_name (str, default: 'model.onnx' ) –

    Name of the exported model

Source code in quadra/utils/export.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
@torch.inference_mode()
def export_onnx_model(
    model: nn.Module,
    output_path: str,
    onnx_config: DictConfig,
    input_shapes: list[Any] | None = None,
    half_precision: bool = False,
    model_name: str = "model.onnx",
) -> tuple[str, Any] | None:
    """Export a PyTorch model with ONNX.

    Args:
        model: PyTorch model to be exported
        output_path: Path to save the model
        input_shapes: Input shapes for tracing
        onnx_config: ONNX export configuration
        half_precision: If True, the model will be exported with half precision
        model_name: Name of the exported model
    """
    if not ONNX_AVAILABLE:
        log.warning("ONNX is not installed, can not export model in this format.")
        log.warning("Please install ONNX capabilities for quadra with: poetry install -E onnx")
        return None

    model.eval()
    if half_precision:
        model.to("cuda:0")
        model = model.half()
    else:
        model.cpu()

    if hasattr(onnx_config, "fixed_batch_size") and onnx_config.fixed_batch_size is not None:
        batch_size = onnx_config.fixed_batch_size
    else:
        batch_size = 1

    model_inputs = extract_torch_model_inputs(
        model=model, input_shapes=input_shapes, half_precision=half_precision, batch_size=batch_size
    )
    if model_inputs is None:
        return None

    if isinstance(model, ModelSignatureWrapper):
        model = model.instance

    inp, input_shapes = model_inputs

    os.makedirs(output_path, exist_ok=True)

    model_path = os.path.join(output_path, model_name)

    input_names = onnx_config.input_names if hasattr(onnx_config, "input_names") else None

    if input_names is None:
        input_names = []
        for i, _ in enumerate(inp):
            input_names.append(f"input_{i}")

    output = [model(*inp)]
    output_names = onnx_config.output_names if hasattr(onnx_config, "output_names") else None

    if output_names is None:
        output_names = []
        for i, _ in enumerate(output):
            output_names.append(f"output_{i}")

    dynamic_axes = onnx_config.dynamic_axes if hasattr(onnx_config, "dynamic_axes") else None

    if hasattr(onnx_config, "fixed_batch_size") and onnx_config.fixed_batch_size is not None:
        dynamic_axes = None
    elif dynamic_axes is None:
        dynamic_axes = {}
        for i, _ in enumerate(input_names):
            dynamic_axes[input_names[i]] = {0: "batch_size"}

        for i, _ in enumerate(output_names):
            dynamic_axes[output_names[i]] = {0: "batch_size"}

    modified_onnx_config = cast(dict[str, Any], OmegaConf.to_container(onnx_config, resolve=True))

    modified_onnx_config["input_names"] = input_names
    modified_onnx_config["output_names"] = output_names
    modified_onnx_config["dynamic_axes"] = dynamic_axes

    simplify = modified_onnx_config.pop("simplify", False)
    _ = modified_onnx_config.pop("fixed_batch_size", None)

    if len(inp) == 1:
        inp = inp[0]

    if isinstance(inp, list):
        inp = tuple(inp)  # onnx doesn't like lists representing tuples of inputs

    if isinstance(inp, dict):
        raise ValueError("ONNX export does not support model with dict inputs")

    try:
        torch.onnx.export(model=model, args=inp, f=model_path, **modified_onnx_config)

        onnx_model = onnx.load(model_path)
        # Check if ONNX model is valid
        onnx.checker.check_model(onnx_model)
    except Exception as e:
        log.debug("ONNX export failed with error: %s", e)
        return None

    log.info("ONNX model saved to %s", os.path.join(os.getcwd(), model_path))

    if half_precision:
        is_export_ok = _safe_export_half_precision_onnx(
            model=model,
            export_model_path=model_path,
            inp=inp,
            onnx_config=onnx_config,
            input_shapes=input_shapes,
            input_names=input_names,
        )

        if not is_export_ok:
            return None

    if simplify:
        log.info("Attempting to simplify ONNX model")
        onnx_model = onnx.load(model_path)

        try:
            simplified_model, check = onnx_simplify(onnx_model)
        except Exception as e:
            log.debug("ONNX simplification failed with error: %s", e)
            check = False

        if not check:
            log.warning("Something failed during model simplification, only original ONNX model will be exported")
        else:
            model_filename, model_extension = os.path.splitext(model_name)
            model_name = f"{model_filename}_simplified{model_extension}"
            model_path = os.path.join(output_path, model_name)
            onnx.save(simplified_model, model_path)
            log.info("Simplified ONNX model saved to %s", os.path.join(os.getcwd(), model_path))

    return os.path.join(os.getcwd(), model_path), input_shapes

export_pytorch_model(model, output_path, model_name='model.pth')

Export pytorch model's parameter dictionary using a deserialized state_dict.

Parameters:

  • model (Module) –

    PyTorch model to be exported

  • output_path (str) –

    Path to save the model

  • model_name (str, default: 'model.pth' ) –

    Name of the exported model

Returns:

  • str

    If the model is exported successfully, the path to the model is returned.

Source code in quadra/utils/export.py
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
def export_pytorch_model(model: nn.Module, output_path: str, model_name: str = "model.pth") -> str:
    """Export pytorch model's parameter dictionary using a deserialized state_dict.

    Args:
        model: PyTorch model to be exported
        output_path: Path to save the model
        model_name: Name of the exported model

    Returns:
        If the model is exported successfully, the path to the model is returned.

    """
    if isinstance(model, ModelSignatureWrapper):
        model = model.instance

    os.makedirs(output_path, exist_ok=True)
    model.eval()
    model.cpu()
    model_path = os.path.join(output_path, model_name)
    torch.save(model.state_dict(), model_path)
    log.info("Pytorch model saved to %s", os.path.join(output_path, model_name))

    return os.path.join(os.getcwd(), model_path)

export_torchscript_model(model, output_path, input_shapes=None, half_precision=False, model_name='model.pt')

Export a PyTorch model with TorchScript.

Parameters:

  • model (Module) –

    PyTorch model to be exported

  • input_shapes (list[Any] | None, default: None ) –

    Inputs shape for tracing

  • output_path (str) –

    Path to save the model

  • half_precision (bool, default: False ) –

    If True, the model will be exported with half precision

  • model_name (str, default: 'model.pt' ) –

    Name of the exported model

Returns:

  • tuple[str, Any] | None

    If the model is exported successfully, the path to the model and the input shape are returned.

Source code in quadra/utils/export.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@torch.inference_mode()
def export_torchscript_model(
    model: nn.Module,
    output_path: str,
    input_shapes: list[Any] | None = None,
    half_precision: bool = False,
    model_name: str = "model.pt",
) -> tuple[str, Any] | None:
    """Export a PyTorch model with TorchScript.

    Args:
        model: PyTorch model to be exported
        input_shapes: Inputs shape for tracing
        output_path: Path to save the model
        half_precision: If True, the model will be exported with half precision
        model_name: Name of the exported model

    Returns:
        If the model is exported successfully, the path to the model and the input shape are returned.

    """
    if isinstance(model, CflowLightning):
        log.warning("Exporting cflow model with torchscript is not supported yet.")
        return None

    model.eval()
    if half_precision:
        model.to("cuda:0")
        model = model.half()
    else:
        model.cpu()

    model_inputs = extract_torch_model_inputs(model, input_shapes, half_precision)

    if model_inputs is None:
        return None

    if isinstance(model, ModelSignatureWrapper):
        model = model.instance

    inp, input_shapes = model_inputs

    try:
        try:
            model_jit = torch.jit.trace(model, inp)
        except RuntimeError as e:
            log.warning("Standard tracing failed with exception %s, attempting tracing with strict=False", e)
            model_jit = torch.jit.trace(model, inp, strict=False)

        os.makedirs(output_path, exist_ok=True)

        model_path = os.path.join(output_path, model_name)
        model_jit.save(model_path)

        log.info("Torchscript model saved to %s", os.path.join(os.getcwd(), model_path))

        return os.path.join(os.getcwd(), model_path), input_shapes
    except Exception as e:
        log.debug("Failed to export torchscript model with exception: %s", e)
        return None

extract_torch_model_inputs(model, input_shapes=None, half_precision=False, batch_size=1)

Extract the input shapes for the given model and generate a list of torch tensors with the given device and dtype.

Parameters:

  • model (Module | ModelSignatureWrapper) –

    Module or ModelSignatureWrapper

  • input_shapes (list[Any] | None, default: None ) –

    Inputs shapes

  • half_precision (bool, default: False ) –

    If True, the model will be exported with half precision

  • batch_size (int, default: 1 ) –

    Batch size for the input shapes

Source code in quadra/utils/export.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def extract_torch_model_inputs(
    model: nn.Module | ModelSignatureWrapper,
    input_shapes: list[Any] | None = None,
    half_precision: bool = False,
    batch_size: int = 1,
) -> tuple[list[Any] | tuple[Any, ...] | torch.Tensor, list[Any]] | None:
    """Extract the input shapes for the given model and generate a list of torch tensors with the
    given device and dtype.

    Args:
        model: Module or ModelSignatureWrapper
        input_shapes: Inputs shapes
        half_precision: If True, the model will be exported with half precision
        batch_size: Batch size for the input shapes
    """
    if isinstance(model, ModelSignatureWrapper) and input_shapes is None:
        input_shapes = model.input_shapes

    if input_shapes is None:
        log.warning(
            "Input shape is None, can not trace model! Please provide input_shapes in the task export configuration."
        )
        return None

    if half_precision:
        # TODO: This doesn't support bfloat16!!
        inp = generate_torch_inputs(
            input_shapes=input_shapes, device="cuda:0", half_precision=True, dtype=torch.float16, batch_size=batch_size
        )
    else:
        inp = generate_torch_inputs(
            input_shapes=input_shapes, device="cpu", half_precision=False, dtype=torch.float32, batch_size=batch_size
        )

    return inp, input_shapes

generate_torch_inputs(input_shapes, device, half_precision=False, dtype=torch.float32, batch_size=1)

Given a list of input shapes that can contain either lists, tuples or dicts, with tuples being the input shapes of the model, generate a list of torch tensors with the given device and dtype.

Source code in quadra/utils/export.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def generate_torch_inputs(
    input_shapes: list[Any],
    device: str | torch.device,
    half_precision: bool = False,
    dtype: torch.dtype = torch.float32,
    batch_size: int = 1,
) -> list[Any] | tuple[Any, ...] | torch.Tensor:
    """Given a list of input shapes that can contain either lists, tuples or dicts, with tuples being the input shapes
    of the model, generate a list of torch tensors with the given device and dtype.
    """
    inp = None

    if isinstance(input_shapes, (ListConfig, DictConfig)):
        input_shapes = OmegaConf.to_container(input_shapes)

    if isinstance(input_shapes, list):
        if any(isinstance(inp, (Sequence, dict)) for inp in input_shapes):
            return [generate_torch_inputs(inp, device, half_precision, dtype) for inp in input_shapes]

        # Base case
        inp = torch.randn((batch_size, *input_shapes), dtype=dtype, device=device)

    if isinstance(input_shapes, dict):
        return {k: generate_torch_inputs(v, device, half_precision, dtype) for k, v in input_shapes.items()}

    if isinstance(input_shapes, tuple):
        if any(isinstance(inp, (Sequence, dict)) for inp in input_shapes):
            # The tuple contains a list, tuple or dict
            return tuple(generate_torch_inputs(inp, device, half_precision, dtype) for inp in input_shapes)

        # Base case
        inp = torch.randn((batch_size, *input_shapes), dtype=dtype, device=device)

    if inp is None:
        raise RuntimeError("Something went wrong during model export, unable to parse input shapes")

    if half_precision:
        inp = inp.half()

    return inp

get_export_extension(export_type)

Get the extension of the exported model.

Parameters:

  • export_type (str) –

    The type of the exported model.

Returns:

  • str

    The extension of the exported model.

Source code in quadra/utils/export.py
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
def get_export_extension(export_type: str) -> str:
    """Get the extension of the exported model.

    Args:
        export_type: The type of the exported model.

    Returns:
        The extension of the exported model.
    """
    if export_type == "onnx":
        extension = "onnx"
    elif export_type == "torchscript":
        extension = "pt"
    elif export_type == "pytorch":
        extension = "pth"
    else:
        raise ValueError(f"Unsupported export type {export_type}")

    return extension

import_deployment_model(model_path, inference_config, device, model_architecture=None)

Try to import a model for deployment, currently only supports torchscript .pt files and state dictionaries .pth files.

Parameters:

  • model_path (str) –

    Path to the model

  • inference_config (DictConfig) –

    Inference configuration, should contain keys for the different deployment models

  • device (str) –

    Device to load the model on

  • model_architecture (Module | None, default: None ) –

    Optional model architecture to use for loading a plain pytorch model

Returns:

Source code in quadra/utils/export.py
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
def import_deployment_model(
    model_path: str,
    inference_config: DictConfig,
    device: str,
    model_architecture: nn.Module | None = None,
) -> BaseEvaluationModel:
    """Try to import a model for deployment, currently only supports torchscript .pt files and
    state dictionaries .pth files.

    Args:
        model_path: Path to the model
        inference_config: Inference configuration, should contain keys for the different deployment models
        device: Device to load the model on
        model_architecture: Optional model architecture to use for loading a plain pytorch model

    Returns:
        A tuple containing the model and the model type
    """
    log.info("Importing trained model")

    file_extension = os.path.splitext(os.path.basename(model_path))[1]
    deployment_model: BaseEvaluationModel | None = None

    if file_extension == ".pt":
        deployment_model = TorchscriptEvaluationModel(config=inference_config.torchscript)
    elif file_extension == ".pth":
        if model_architecture is None:
            raise ValueError("model_architecture must be specified when loading a .pth file")

        deployment_model = TorchEvaluationModel(config=inference_config.pytorch, model_architecture=model_architecture)
    elif file_extension == ".onnx":
        deployment_model = ONNXEvaluationModel(config=inference_config.onnx)

    if deployment_model is not None:
        deployment_model.load_from_disk(model_path=model_path, device=device)

        log.info("Imported %s model", deployment_model.__class__.__name__)

        return deployment_model

    raise ValueError(f"Unable to load model with extension {file_extension}, valid extensions are: ['.pt', 'pth']")