Skip to content

ssl

BYOL(config, checkpoint_path=None, run_test=False, **kwargs)

Bases: SSL

BYOL model as a pytorch_lightning.LightningModule.

Parameters:

  • config (DictConfig) –

    the main config

  • checkpoint_path (Optional[str], default: None ) –

    if a checkpoint is specified, then it will return a trained model, with weights loaded from the checkpoint path specified. Defaults to None.

  • run_test (bool, default: False ) –

    Whether to run final test

  • **kwargs (Any, default: {} ) –

    Keyword arguments

Source code in quadra/tasks/ssl.py
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
def __init__(
    self,
    config: DictConfig,
    checkpoint_path: Optional[str] = None,
    run_test: bool = False,
    **kwargs: Any,
):
    super().__init__(
        config=config,
        checkpoint_path=checkpoint_path,
        run_test=run_test,
        **kwargs,
    )
    self.student_model: nn.Module
    self.teacher_model: nn.Module
    self.student_projection_mlp: nn.Module
    self.student_prediction_mlp: nn.Module
    self.teacher_projection_mlp: nn.Module

learnable_parameters()

Get the learnable parameters.

Source code in quadra/tasks/ssl.py
314
315
316
317
318
319
320
def learnable_parameters(self) -> List[nn.Parameter]:
    """Get the learnable parameters."""
    return list(
        list(self.student_model.parameters())
        + list(self.student_projection_mlp.parameters())
        + list(self.student_prediction_mlp.parameters()),
    )

prepare()

Prepare the experiment.

Source code in quadra/tasks/ssl.py
322
323
324
325
326
327
328
329
330
331
332
def prepare(self) -> None:
    """Prepare the experiment."""
    super().prepare()
    self.student_model = hydra.utils.instantiate(self.config.model.student)
    self.teacher_model = hydra.utils.instantiate(self.config.model.student)
    self.student_projection_mlp = hydra.utils.instantiate(self.config.model.projection_mlp)
    self.student_prediction_mlp = hydra.utils.instantiate(self.config.model.prediction_mlp)
    self.teacher_projection_mlp = hydra.utils.instantiate(self.config.model.projection_mlp)
    self.optimizer = self.config.optimizer
    self.scheduler = self.config.scheduler
    self.module = self.config.model.module

Barlow(config, checkpoint_path=None, run_test=False)

Bases: SimCLR

Barlow model as a pytorch_lightning.LightningModule.

Parameters:

  • config (DictConfig) –

    the main config

  • checkpoint_path (Optional[str], default: None ) –

    if a checkpoint is specified, then it will return a trained model, with weights loaded from the checkpoint path specified. Defaults to None.

  • run_test (bool, default: False ) –

    Whether to run final test

Source code in quadra/tasks/ssl.py
257
258
259
260
261
262
263
def __init__(
    self,
    config: DictConfig,
    checkpoint_path: Optional[str] = None,
    run_test: bool = False,
):
    super().__init__(config=config, checkpoint_path=checkpoint_path, run_test=run_test)

prepare()

Prepare the experiment.

Source code in quadra/tasks/ssl.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def prepare(self) -> None:
    """Prepare the experiment."""
    super(SimCLR, self).prepare()
    self.backbone = hydra.utils.instantiate(self.config.model.model)

    with open_dict(self.config):
        self.config.model.projection_mlp.hidden_dim = (
            self.config.model.projection_mlp.hidden_dim * self.config.model.projection_mlp_mult
        )
        self.config.model.projection_mlp.output_dim = (
            self.config.model.projection_mlp.output_dim * self.config.model.projection_mlp_mult
        )
    self.projection_mlp = hydra.utils.instantiate(self.config.model.projection_mlp)
    self.optimizer = self.config.optimizer
    self.scheduler = self.config.scheduler
    self.module = self.config.model.module

DINO(config, checkpoint_path=None, run_test=False)

