Skip to content

ssl

BYOL(config, checkpoint_path=None, run_test=False, **kwargs)

Bases: SSL

BYOL model as a pytorch_lightning.LightningModule.

Parameters:

  • config (DictConfig) –

    the main config

  • checkpoint_path (str | None, default: None ) –

    if a checkpoint is specified, then it will return a trained model, with weights loaded from the checkpoint path specified. Defaults to None.

  • run_test (bool, default: False ) –

    Whether to run final test

  • **kwargs (Any, default: {} ) –

    Keyword arguments

Source code in quadra/tasks/ssl.py
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
def __init__(
    self,
    config: DictConfig,
    checkpoint_path: str | None = None,
    run_test: bool = False,
    **kwargs: Any,
):
    super().__init__(
        config=config,
        checkpoint_path=checkpoint_path,
        run_test=run_test,
        **kwargs,
    )
    self.student_model: nn.Module
    self.teacher_model: nn.Module
    self.student_projection_mlp: nn.Module
    self.student_prediction_mlp: nn.Module
    self.teacher_projection_mlp: nn.Module

learnable_parameters()

Get the learnable parameters.

Source code in quadra/tasks/ssl.py
316
317
318
319
320
321
322
def learnable_parameters(self) -> list[nn.Parameter]:
    """Get the learnable parameters."""
    return list(
        list(self.student_model.parameters())
        + list(self.student_projection_mlp.parameters())
        + list(self.student_prediction_mlp.parameters()),
    )

prepare()

Prepare the experiment.

Source code in quadra/tasks/ssl.py
324
325
326
327
328
329
330
331
332
333
334
def prepare(self) -> None:
    """Prepare the experiment."""
    super().prepare()
    self.student_model = hydra.utils.instantiate(self.config.model.student)
    self.teacher_model = hydra.utils.instantiate(self.config.model.student)
    self.student_projection_mlp = hydra.utils.instantiate(self.config.model.projection_mlp)
    self.student_prediction_mlp = hydra.utils.instantiate(self.config.model.prediction_mlp)
    self.teacher_projection_mlp = hydra.utils.instantiate(self.config.model.projection_mlp)
    self.optimizer = self.config.optimizer
    self.scheduler = self.config.scheduler
    self.module = self.config.model.module

Barlow(config, checkpoint_path=None, run_test=False)

Bases: SimCLR

Barlow model as a pytorch_lightning.LightningModule.

Parameters:

  • config (DictConfig) –

    the main config

  • checkpoint_path (str | None, default: None ) –

    if a checkpoint is specified, then it will return a trained model, with weights loaded from the checkpoint path specified. Defaults to None.

  • run_test (bool, default: False ) –

    Whether to run final test

Source code in quadra/tasks/ssl.py
259
260
261
262
263
264
265
def __init__(
    self,
    config: DictConfig,
    checkpoint_path: str | None = None,
    run_test: bool = False,
):
    super().__init__(config=config, checkpoint_path=checkpoint_path, run_test=run_test)

prepare()

Prepare the experiment.

Source code in quadra/tasks/ssl.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def prepare(self) -> None:
    """Prepare the experiment."""
    super(SimCLR, self).prepare()
    self.backbone = hydra.utils.instantiate(self.config.model.model)

    with open_dict(self.config):
        self.config.model.projection_mlp.hidden_dim = (
            self.config.model.projection_mlp.hidden_dim * self.config.model.projection_mlp_mult
        )
        self.config.model.projection_mlp.output_dim = (
            self.config.model.projection_mlp.output_dim * self.config.model.projection_mlp_mult
        )
    self.projection_mlp = hydra.utils.instantiate(self.config.model.projection_mlp)
    self.optimizer = self.config.optimizer
    self.scheduler = self.config.scheduler
    self.module = self.config.model.module

DINO(config, checkpoint_path=None, run_test=False)

Bases: SSL

DINO model as a pytorch_lightning.LightningModule.

Parameters:

  • config (DictConfig) –

    the main config

  • checkpoint_path (str | None, default: None ) –

    if a checkpoint is specified, then it will return a trained model, with weights loaded from the checkpoint path specified. Defaults to None.

  • run_test (bool, default: False ) –

    Whether to run final test

Source code in quadra/tasks/ssl.py
379
380
381
382
383
384
385
386
387
388
389
def __init__(
    self,
    config: DictConfig,
    checkpoint_path: str | None = None,
    run_test: bool = False,
):
    super().__init__(config=config, checkpoint_path=checkpoint_path, run_test=run_test)
    self.student_model: nn.Module
    self.teacher_model: nn.Module
    self.student_projection_mlp: nn.Module
    self.teacher_projection_mlp: nn.Module

learnable_parameters()

Get the learnable parameters.

Source code in quadra/tasks/ssl.py
391
392
393
394
395
def learnable_parameters(self) -> list[nn.Parameter]:
    """Get the learnable parameters."""
    return list(
        list(self.student_model.parameters()) + list(self.student_projection_mlp.parameters()),
    )

prepare()

Prepare the experiment.

