simclr
SimCLR(model, projection_mlp, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')
¶
Bases: SSLModule
SIMCLR class.
Parameters:
-
model
(
Module
) –Feature extractor as pytorch
torch.nn.Module
-
projection_mlp
(
Module
) –projection head as pytorch
torch.nn.Module
-
criterion
(
Module
) –SSL loss to be applied
-
classifier
(
Optional[ClassifierMixin]
, default:None
) –Standard sklearn classifier. Defaults to None.
-
optimizer
(
Optional[Optimizer]
, default:None
) –optimizer of the training. If None a default Adam is used. Defaults to None.
-
lr_scheduler
(
Optional[object]
, default:None
) –lr scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.
-
lr_scheduler_interval
(
Optional[str]
, default:'epoch'
) –interval at which the lr scheduler is updated. Defaults to "epoch".
Source code in quadra/modules/ssl/simclr.py
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
|
training_step(batch, batch_idx)
¶
Parameters:
-
batch
(
Tuple[Tuple[Tensor, Tensor], Tensor]
) –The batch of data
-
batch_idx
(
int
) –The index of the batch.
Returns:
-
Tensor
–The computed loss
Source code in quadra/modules/ssl/simclr.py
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
|