Skip to content

utils

Common utility functions. Some of them are mostly based on https://github.com/ashleve/lightning-hydra-template.

AllGatherSyncFunction

Bases: torch.autograd.Function

Function to gather gradients from multiple GPUs.

HydraEncoder

Bases: json.JSONEncoder

Custom JSON encoder to handle OmegaConf objects.

default(o)

Convert OmegaConf objects to base python objects.

Source code in quadra/utils/utils.py
365
366
367
368
369
370
def default(self, o):
    """Convert OmegaConf objects to base python objects."""
    if o is not None:
        if OmegaConf.is_config(o):
            return OmegaConf.to_container(o)
    return json.JSONEncoder.default(self, o)

NumpyEncoder

Bases: json.JSONEncoder

Custom JSON encoder to handle numpy objects.

default(o)

Custom JSON encoder to handle numpy objects.

Source code in quadra/utils/utils.py
376
377
378
379
380
381
382
383
384
385
def default(self, o):
    """Custom JSON encoder to handle numpy objects."""
    if o is not None:
        if isinstance(o, np.ndarray):
            if o.size == 1:
                return o.item()
            return o.tolist()
        if isinstance(o, np.number):
            return o.item()
    return json.JSONEncoder.default(self, o)

concat_all_gather(tensor)

Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient.

Source code in quadra/utils/utils.py
412
413
414
415
416
417
418
419
420
421
@torch.no_grad()
def concat_all_gather(tensor):
    """Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output

extras(config)

A couple of optional utilities, controlled by main config file: - disabling warnings - forcing debug friendly configuration - verifying experiment name is set when running in experiment mode. Modifies DictConfig in place.

Parameters:

  • config (DictConfig) –

    Configuration composed by Hydra.

Source code in quadra/utils/utils.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def extras(config: DictConfig) -> None:
    """A couple of optional utilities, controlled by main config file:
    - disabling warnings
    - forcing debug friendly configuration
    - verifying experiment name is set when running in experiment mode.
    Modifies DictConfig in place.

    Args:
        config: Configuration composed by Hydra.
    """
    logging.basicConfig()
    logging.getLogger().setLevel(config.core.log_level.upper())

    log = get_logger(__name__)
    config.core.command += " ".join(sys.argv)
    config.core.experiment_path = os.getcwd()

    # disable python warnings if <config.ignore_warnings=True>
    if config.get("ignore_warnings"):
        log.info("Disabling python warnings! <config.ignore_warnings=True>")
        warnings.filterwarnings("ignore")

    # force debugger friendly configuration if <config.trainer.fast_dev_run=True>
    # debuggers don't like GPUs and multiprocessing
    if config.get("trainer") and config.trainer.get("fast_dev_run"):
        log.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>")
        if config.trainer.get("gpus"):
            config.trainer.devices = 1
            config.trainer.accelerator = "cpu"
            config.trainer.gpus = None
        if config.datamodule.get("pin_memory"):
            config.datamodule.pin_memory = False
        if config.datamodule.get("num_workers"):
            config.datamodule.num_workers = 0

finish(config, module, datamodule, trainer, callbacks, logger, export_folder)

Upload config files to MLFlow server.

Parameters:

  • config (DictConfig) –

    Configuration composed by Hydra.

  • module (pl.LightningModule) –

    LightningModule.

  • datamodule (pl.LightningDataModule) –

    LightningDataModule.

  • trainer (pl.Trainer) –

    LightningTrainer.

  • callbacks (List[pl.Callback]) –

    List of LightningCallbacks.

  • logger (List[pl.loggers.Logger]) –

    List of LightningLoggers.

  • export_folder (str) –

    Folder where the deployment models are exported.

Source code in quadra/utils/utils.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def finish(
    config: DictConfig,
    module: pl.LightningModule,
    datamodule: pl.LightningDataModule,
    trainer: pl.Trainer,
    callbacks: List[pl.Callback],
    logger: List[pl.loggers.Logger],
    export_folder: str,
) -> None:
    """Upload config files to MLFlow server.

    Args:
        config: Configuration composed by Hydra.
        module: LightningModule.
        datamodule: LightningDataModule.
        trainer: LightningTrainer.
        callbacks: List of LightningCallbacks.
        logger: List of LightningLoggers.
        export_folder: Folder where the deployment models are exported.
    """
    # pylint: disable=unused-argument

    if len(logger) > 0 and config.core.get("upload_artifacts"):
        mlflow_logger = get_mlflow_logger(trainer=trainer)
        tensorboard_logger = get_tensorboard_logger(trainer=trainer)
        file_names = ["config.yaml", "config_resolved.yaml", "config_tree.txt", "data/dataset.csv"]

        if mlflow_logger is not None:
            config_paths = []

            for f in file_names:
                if os.path.isfile(os.path.join(os.getcwd(), f)):
                    config_paths.append(os.path.join(os.getcwd(), f))

            for path in config_paths:
                mlflow_logger.experiment.log_artifact(
                    run_id=mlflow_logger.run_id, local_path=path, artifact_path="metadata"
                )

            deployed_models = glob.glob(os.path.join(export_folder, "*"))
            model_json: Optional[Dict[str, Any]] = None

            if os.path.exists(os.path.join(export_folder, "model.json")):
                with open(os.path.join(export_folder, "model.json"), "r") as json_file:
                    model_json = json.load(json_file)

            if model_json is not None:
                for model_path in deployed_models:
                    if model_path.endswith(".pt"):
                        model, _ = quadra_export.import_deployment_model(model_path, device="cpu")

                        input_size = model_json["input_size"]

                        # Not a huge fan of this check
                        if not isinstance(input_size[0], list):
                            # Input size is not a list of lists
                            input_size = [input_size]

                        inputs = cast(List[Any], quadra_export.generate_torch_inputs(input_size, device="cpu"))
                        signature = infer_signature_torch_model(model, inputs)

                        with mlflow.start_run(run_id=mlflow_logger.run_id) as _:
                            mlflow.pytorch.log_model(
                                model,
                                artifact_path=model_path,
                                signature=signature,
                            )

        if tensorboard_logger is not None:
            config_paths = []
            for f in file_names:
                if os.path.isfile(os.path.join(os.getcwd(), f)):
                    config_paths.append(os.path.join(os.getcwd(), f))

            for path in config_paths:
                upload_file_tensorboard(file_path=path, tensorboard_logger=tensorboard_logger)

            tensorboard_logger.experiment.flush()

flatten_list(l)

Return an iterator over the flattened list.

Parameters:

Yields:

  • Any

    Iterator[Any]: the iterator over the flattend list

Source code in quadra/utils/utils.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
def flatten_list(l: Iterable[Any]) -> Iterator[Any]:
    """Return an iterator over the flattened list.

    Args:
        l: the list to be flattened

    Yields:
        Iterator[Any]: the iterator over the flattend list
    """
    for v in l:
        if isinstance(v, Iterable) and not isinstance(v, (str, bytes)):
            yield from flatten_list(v)
        else:
            yield v

get_device(cuda=True)

Returns the device to use for training.

Parameters:

  • cuda (bool) –

    whether to use cuda or not

Returns:

  • str

    The device to use

Source code in quadra/utils/utils.py
323
324
325
326
327
328
329
330
331
332
333
334
335
def get_device(cuda: bool = True) -> str:
    """Returns the device to use for training.

    Args:
        cuda: whether to use cuda or not

    Returns:
        The device to use
    """
    if torch.cuda.is_available() and cuda:
        return "cuda:0"

    return "cpu"

get_logger(name=__name__)

Initializes multi-GPU-friendly python logger.

Source code in quadra/utils/utils.py
37
38
39
40
41
42
43
44
45
46
def get_logger(name=__name__) -> logging.Logger:
    """Initializes multi-GPU-friendly python logger."""
    logger = logging.getLogger(name)

    # this ensures all logging levels get marked with the rank zero decorator
    # otherwise logs would get multiplied for each GPU process in multi-GPU setup
    for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"):
        setattr(logger, level, rank_zero_only(getattr(logger, level)))

    return logger

get_tensorboard_logger(trainer)

Safely get tensorboard logger from Lightning Trainer loggers.

Parameters:

  • trainer (pl.Trainer) –

    Pytorch Lightning Trainer.

Returns:

  • Optional[TensorBoardLogger]

    An mlflow logger if available, else None.

Source code in quadra/utils/utils.py
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
def get_tensorboard_logger(trainer: pl.Trainer) -> Optional[TensorBoardLogger]:
    """Safely get tensorboard logger from Lightning Trainer loggers.

    Args:
        trainer: Pytorch Lightning Trainer.

    Returns:
        An mlflow logger if available, else None.
    """
    if isinstance(trainer.logger, TensorBoardLogger):
        return trainer.logger

    if isinstance(trainer.logger, list):
        for logger in trainer.logger:
            if isinstance(logger, TensorBoardLogger):
                return logger

    return None

load_envs(env_file=None)

Load all the environment variables defined in the env_file. This is equivalent to . env_file in bash.

It is possible to define all the system specific variables in the env_file.

Parameters:

  • env_file (Optional[str]) –

    the file that defines the environment variables to use. If None it searches for a .env file in the project.

Source code in quadra/utils/utils.py
304
305
306
307
308
309
310
311
312
313
314
def load_envs(env_file: Optional[str] = None) -> None:
    """Load all the environment variables defined in the `env_file`.
    This is equivalent to `. env_file` in bash.

    It is possible to define all the system specific variables in the `env_file`.

    Args:
        env_file: the file that defines the environment variables to use. If None
                     it searches for a `.env` file in the project.
    """
    dotenv.load_dotenv(dotenv_path=env_file, override=True)

log_hyperparameters(config, model, trainer)

This method controls which parameters from Hydra config are saved by Lightning loggers.

Additionaly saves
  • number of trainable model parameters
Source code in quadra/utils/utils.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
@rank_zero_only
def log_hyperparameters(
    config: DictConfig,
    model: pl.LightningModule,
    trainer: pl.Trainer,
) -> None:
    """This method controls which parameters from Hydra config are saved by Lightning loggers.

    Additionaly saves:
        - number of trainable model parameters
    """
    log = get_logger(__name__)

    if not HydraConfig.initialized() or trainer.logger is None:
        return

    log.info("Logging hyperparameters!")
    hydra_cfg = HydraConfig.get()
    hydra_choices = OmegaConf.to_container(hydra_cfg.runtime.choices)
    if isinstance(hydra_choices, dict):
        # For multirun override the choices that are not automatically updated
        for item in hydra_cfg.overrides.task:
            if "." in item:
                continue

            override, value = item.split("=")
            hydra_choices[override] = value

        hparams = {}
        hydra_choices_final = {}
        for k, v in hydra_choices.items():
            if isinstance(k, str):
                k_replaced = k.replace("@", "_at_")
                hydra_choices_final[k_replaced] = v
                if v is not None and isinstance(v, str) and "@" in v:
                    hydra_choices_final[k_replaced] = v.replace("@", "_at_")

        hparams.update(hydra_choices_final)
    else:
        logging.warning("Hydra choices is not a dictionary, skip adding them to the logger")
    # save number of model parameters
    hparams["model/params_total"] = sum(p.numel() for p in model.parameters())
    hparams["model/params_trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
    hparams["model/params_not_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    hparams["experiment_path"] = config.core.experiment_path
    hparams["command"] = config.core.command
    hparams["library/version"] = str(quadra.__version__)

    # pylint: disable=consider-using-with
    if (
        subprocess.call(
            ["git", "-C", get_original_cwd(), "status"], stderr=subprocess.STDOUT, stdout=open(os.devnull, "w")
        )
        == 0
    ):
        hparams["git/commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
        hparams["git/branch"] = (
            subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]).decode("ascii").strip()
        )
        hparams["git/remote"] = subprocess.check_output(["git", "remote", "get-url", "origin"]).decode("ascii").strip()
    else:
        log.warning("Could not find git repository, skipping git commit and branch info")

    # send hparams to all loggers
    trainer.logger.log_hyperparams(hparams)

nested_set(dic, keys, value)

Assign the value of a dictionary using nested keys.

Source code in quadra/utils/utils.py
338
339
340
341
342
343
def nested_set(dic: Dict, keys: List[str], value: str) -> None:
    """Assign the value of a dictionary using nested keys."""
    for key in keys[:-1]:
        dic = dic.setdefault(key, {})

    dic[keys[-1]] = value

print_config(config, fields=('trainer', 'model', 'datamodule', 'callbacks', 'logger', 'core', 'backbone', 'transforms', 'optimizer', 'scheduler'), resolve=True)

Prints content of DictConfig using Rich library and its tree structure.

Parameters:

  • config (DictConfig) –

    Configuration composed by Hydra.

  • fields (Sequence[str]) –

    Determines which main fields from config will be printed and in what order.

  • resolve (bool) –

    Whether to resolve reference fields of DictConfig.

Source code in quadra/utils/utils.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
@rank_zero_only
def print_config(
    config: DictConfig,
    fields: Sequence[str] = (
        "trainer",
        "model",
        "datamodule",
        "callbacks",
        "logger",
        "core",
        "backbone",
        "transforms",
        "optimizer",
        "scheduler",
    ),
    resolve: bool = True,
) -> None:
    """Prints content of DictConfig using Rich library and its tree structure.

    Args:
        config: Configuration composed by Hydra.
        fields: Determines which main fields from config will
            be printed and in what order.
        resolve: Whether to resolve reference fields of DictConfig.
    """
    style = "dim"
    tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)

    for field in fields:
        branch = tree.add(field, style=style, guide_style=style)

        config_section = config.get(field)
        branch_content = str(config_section)
        if isinstance(config_section, DictConfig):
            branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)

        branch.add(rich.syntax.Syntax(branch_content, "yaml"))

    rich.print(tree)

    with open("config_tree.txt", "w") as fp:
        rich.print(tree, file=fp)

setup_opencv()

Setup OpenCV to use only one thread and not use OpenCL.

Source code in quadra/utils/utils.py
317
318
319
320
def setup_opencv() -> None:
    """Setup OpenCV to use only one thread and not use OpenCL."""
    cv2.setNumThreads(1)
    cv2.ocl.setUseOpenCL(False)

upload_file_tensorboard(file_path, tensorboard_logger)

Upload a file to tensorboard handling different extensions.

Parameters:

  • file_path (str) –

    Path to the file to upload.

  • tensorboard_logger (TensorBoardLogger) –

    Tensorboard logger instance.

Source code in quadra/utils/utils.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def upload_file_tensorboard(file_path: str, tensorboard_logger: TensorBoardLogger) -> None:
    """Upload a file to tensorboard handling different extensions.

    Args:
        file_path: Path to the file to upload.
        tensorboard_logger: Tensorboard logger instance.
    """
    tag = os.path.basename(file_path)
    ext = os.path.splitext(file_path)[1].lower()

    if ext == ".json":
        with open(file_path, "r") as f:
            json_content = json.load(f)

            json_content = f"```json\n{json.dumps(json_content, indent=4)}\n```"
            tensorboard_logger.experiment.add_text(tag=tag, text_string=json_content, global_step=0)
    elif ext in [".yaml", ".yml"]:
        with open(file_path, "r") as f:
            yaml_content = f.read()
            yaml_content = f"```yaml\n{yaml_content}\n```"
            tensorboard_logger.experiment.add_text(tag=tag, text_string=yaml_content, global_step=0)
    else:
        with open(file_path, "r", encoding="utf-8") as f:
            tensorboard_logger.experiment.add_text(tag=tag, text_string=f.read().replace("\n", "  \n"), global_step=0)

    tensorboard_logger.experiment.flush()