dino
DinoDistillationLoss(output_dim, max_epochs, warmup_teacher_temp=0.04, teacher_temp=0.07, warmup_teacher_temp_epochs=30, student_temp=0.1, center_momentum=0.9)
¶
Bases: Module
Dino distillation loss module.
Parameters:
-
output_dim
(
int
) –output dim.
-
max_epochs
(
int
) –max epochs.
-
warmup_teacher_temp
(
float
, default:0.04
) –warmup temperature.
-
teacher_temp
(
float
, default:0.07
) –teacher temperature.
-
warmup_teacher_temp_epochs
(
int
, default:30
) –warmup teacher epocs.
-
student_temp
(
float
, default:0.1
) –student temperature.
-
center_momentum
(
float
, default:0.9
) –center momentum.
Source code in quadra/losses/ssl/dino.py
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
|
forward(current_epoch, student_output, teacher_output)
¶
Runs forward.
Source code in quadra/losses/ssl/dino.py
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
|
update_center(teacher_output)
¶
Update center of the distribution of the teacher Args: teacher_output: teacher output.
Returns:
-
None
–None
Source code in quadra/losses/ssl/dino.py
117 118 119 120 121 122 123 124 125 126 127 128 129 |
|
dino_distillation_loss(student_output, teacher_output, center_vector, teacher_temp=0.04, student_temp=0.1)
¶
Compute the DINO distillation loss.
Parameters:
-
student_output
(
Tensor
) –tensor of the student output
-
teacher_output
(
Tensor
) –tensor of the teacher output
-
center_vector
(
Tensor
) –center vector of distribution
-
teacher_temp
(
float
, default:0.04
) –temperature teacher
-
student_temp
(
float
, default:0.1
) –temperature student.
Returns:
-
Tensor
–The computed loss
Source code in quadra/losses/ssl/dino.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
|