Source code in quadra/tasks/ssl.py
397
398
399
400
401
402
403
404
405
406
def prepare(self) -> None:
    """Prepare the experiment."""
    super().prepare()
    self.student_model = cast(nn.Module, hydra.utils.instantiate(self.config.model.student))
    self.teacher_model = cast(nn.Module, hydra.utils.instantiate(self.config.model.student))
    self.student_projection_mlp = cast(nn.Module, hydra.utils.instantiate(self.config.model.student_projection_mlp))
    self.teacher_projection_mlp = cast(nn.Module, hydra.utils.instantiate(self.config.model.teacher_projection_mlp))
    self.optimizer = self.config.optimizer
    self.scheduler = self.config.scheduler
    self.module = self.config.model.module

EmbeddingVisualization(config, model_path, report_folder='embeddings', embedding_image_size=None)

Bases: Task

Visualization task for learned embeddings.

Parameters:

  • config (DictConfig) –

    The loaded experiment config

  • model_path (str) –

    The path to a deployment model

  • report_folder (str, default: 'embeddings' ) –

    Where to save the embeddings

  • embedding_image_size (int | None, default: None ) –

    If not None rescale the images associated with the embeddings, tensorboard will save on disk a large sprite containing all the images in a matrix fashion, if the dimension of this sprite is too big it's not possible to load it in the browser. Rescaling the output image from the model input size to something smaller can solve this issue. The field is an int to always rescale to a squared image.

Source code in quadra/tasks/ssl.py
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
def __init__(
    self,
    config: DictConfig,
    model_path: str,
    report_folder: str = "embeddings",
    embedding_image_size: int | None = None,
):
    super().__init__(config=config)

    self.config = config
    self.metadata = {
        "report_files": [],
    }
    self.model_path = model_path
    self._device = utils.get_device()
    log.info("Current device: %s", self._device)

    self.report_folder = report_folder
    if self.model_path is None:
        raise ValueError(
            "Model path cannot be found!, please specify it in the config or pass it as an argument for evaluation"
        )
    self.embeddings_path = os.path.join(self.model_path, self.report_folder)
    if not os.path.exists(self.embeddings_path):
        os.makedirs(self.embeddings_path)
    self.embedding_writer = SummaryWriter(self.embeddings_path)
    self.writer_step = 0  # for tensorboard
    self.embedding_image_size = embedding_image_size
    self._deployment_model: BaseEvaluationModel
    self.deployment_model_type: str

deployment_model property writable

Get the deployment model.

prepare()

Prepare the evaluation.

Source code in quadra/tasks/ssl.py
506
507
508
509
def prepare(self) -> None:
    """Prepare the evaluation."""
    super().prepare()
    self.deployment_model = self.model_path

test()

Run embeddings extraction.

Source code in quadra/tasks/ssl.py
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
@torch.no_grad()
def test(self) -> None:
    """Run embeddings extraction."""
    self.datamodule.setup("fit")
    idx_to_class = self.datamodule.val_dataset.idx_to_class
    self.datamodule.setup("test")
    dataloader = self.datamodule.test_dataloader()
    images = []
    metadata: list[tuple[int, str, str]] = []
    embeddings = []
    std = torch.tensor(self.config.transforms.std).view(1, -1, 1, 1)
    mean = torch.tensor(self.config.transforms.mean).view(1, -1, 1, 1)
    dl = self.datamodule.test_dataloader()
    counter = 0

    is_half_precision = False
    for param in self.deployment_model.parameters():
        if param.dtype == torch.half:
            is_half_precision = True
        break

    for batch in tqdm(dataloader):
        im, target = batch
        if is_half_precision:
            im = im.half()

        x = self.deployment_model(im.to(self.device))
        targets = [int(t.item()) for t in target]
        class_names = [idx_to_class[t.item()] for t in target]
        file_paths = [s[0] for s in dl.dataset.samples[counter : counter + len(im)]]
        embeddings.append(x.cpu())
        im = im * std
        im += mean

        if self.embedding_image_size is not None:
            im = interpolate(im, self.embedding_image_size)

        images.append(im.cpu())
        metadata.extend(zip(targets, class_names, file_paths, strict=False))
        counter += len(im)
    images = torch.cat(images, dim=0)
    embeddings = torch.cat(embeddings, dim=0)
    self.embedding_writer.add_embedding(
        embeddings,
        metadata=metadata,
        label_img=images,
        global_step=self.writer_step,
        metadata_header=["class", "class_name", "path"],
    )

SSL(config, run_test=False, report=False, checkpoint_path=None)

Bases: LightningTask

SSL Task.

Parameters:

  • config (DictConfig) –

    The experiment configuration

  • checkpoint_path (str | None, default: None ) –

    The path to the checkpoint to load the model from Defaults to None

  • report (bool, default: False ) –

    Whether to create the report

  • run_test (bool, default: False ) –

    Whether to run final test