Bases: SSL

DINO model as a pytorch_lightning.LightningModule.

Parameters:

  • config (DictConfig) –

    the main config

  • checkpoint_path (Optional[str], default: None ) –

    if a checkpoint is specified, then it will return a trained model, with weights loaded from the checkpoint path specified. Defaults to None.

  • run_test (bool, default: False ) –

    Whether to run final test

Source code in quadra/tasks/ssl.py
377
378
379
380
381
382
383
384
385
386
387
def __init__(
    self,
    config: DictConfig,
    checkpoint_path: Optional[str] = None,
    run_test: bool = False,
):
    super().__init__(config=config, checkpoint_path=checkpoint_path, run_test=run_test)
    self.student_model: nn.Module
    self.teacher_model: nn.Module
    self.student_projection_mlp: nn.Module
    self.teacher_projection_mlp: nn.Module

learnable_parameters()

Get the learnable parameters.

Source code in quadra/tasks/ssl.py
389
390
391
392
393
def learnable_parameters(self) -> List[nn.Parameter]:
    """Get the learnable parameters."""
    return list(
        list(self.student_model.parameters()) + list(self.student_projection_mlp.parameters()),
    )

prepare()

Prepare the experiment.

Source code in quadra/tasks/ssl.py
395
396
397
398
399
400
401
402
403
404
def prepare(self) -> None:
    """Prepare the experiment."""
    super().prepare()
    self.student_model = cast(nn.Module, hydra.utils.instantiate(self.config.model.student))
    self.teacher_model = cast(nn.Module, hydra.utils.instantiate(self.config.model.student))
    self.student_projection_mlp = cast(nn.Module, hydra.utils.instantiate(self.config.model.student_projection_mlp))
    self.teacher_projection_mlp = cast(nn.Module, hydra.utils.instantiate(self.config.model.teacher_projection_mlp))
    self.optimizer = self.config.optimizer
    self.scheduler = self.config.scheduler
    self.module = self.config.model.module

EmbeddingVisualization(config, model_path, report_folder='embeddings', embedding_image_size=None)

Bases: Task

Visualization task for learned embeddings.

Parameters:

  • config (DictConfig) –

    The loaded experiment config

  • model_path (str) –

    The path to a deployment model

  • report_folder (str, default: 'embeddings' ) –

    Where to save the embeddings

  • embedding_image_size (Optional[int], default: None ) –

    If not None rescale the images associated with the embeddings, tensorboard will save on disk a large sprite containing all the images in a matrix fashion, if the dimension of this sprite is too big it's not possible to load it in the browser. Rescaling the output image from the model input size to something smaller can solve this issue. The field is an int to always rescale to a squared image.

Source code in quadra/tasks/ssl.py
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
def __init__(
    self,
    config: DictConfig,
    model_path: str,
    report_folder: str = "embeddings",
    embedding_image_size: Optional[int] = None,
):
    super().__init__(config=config)

    self.config = config
    self.metadata = {
        "report_files": [],
    }
    self.model_path = model_path
    self._device = utils.get_device()
    log.info("Current device: %s", self._device)

    self.report_folder = report_folder
    if self.model_path is None:
        raise ValueError(
            "Model path cannot be found!, please specify it in the config or pass it as an argument for"
            " evaluation"
        )
    self.embeddings_path = os.path.join(self.model_path, self.report_folder)
    if not os.path.exists(self.embeddings_path):
        os.makedirs(self.embeddings_path)
    self.embedding_writer = SummaryWriter(self.embeddings_path)
    self.writer_step = 0  # for tensorboard
    self.embedding_image_size = embedding_image_size
    self._deployment_model: BaseEvaluationModel
    self.deployment_model_type: str

deployment_model property writable

Get the deployment model.

prepare()

Prepare the evaluation.

Source code in quadra/tasks/ssl.py
505
506
507
508
def prepare(self) -> None:
    """Prepare the evaluation."""
    super().prepare()
    self.deployment_model = self.model_path

