Skip to content

barlowtwins

BarlowTwins(model, projection_mlp, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')

Bases: SSLModule

BarlowTwins model.

Parameters:

  • model (Module) –

    Network Module used for extract features

  • projection_mlp (Module) –

    Module to project extracted features

  • criterion (Module) –

    SSL loss to be applied

  • classifier (ClassifierMixin | None, default: None ) –

    Standard sklearn classifier. Defaults to None.

  • optimizer (Optimizer | None, default: None ) –

    optimizer of the training. If None a default Adam is used. Defaults to None.

  • lr_scheduler (object | None, default: None ) –

    lr scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.

  • lr_scheduler_interval (str | None, default: 'epoch' ) –

    interval at which the lr scheduler is updated. Defaults to "epoch".

Source code in quadra/modules/ssl/barlowtwins.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def __init__(
    self,
    model: nn.Module,
    projection_mlp: nn.Module,
    criterion: nn.Module,
    classifier: sklearn.base.ClassifierMixin | None = None,
    optimizer: optim.Optimizer | None = None,
    lr_scheduler: object | None = None,
    lr_scheduler_interval: str | None = "epoch",
):
    super().__init__(model, criterion, classifier, optimizer, lr_scheduler, lr_scheduler_interval)
    # self.save_hyperparameters()
    self.projection_mlp = projection_mlp
    self.criterion = criterion