barlowtwins
BarlowTwins(model, projection_mlp, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')
¶
Bases: SSLModule
BarlowTwins model.
Parameters:
-
model(Module) –Network Module used for extract features
-
projection_mlp(Module) –Module to project extracted features
-
criterion(Module) –SSL loss to be applied
-
classifier(ClassifierMixin | None, default:None) –Standard sklearn classifier. Defaults to None.
-
optimizer(Optimizer | None, default:None) –optimizer of the training. If None a default Adam is used. Defaults to None.
-
lr_scheduler(object | None, default:None) –lr scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.
-
lr_scheduler_interval(str | None, default:'epoch') –interval at which the lr scheduler is updated. Defaults to "epoch".
Source code in quadra/modules/ssl/barlowtwins.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 | |