Source code in quadra/tasks/ssl.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    config: DictConfig,
    run_test: bool = False,
    report: bool = False,
    checkpoint_path: str | None = None,
):
    super().__init__(
        config=config,
        checkpoint_path=checkpoint_path,
        run_test=run_test,
        report=report,
    )
    self._backbone: nn.Module
    self._optimizer: torch.optim.Optimizer
    self._lr_scheduler: torch.optim.lr_scheduler._LRScheduler
    self.export_folder = "deployment_model"

optimizer property writable

Get the optimizer.

scheduler property writable

Get the scheduler.

export()

Deploy a model ready for production.

Source code in quadra/tasks/ssl.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def export(self) -> None:
    """Deploy a model ready for production."""
    half_precision = "16" in self.trainer.precision

    input_shapes = self.config.export.input_shapes

    model_json, export_paths = export_model(
        config=self.config,
        model=self.module.model,
        export_folder=self.export_folder,
        half_precision=half_precision,
        input_shapes=input_shapes,
        idx_to_class=None,
    )

    if len(export_paths) == 0:
        return

    with open(os.path.join(self.export_folder, "model.json"), "w") as f:
        json.dump(model_json, f, cls=utils.HydraEncoder)

learnable_parameters()

Get the learnable parameters.

Source code in quadra/tasks/ssl.py
54
55
56
def learnable_parameters(self) -> list[nn.Parameter]:
    """Get the learnable parameters."""
    raise NotImplementedError("This method must be implemented by the subclass")

test()

Test the model.

Source code in quadra/tasks/ssl.py
89
90
91
92
93
94
def test(self) -> None:
    """Test the model."""
    if self.run_test and not self.config.trainer.get("fast_dev_run"):
        log.info("Starting testing!")
        log.info("Using last epoch's weights for testing.")
        self.trainer.test(datamodule=self.datamodule, model=self.module, ckpt_path=None)

SimCLR(config, checkpoint_path=None, run_test=False)

Bases: SSL

SimCLR model as a pytorch_lightning.LightningModule.

Parameters:

  • config (DictConfig) –

    the main config

  • checkpoint_path (str | None, default: None ) –

    if a checkpoint is specified, then it will return a trained model, with weights loaded from the checkpoint path specified. Defaults to None.

  • run_test (bool, default: False ) –

    Whether to run final test

Source code in quadra/tasks/ssl.py
198
199
200
201
202
203
204
205
206
def __init__(
    self,
    config: DictConfig,
    checkpoint_path: str | None = None,
    run_test: bool = False,
):
    super().__init__(config=config, checkpoint_path=checkpoint_path, run_test=run_test)
    self.backbone: nn.Module
    self.projection_mlp: nn.Module

learnable_parameters()

Get the learnable parameters.

Source code in quadra/tasks/ssl.py
208
209
210
def learnable_parameters(self) -> list[nn.Parameter]:
    """Get the learnable parameters."""
    return list(self.backbone.parameters()) + list(self.projection_mlp.parameters())

prepare()

Prepare the experiment.

Source code in quadra/tasks/ssl.py
212
213
214
215
216
217
218
219
def prepare(self) -> None:
    """Prepare the experiment."""
    super().prepare()
    self.backbone = hydra.utils.instantiate(self.config.model.model)
    self.projection_mlp = hydra.utils.instantiate(self.config.model.projection_mlp)
    self.optimizer = self.config.optimizer
    self.scheduler = self.config.scheduler
    self.module = self.config.model.module

Simsiam(config, checkpoint_path=None, run_test=False)

Bases: SSL

Simsiam model as a pytorch_lightning.LightningModule.

Parameters:

  • config (DictConfig) –

    the main config

  • checkpoint_path (str | None, default: None ) –

    if a checkpoint is specified, then it will return a trained model, with weights loaded from the checkpoint path specified. Defaults to None.

  • run_test (bool, default: False ) –

    Whether to run final test

Source code in quadra/tasks/ssl.py
129
130
131
132
133
134
135
136
137
138
def __init__(
    self,
    config: DictConfig,
    checkpoint_path: str | None = None,
    run_test: bool = False,
):
    super().__init__(config=config, checkpoint_path=checkpoint_path, run_test=run_test)
    self.backbone: nn.Module
    self.projection_mlp: nn.Module
    self.prediction_mlp: nn.Module

module property writable

Get the module of the model.

learnable_parameters()

Get the learnable parameters.

Source code in quadra/tasks/ssl.py
140
141
142
143
144
145
146
def learnable_parameters(self) -> list[nn.Parameter]:
    """Get the learnable parameters."""
    return list(
        list(self.backbone.parameters())
        + list(self.projection_mlp.parameters())
        + list(self.prediction_mlp.parameters()),
    )

prepare()

Prepare the experiment.

Source code in quadra/tasks/ssl.py
148
149
150
151
152
153
154
155
156
def prepare(self) -> None:
    """Prepare the experiment."""
    super().prepare()
    self.backbone = hydra.utils.instantiate(self.config.model.model)
    self.projection_mlp = hydra.utils.instantiate(self.config.model.projection_mlp)
    self.prediction_mlp = hydra.utils.instantiate(self.config.model.prediction_mlp)
    self.optimizer = self.config.optimizer
    self.scheduler = self.config.scheduler
    self.module = self.config.model.module