byol
BYOL(student, teacher, student_projection_mlp, student_prediction_mlp, teacher_projection_mlp, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch', teacher_momentum=0.9995, teacher_momentum_cosine_decay=True)
¶
Bases: SSLModule
BYOL module, inspired by https://arxiv.org/abs/2006.07733.
Parameters:
-
student
–
student model.
-
teacher
–
teacher model.
-
student_projection_mlp
–
student projection MLP.
-
student_prediction_mlp
–
student prediction MLP.
-
teacher_projection_mlp
–
teacher projection MLP.
-
criterion
–
loss function.
-
classifier
(
Optional[ClassifierMixin]
, default:None
) –Standard sklearn classifier.
-
optimizer
(
Optional[Optimizer]
, default:None
) –optimizer of the training. If None a default Adam is used.
-
lr_scheduler
(
Optional[object]
, default:None
) –lr scheduler. If None a default ReduceLROnPlateau is used.
-
lr_scheduler_interval
(
Optional[str]
, default:'epoch'
) –interval at which the lr scheduler is updated.
-
teacher_momentum
(
float
, default:0.9995
) –momentum of the teacher parameters.
-
teacher_momentum_cosine_decay
(
Optional[bool]
, default:True
) –whether to use cosine decay for the teacher momentum. Default: True
Source code in quadra/modules/ssl/byol.py
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
|
calculate_accuracy(batch)
¶
Calculate accuracy on the given batch.
Source code in quadra/modules/ssl/byol.py
152 153 154 155 156 157 158 159 160 |
|
initialize_teacher()
¶
Initialize teacher from the state dict of the student one, checking also that student model requires greadient correctly.
Source code in quadra/modules/ssl/byol.py
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
|
optimizer_step(epoch, batch_idx, optimizer, optimizer_closure=None)
¶
Override optimizer step to update the teacher parameters.
Source code in quadra/modules/ssl/byol.py
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
|
test_step(batch, *args)
¶
Calculate accuracy on the test set for the given batch.
Source code in quadra/modules/ssl/byol.py
165 166 167 168 169 |
|
update_teacher()
¶
Update teacher given self.teacher_momentum
by an exponential moving average
of the student parameters, that is: theta_t * tau + theta_s * (1 - tau), where
theta_{s,t}
are the parameters of the student and the teacher model, while tau
is the
teacher momentum. If self.teacher_momentum_cosine_decay
is True, then the teacher
momentum will follow a cosine scheduling from self.teacher_momentum
to 1:
tau = 1 - (1 - tau) * (cos(pi * t / T) + 1) / 2, where t
is the current step and
T
is the max number of steps.
Source code in quadra/modules/ssl/byol.py
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
|