Skip to content

simsiam

SimSIAM(model, projection_mlp, prediction_mlp, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')

Bases: SSLModule

SimSIAM model.

Parameters:

  • model (torch.nn.Module) –

    Feature extractor as pytorch torch.nn.Module

  • projection_mlp (torch.nn.Module) –

    optional projection head as pytorch torch.nn.Module

  • prediction_mlp (torch.nn.Module) –

    optional predicition head as pytorch torch.nn.Module

  • criterion (torch.nn.Module) –

    loss to be applied.

  • classifier (Optional[sklearn.base.ClassifierMixin]) –

    Standard sklearn classifier.

  • optimizer (Optional[torch.optim.Optimizer]) –

    optimizer of the training. If None a default Adam is used.

  • lr_scheduler (Optional[object]) –

    lr scheduler. If None a default ReduceLROnPlateau is used.

  • lr_scheduler_interval (Optional[str]) –

    interval at which the lr scheduler is updated.

Source code in quadra/modules/ssl/simsiam.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
    self,
    model: torch.nn.Module,
    projection_mlp: torch.nn.Module,
    prediction_mlp: torch.nn.Module,
    criterion: torch.nn.Module,
    classifier: Optional[sklearn.base.ClassifierMixin] = None,
    optimizer: Optional[torch.optim.Optimizer] = None,
    lr_scheduler: Optional[object] = None,
    lr_scheduler_interval: Optional[str] = "epoch",
):

    super().__init__(
        model,
        criterion,
        classifier,
        optimizer,
        lr_scheduler,
        lr_scheduler_interval,
    )
    # self.save_hyperparameters()
    self.projection_mlp = projection_mlp
    self.prediction_mlp = prediction_mlp