dino
Dino(student, teacher, student_projection_mlp, teacher_projection_mlp, criterion, freeze_last_layer=1, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch', teacher_momentum=0.9995, teacher_momentum_cosine_decay=True)
¶
Bases: BYOL
DINO pytorch-lightning module.
Parameters:
-
student
–
student model
-
teacher
–
teacher model
-
student_projection_mlp
–
student projection MLP
-
teacher_projection_mlp
–
teacher projection MLP
-
criterion
–
loss function
-
freeze_last_layer
–
number of layers to freeze in the student model. Default: 1
-
classifier
(
Optional[ClassifierMixin], default:None) –Standard sklearn classifier
-
optimizer
(
Optional[Optimizer], default:None) –optimizer of the training. If None a default Adam is used.
-
lr_scheduler
(
Optional[object], default:None) –lr scheduler. If None a default ReduceLROnPlateau is used.
-
lr_scheduler_interval
(
Optional[str], default:'epoch') –interval at which the lr scheduler is updated.
-
teacher_momentum
(
float, default:0.9995) –momentum of the teacher parameters
-
teacher_momentum_cosine_decay
(
Optional[bool], default:True) –whether to use cosine decay for the teacher momentum
Source code in quadra/modules/ssl/dino.py
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | |
cancel_gradients_last_layer(epoch, freeze_last_layer)
¶
Zero out the gradient of the last layer, as specified in the paper.
Parameters:
-
epoch
(
int) –current epoch
-
freeze_last_layer
(
int) –maximum freeze epoch: if
epoch>=freeze_last_layerthen the gradient of the last layer will not be freezed
Source code in quadra/modules/ssl/dino.py
132 133 134 135 136 137 138 139 140 141 142 143 144 | |
configure_gradient_clipping(optimizer, optimizer_idx, gradient_clip_val=None, gradient_clip_algorithm=None)
¶
Configure gradient clipping for the optimizer.
Source code in quadra/modules/ssl/dino.py
157 158 159 160 161 162 163 164 165 166 167 168 169 | |
initialize_teacher()
¶
Initialize teacher from the state dict of the student one, checking also that student model requires greadient correctly.
Source code in quadra/modules/ssl/dino.py
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | |
optimizer_step(epoch, batch_idx, optimizer, optimizer_idx=0, optimizer_closure=None, on_tpu=False, using_lbfgs=False)
¶
Override optimizer_step to update the teacher model.
Source code in quadra/modules/ssl/dino.py
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | |
student_multicrop_forward(x)
¶
Student forward on the multicrop imges.
Parameters:
Returns:
-
Tensor–torch.Tensor: a tensor of shape NxBxD, where N is the number crops corresponding to the length of the input list
x, B is the batch size and D is the output dimension
Source code in quadra/modules/ssl/dino.py
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | |
teacher_multicrop_forward(x)
¶
Teacher forward on the multicrop imges.
Parameters:
Returns:
-
Tensor–torch.Tensor: a tensor of shape NxBxD, where N is the number crops corresponding to the length of the input list
x, B is the batch size and D is the output dimension
Source code in quadra/modules/ssl/dino.py
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | |