Skip to content

vicreg

VICReg(model, projection_mlp, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')

Bases: SSLModule

VICReg model.

Parameters:

  • model (Module) –

    Network Module used for extract features

  • projection_mlp (Module) –

    Module to project extracted features

  • criterion (Module) –

    SSL loss to be applied.

  • classifier (ClassifierMixin | None, default: None ) –

    Standard sklearn classifier. Defaults to None.

  • optimizer (Optimizer | None, default: None ) –

    optimizer of the training. If None a default Adam is used. Defaults to None.

  • lr_scheduler (object | None, default: None ) –

    lr scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.

  • lr_scheduler_interval (str | None, default: 'epoch' ) –

    interval at which the lr scheduler is updated. Defaults to "epoch".

Source code in quadra/modules/ssl/vicreg.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def __init__(
    self,
    model: nn.Module,
    projection_mlp: nn.Module,
    criterion: nn.Module,
    classifier: sklearn.base.ClassifierMixin | None = None,
    optimizer: optim.Optimizer | None = None,
    lr_scheduler: object | None = None,
    lr_scheduler_interval: str | None = "epoch",
):
    super().__init__(
        model,
        criterion,
        classifier,
        optimizer,
        lr_scheduler,
        lr_scheduler_interval,
    )
    # self.save_hyperparameters()
    self.projection_mlp = projection_mlp
    self.criterion = criterion