ssl
BYOL(student, teacher, student_projection_mlp, student_prediction_mlp, teacher_projection_mlp, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch', teacher_momentum=0.9995, teacher_momentum_cosine_decay=True)
¶
Bases: SSLModule
BYOL module, inspired by https://arxiv.org/abs/2006.07733.
Parameters:
-
student
–
student model.
-
teacher
–
teacher model.
-
student_projection_mlp
–
student projection MLP.
-
student_prediction_mlp
–
student prediction MLP.
-
teacher_projection_mlp
–
teacher projection MLP.
-
criterion
–
loss function.
-
classifier
(
Optional[ClassifierMixin]
, default:None
) –Standard sklearn classifier.
-
optimizer
(
Optional[Optimizer]
, default:None
) –optimizer of the training. If None a default Adam is used.
-
lr_scheduler
(
Optional[object]
, default:None
) –lr scheduler. If None a default ReduceLROnPlateau is used.
-
lr_scheduler_interval
(
Optional[str]
, default:'epoch'
) –interval at which the lr scheduler is updated.
-
teacher_momentum
(
float
, default:0.9995
) –momentum of the teacher parameters.
-
teacher_momentum_cosine_decay
(
Optional[bool]
, default:True
) –whether to use cosine decay for the teacher momentum. Default: True
Source code in quadra/modules/ssl/byol.py
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
|
calculate_accuracy(batch)
¶
Calculate accuracy on the given batch.
Source code in quadra/modules/ssl/byol.py
152 153 154 155 156 157 158 159 160 |
|
initialize_teacher()
¶
Initialize teacher from the state dict of the student one, checking also that student model requires greadient correctly.
Source code in quadra/modules/ssl/byol.py
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
|
optimizer_step(epoch, batch_idx, optimizer, optimizer_closure=None)
¶
Override optimizer step to update the teacher parameters.
Source code in quadra/modules/ssl/byol.py
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
|
test_step(batch, *args)
¶
Calculate accuracy on the test set for the given batch.
Source code in quadra/modules/ssl/byol.py
165 166 167 168 169 |
|
update_teacher()
¶
Update teacher given self.teacher_momentum
by an exponential moving average
of the student parameters, that is: theta_t * tau + theta_s * (1 - tau), where
theta_{s,t}
are the parameters of the student and the teacher model, while tau
is the
teacher momentum. If self.teacher_momentum_cosine_decay
is True, then the teacher
momentum will follow a cosine scheduling from self.teacher_momentum
to 1:
tau = 1 - (1 - tau) * (cos(pi * t / T) + 1) / 2, where t
is the current step and
T
is the max number of steps.
Source code in quadra/modules/ssl/byol.py
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
|
BarlowTwins(model, projection_mlp, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')
¶
Bases: SSLModule
BarlowTwins model.
Parameters:
-
model
(
Module
) –Network Module used for extract features
-
projection_mlp
(
Module
) –Module to project extracted features
-
criterion
(
Module
) –SSL loss to be applied
-
classifier
(
Optional[ClassifierMixin]
, default:None
) –Standard sklearn classifier. Defaults to None.
-
optimizer
(
Optional[Optimizer]
, default:None
) –optimizer of the training. If None a default Adam is used. Defaults to None.
-
lr_scheduler
(
Optional[object]
, default:None
) –lr scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.
-
lr_scheduler_interval
(
Optional[str]
, default:'epoch'
) –interval at which the lr scheduler is updated. Defaults to "epoch".
Source code in quadra/modules/ssl/barlowtwins.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
|
Dino(student, teacher, student_projection_mlp, teacher_projection_mlp, criterion, freeze_last_layer=1, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch', teacher_momentum=0.9995, teacher_momentum_cosine_decay=True)
¶
Bases: BYOL
DINO pytorch-lightning module.
Parameters:
-
student
–
student model
-
teacher
–
teacher model
-
student_projection_mlp
–
student projection MLP
-
teacher_projection_mlp
–
teacher projection MLP
-
criterion
–
loss function
-
freeze_last_layer
–
number of layers to freeze in the student model. Default: 1
-
classifier
(
Optional[ClassifierMixin]
, default:None
) –Standard sklearn classifier
-
optimizer
(
Optional[Optimizer]
, default:None
) –optimizer of the training. If None a default Adam is used.
-
lr_scheduler
(
Optional[object]
, default:None
) –lr scheduler. If None a default ReduceLROnPlateau is used.
-
lr_scheduler_interval
(
Optional[str]
, default:'epoch'
) –interval at which the lr scheduler is updated.
-
teacher_momentum
(
float
, default:0.9995
) –momentum of the teacher parameters
-
teacher_momentum_cosine_decay
(
Optional[bool]
, default:True
) –whether to use cosine decay for the teacher momentum
Source code in quadra/modules/ssl/dino.py
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
|
cancel_gradients_last_layer(epoch, freeze_last_layer)
¶
Zero out the gradient of the last layer, as specified in the paper.
Parameters:
-
epoch
(
int
) –current epoch
-
freeze_last_layer
(
int
) –maximum freeze epoch: if
epoch
>=freeze_last_layer
then the gradient of the last layer will not be freezed
Source code in quadra/modules/ssl/dino.py
132 133 134 135 136 137 138 139 140 141 142 143 144 |
|
configure_gradient_clipping(optimizer, gradient_clip_val=None, gradient_clip_algorithm=None)
¶
Configure gradient clipping for the optimizer.
Source code in quadra/modules/ssl/dino.py
157 158 159 160 161 162 163 164 165 166 167 |
|
initialize_teacher()
¶
Initialize teacher from the state dict of the student one, checking also that student model requires greadient correctly.
Source code in quadra/modules/ssl/dino.py
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
|
optimizer_step(epoch, batch_idx, optimizer, optimizer_closure=None)
¶
Override optimizer step to update the teacher parameters.
Source code in quadra/modules/ssl/dino.py
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
|
student_multicrop_forward(x)
¶
Student forward on the multicrop imges.
Parameters:
Returns:
-
Tensor
–torch.Tensor: a tensor of shape NxBxD, where N is the number crops corresponding to the length of the input list
x
, B is the batch size and D is the output dimension
Source code in quadra/modules/ssl/dino.py
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
|
teacher_multicrop_forward(x)
¶
Teacher forward on the multicrop imges.
Parameters:
Returns:
-
Tensor
–torch.Tensor: a tensor of shape NxBxD, where N is the number crops corresponding to the length of the input list
x
, B is the batch size and D is the output dimension
Source code in quadra/modules/ssl/dino.py
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
|
IDMM(model, prediction_mlp, criterion, multiview_loss=True, mixup_fn=None, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')
¶
Bases: SSLModule
IDMM model.
Parameters:
-
model
(
Module
) –backbone model
-
prediction_mlp
(
Module
) –student prediction MLP
-
criterion
(
Module
) –loss function
-
multiview_loss
(
bool
, default:True
) –whether to use the multiview loss as definied in https://arxiv.org/abs/2201.10728. Defaults to True.
-
mixup_fn
(
Optional[Mixup]
, default:None
) –the mixup/cutmix function to be applied to a batch of images. Defaults to None.
-
classifier
(
Optional[ClassifierMixin]
, default:None
) –Standard sklearn classifier
-
optimizer
(
Optional[Optimizer]
, default:None
) –optimizer of the training. If None a default Adam is used.
-
lr_scheduler
(
Optional[object]
, default:None
) –lr scheduler. If None a default ReduceLROnPlateau is used.
-
lr_scheduler_interval
(
Optional[str]
, default:'epoch'
) –interval at which the lr scheduler is updated.
Source code in quadra/modules/ssl/idmm.py
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
|
SimCLR(model, projection_mlp, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')
¶
Bases: SSLModule
SIMCLR class.
Parameters:
-
model
(
Module
) –Feature extractor as pytorch
torch.nn.Module
-
projection_mlp
(
Module
) –projection head as pytorch
torch.nn.Module
-
criterion
(
Module
) –SSL loss to be applied
-
classifier
(
Optional[ClassifierMixin]
, default:None
) –Standard sklearn classifier. Defaults to None.
-
optimizer
(
Optional[Optimizer]
, default:None
) –optimizer of the training. If None a default Adam is used. Defaults to None.
-
lr_scheduler
(
Optional[object]
, default:None
) –lr scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.
-
lr_scheduler_interval
(
Optional[str]
, default:'epoch'
) –interval at which the lr scheduler is updated. Defaults to "epoch".
Source code in quadra/modules/ssl/simclr.py
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
|
training_step(batch, batch_idx)
¶
Parameters:
-
batch
(
Tuple[Tuple[Tensor, Tensor], Tensor]
) –The batch of data
-
batch_idx
(
int
) –The index of the batch.
Returns:
-
Tensor
–The computed loss
Source code in quadra/modules/ssl/simclr.py
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
|
SimSIAM(model, projection_mlp, prediction_mlp, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')
¶
Bases: SSLModule
SimSIAM model.
Parameters:
-
model
(
Module
) –Feature extractor as pytorch
torch.nn.Module
-
projection_mlp
(
Module
) –optional projection head as pytorch
torch.nn.Module
-
prediction_mlp
(
Module
) –optional predicition head as pytorch
torch.nn.Module
-
criterion
(
Module
) –loss to be applied.
-
classifier
(
Optional[ClassifierMixin]
, default:None
) –Standard sklearn classifier.
-
optimizer
(
Optional[Optimizer]
, default:None
) –optimizer of the training. If None a default Adam is used.
-
lr_scheduler
(
Optional[object]
, default:None
) –lr scheduler. If None a default ReduceLROnPlateau is used.
-
lr_scheduler_interval
(
Optional[str]
, default:'epoch'
) –interval at which the lr scheduler is updated.
Source code in quadra/modules/ssl/simsiam.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
|
VICReg(model, projection_mlp, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')
¶
Bases: SSLModule
VICReg model.
Parameters:
-
model
(
Module
) –Network Module used for extract features
-
projection_mlp
(
Module
) –Module to project extracted features
-
criterion
(
Module
) –SSL loss to be applied.
-
classifier
(
Optional[ClassifierMixin]
, default:None
) –Standard sklearn classifier. Defaults to None.
-
optimizer
(
Optional[Optimizer]
, default:None
) –optimizer of the training. If None a default Adam is used. Defaults to None.
-
lr_scheduler
(
Optional[object]
, default:None
) –lr scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.
-
lr_scheduler_interval
(
Optional[str]
, default:'epoch'
) –interval at which the lr scheduler is updated. Defaults to "epoch".
Source code in quadra/modules/ssl/vicreg.py
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
|