dino
Dino(student, teacher, student_projection_mlp, teacher_projection_mlp, criterion, freeze_last_layer=1, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch', teacher_momentum=0.9995, teacher_momentum_cosine_decay=True)
¶
Bases: BYOL
DINO pytorch-lightning module.
Parameters:
-
student
–
student model
-
teacher
–
teacher model
-
student_projection_mlp
–
student projection MLP
-
teacher_projection_mlp
–
teacher projection MLP
-
criterion
–
loss function
-
freeze_last_layer
–
number of layers to freeze in the student model. Default: 1
-
classifier
(
Optional[ClassifierMixin]
, default:None
) –Standard sklearn classifier
-
optimizer
(
Optional[Optimizer]
, default:None
) –optimizer of the training. If None a default Adam is used.
-
lr_scheduler
(
Optional[object]
, default:None
) –lr scheduler. If None a default ReduceLROnPlateau is used.
-
lr_scheduler_interval
(
Optional[str]
, default:'epoch'
) –interval at which the lr scheduler is updated.
-
teacher_momentum
(
float
, default:0.9995
) –momentum of the teacher parameters
-
teacher_momentum_cosine_decay
(
Optional[bool]
, default:True
) –whether to use cosine decay for the teacher momentum
Source code in quadra/modules/ssl/dino.py
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
|
cancel_gradients_last_layer(epoch, freeze_last_layer)
¶
Zero out the gradient of the last layer, as specified in the paper.
Parameters:
-
epoch
(
int
) –current epoch
-
freeze_last_layer
(
int
) –maximum freeze epoch: if
epoch
>=freeze_last_layer
then the gradient of the last layer will not be freezed
Source code in quadra/modules/ssl/dino.py
132 133 134 135 136 137 138 139 140 141 142 143 144 |
|
configure_gradient_clipping(optimizer, gradient_clip_val=None, gradient_clip_algorithm=None)
¶
Configure gradient clipping for the optimizer.
Source code in quadra/modules/ssl/dino.py
157 158 159 160 161 162 163 164 165 166 167 |
|
initialize_teacher()
¶
Initialize teacher from the state dict of the student one, checking also that student model requires greadient correctly.
Source code in quadra/modules/ssl/dino.py
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
|
optimizer_step(epoch, batch_idx, optimizer, optimizer_closure=None)
¶
Override optimizer step to update the teacher parameters.
Source code in quadra/modules/ssl/dino.py
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
|
student_multicrop_forward(x)
¶
Student forward on the multicrop imges.
Parameters:
Returns:
-
Tensor
–torch.Tensor: a tensor of shape NxBxD, where N is the number crops corresponding to the length of the input list
x
, B is the batch size and D is the output dimension
Source code in quadra/modules/ssl/dino.py
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
|
teacher_multicrop_forward(x)
¶
Teacher forward on the multicrop imges.
Parameters:
Returns:
-
Tensor
–torch.Tensor: a tensor of shape NxBxD, where N is the number crops corresponding to the length of the input list
x
, B is the batch size and D is the output dimension
Source code in quadra/modules/ssl/dino.py
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
|