Skip to content

simclr

SimCLR(model, projection_mlp, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')

Bases: SSLModule

SIMCLR class.

Parameters:

  • model (Module) –

    Feature extractor as pytorch torch.nn.Module

  • projection_mlp (Module) –

    projection head as pytorch torch.nn.Module

  • criterion (Module) –

    SSL loss to be applied

  • classifier (Optional[ClassifierMixin], default: None ) –

    Standard sklearn classifier. Defaults to None.

  • optimizer (Optional[Optimizer], default: None ) –

    optimizer of the training. If None a default Adam is used. Defaults to None.

  • lr_scheduler (Optional[object], default: None ) –

    lr scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.

  • lr_scheduler_interval (Optional[str], default: 'epoch' ) –

    interval at which the lr scheduler is updated. Defaults to "epoch".

Source code in quadra/modules/ssl/simclr.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def __init__(
    self,
    model: nn.Module,
    projection_mlp: nn.Module,
    criterion: torch.nn.Module,
    classifier: Optional[sklearn.base.ClassifierMixin] = None,
    optimizer: Optional[torch.optim.Optimizer] = None,
    lr_scheduler: Optional[object] = None,
    lr_scheduler_interval: Optional[str] = "epoch",
):
    super().__init__(
        model,
        criterion,
        classifier,
        optimizer,
        lr_scheduler,
        lr_scheduler_interval,
    )
    self.projection_mlp = projection_mlp

training_step(batch, batch_idx)

Parameters:

Returns:

  • Tensor

    The computed loss

Source code in quadra/modules/ssl/simclr.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def training_step(
    self, batch: Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int
) -> torch.Tensor:
    """Args:
        batch: The batch of data
        batch_idx: The index of the batch.

    Returns:
        The computed loss
    """
    # pylint: disable=unused-argument
    (im_x, im_y), _ = batch
    emb_x = self(im_x)
    emb_y = self(im_y)
    loss = self.criterion(emb_x, emb_y)

    self.log(
        "loss",
        loss,
        on_epoch=True,
        on_step=True,
        logger=True,
        prog_bar=True,
    )
    return loss