test()

Run embeddings extraction.

Source code in quadra/tasks/ssl.py
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
@torch.no_grad()
def test(self) -> None:
    """Run embeddings extraction."""
    self.datamodule.setup("fit")
    idx_to_class = self.datamodule.val_dataset.idx_to_class
    self.datamodule.setup("test")
    dataloader = self.datamodule.test_dataloader()
    images = []
    metadata: List[Tuple[int, str, str]] = []
    embeddings = []
    std = torch.tensor(self.config.transforms.std).view(1, -1, 1, 1)
    mean = torch.tensor(self.config.transforms.mean).view(1, -1, 1, 1)
    dl = self.datamodule.test_dataloader()
    counter = 0

    is_half_precision = False
    for param in self.deployment_model.parameters():
        if param.dtype == torch.half:
            is_half_precision = True
        break

    for batch in tqdm(dataloader):
        im, target = batch
        if is_half_precision:
            im = im.half()

        x = self.deployment_model(im.to(self.device))
        targets = [int(t.item()) for t in target]
        class_names = [idx_to_class[t.item()] for t in target]
        file_paths = [s[0] for s in dl.dataset.samples[counter : counter + len(im)]]
        embeddings.append(x.cpu())
        im = im * std
        im += mean

        if self.embedding_image_size is not None:
            im = interpolate(im, self.embedding_image_size)

        images.append(im.cpu())
        metadata.extend(zip(targets, class_names, file_paths))
        counter += len(im)
    images = torch.cat(images, dim=0)
    embeddings = torch.cat(embeddings, dim=0)
    self.embedding_writer.add_embedding(
        embeddings,
        metadata=metadata,
        label_img=images,
        global_step=self.writer_step,
        metadata_header=["class", "class_name", "path"],
    )

SSL(config, run_test=False, report=False, checkpoint_path=None)

Bases: LightningTask

SSL Task.

Parameters:

  • config (DictConfig) –

    The experiment configuration

  • checkpoint_path (Optional[str], default: None ) –

    The path to the checkpoint to load the model from Defaults to None

  • report (bool, default: False ) –

    Whether to create the report

  • run_test (bool, default: False ) –

    Whether to run final test

Source code in quadra/tasks/ssl.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def __init__(
    self,
    config: DictConfig,
    run_test: bool = False,
    report: bool = False,
    checkpoint_path: Optional[str] = None,
):
    super().__init__(
        config=config,
        checkpoint_path=checkpoint_path,
        run_test=run_test,
        report=report,
    )
    self._backbone: nn.Module
    self._optimizer: torch.optim.Optimizer
    self._lr_scheduler: torch.optim.lr_scheduler._LRScheduler
    self.export_folder = "deployment_model"

optimizer: torch.optim.Optimizer property writable

Get the optimizer.

scheduler: torch.optim.lr_scheduler._LRScheduler property writable

Get the scheduler.

export()

Deploy a model ready for production.

Source code in quadra/tasks/ssl.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def export(self) -> None:
    """Deploy a model ready for production."""
    half_precision = "16" in self.trainer.precision

    input_shapes = self.config.export.input_shapes

    model_json, export_paths = export_model(
        config=self.config,
        model=self.module.model,
        export_folder=self.export_folder,
        half_precision=half_precision,
        input_shapes=input_shapes,
        idx_to_class=None,
    )

    if len(export_paths) == 0:
        return

    with open(os.path.join(self.export_folder, "model.json"), "w") as f:
        json.dump(model_json, f, cls=utils.HydraEncoder)

learnable_parameters()

Get the learnable parameters.

Source code in quadra/tasks/ssl.py
52
53
54
def learnable_parameters(self) -> List[nn.Parameter]:
    """Get the learnable parameters."""
    raise NotImplementedError("This method must be implemented by the subclass")

test()

Test the model.

Source code in quadra/tasks/ssl.py
87
88
89
90
91
92
def test(self) -> None:
    """Test the model."""
    if self.run_test and not self.config.trainer.get("fast_dev_run"):
        log.info("Starting testing!")
        log.info("Using last epoch's weights for testing.")
        self.trainer.test(datamodule=self.datamodule, model=self.module, ckpt_path=None)

SimCLR(config, checkpoint_path=None, run_test=False)

Bases: SSL

SimCLR model as a pytorch_lightning.LightningModule.

Parameters:

  • config (DictConfig) –

    the main config

  • checkpoint_path (Optional[str], default: None ) –

    if a checkpoint is specified, then it will return a trained model, with weights loaded from the checkpoint path specified. Defaults to None.

  • run_test (bool, default: False ) –

    Whether to run final test

Source code in quadra/tasks/ssl.py
196
197
198
199
200
201
202
203
204
def __init__(
    self,
    config: DictConfig,
    checkpoint_path: Optional[str] = None,
    run_test: bool = False,
):
    super().__init__(config=config, checkpoint_path=checkpoint_path, run_test=run_test)
    self.backbone: nn.Module
    self.projection_mlp: nn.Module

learnable_parameters()

Get the learnable parameters.

Source code in quadra/tasks/ssl.py
206
207
208
def learnable_parameters(self) -> List[nn.Parameter]:
    """Get the learnable parameters."""
    return list(self.backbone.parameters()) + list(self.projection_mlp.parameters())

prepare()

Prepare the experiment.

Source code in quadra/tasks/ssl.py
210
211
212
213
214
215
216
217
def prepare(self) -> None:
    """Prepare the experiment."""
    super().prepare()
    self.backbone = hydra.utils.instantiate(self.config.model.model)
    self.projection_mlp = hydra.utils.instantiate(self.config.model.projection_mlp)
    self.optimizer = self.config.optimizer
    self.scheduler = self.config.scheduler
    self.module = self.config.model.module

Simsiam(config, checkpoint_path=None, run_test=False)

Bases: SSL

Simsiam model as a pytorch_lightning.LightningModule.

Parameters:

  • config (DictConfig) –

    the main config

  • checkpoint_path (Optional[str], default: None ) –

    if a checkpoint is specified, then it will return a trained model, with weights loaded from the checkpoint path specified. Defaults to None.

  • run_test (bool, default: False ) –

    Whether to run final test

Source code in quadra/tasks/ssl.py
127
128
129
130
131
132
133
134
135
136
def __init__(
    self,
    config: DictConfig,
    checkpoint_path: Optional[str] = None,
    run_test: bool = False,
):
    super().__init__(config=config, checkpoint_path=checkpoint_path, run_test=run_test)
    self.backbone: nn.Module
    self.projection_mlp: nn.Module
    self.prediction_mlp: nn.Module

module: LightningModule property writable

Get the module of the model.

learnable_parameters()

Get the learnable parameters.

Source code in quadra/tasks/ssl.py
138
139
140
141
142
143
144
def learnable_parameters(self) -> List[nn.Parameter]:
    """Get the learnable parameters."""
    return list(
        list(self.backbone.parameters())
        + list(self.projection_mlp.parameters())
        + list(self.prediction_mlp.parameters()),
    )

prepare()

Prepare the experiment.

Source code in quadra/tasks/ssl.py
146
147
148
149
150
151
152
153
154
def prepare(self) -> None:
    """Prepare the experiment."""
    super().prepare()
    self.backbone = hydra.utils.instantiate(self.config.model.model)
    self.projection_mlp = hydra.utils.instantiate(self.config.model.projection_mlp)
    self.prediction_mlp = hydra.utils.instantiate(self.config.model.prediction_mlp)
    self.optimizer = self.config.optimizer
    self.scheduler = self.config.scheduler
    self.module = self.config.model.module