dino
DinoDistillationLoss(output_dim, max_epochs, warmup_teacher_temp=0.04, teacher_temp=0.07, warmup_teacher_temp_epochs=30, student_temp=0.1, center_momentum=0.9)
¶
Dino distillation loss module.
Parameters:
-
output_dim
(
int) –output dim.
-
max_epochs
(
int) –max epochs.
-
warmup_teacher_temp
(
float) –warmup temperature.
-
teacher_temp
(
float) –teacher temperature.
-
warmup_teacher_temp_epochs
(
int) –warmup teacher epocs.
-
student_temp
(
float) –student temperature.
-
center_momentum
(
float) –center momentum.
Source code in quadra/losses/ssl/dino.py
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | |
forward(current_epoch, student_output, teacher_output)
¶
Runs forward.
Source code in quadra/losses/ssl/dino.py
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | |
update_center(teacher_output)
¶
Update center of the distribution of the teacher
Parameters:
Returns:
-
None–None
Source code in quadra/losses/ssl/dino.py
117 118 119 120 121 122 123 124 125 126 127 128 129 | |
dino_distillation_loss(student_output, teacher_output, center_vector, teacher_temp=0.04, student_temp=0.1)
¶
Compute the DINO distillation loss.
Parameters:
-
student_output
(
torch.Tensor) –tensor of the student output
-
teacher_output
(
torch.Tensor) –tensor of the teacher output
-
center_vector
(
torch.Tensor) –center vector of distribution
-
teacher_temp
(
float) –temperature teacher
-
student_temp
(
float) –temperature student.
Returns:
Source code in quadra/losses/ssl/dino.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | |