ssl
BYOLRegressionLoss
¶
Bases: Module
BYOL regression loss module.
forward(x, y)
¶
Compute the BYOL regression loss.
Parameters:
Returns:
-
Tensor–BYOL regression loss
Source code in quadra/losses/ssl/byol.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | |
BarlowTwinsLoss(lambd)
¶
Bases: Module
BarlowTwin loss.
Parameters:
-
lambd
(
float) –lambda of the loss.
Source code in quadra/losses/ssl/barlowtwins.py
41 42 43 | |
forward(z1, z2)
¶
Compute the BarlowTwins loss.
Source code in quadra/losses/ssl/barlowtwins.py
45 46 47 | |
DinoDistillationLoss(output_dim, max_epochs, warmup_teacher_temp=0.04, teacher_temp=0.07, warmup_teacher_temp_epochs=30, student_temp=0.1, center_momentum=0.9)
¶
Bases: Module
Dino distillation loss module.
Parameters:
-
output_dim
(
int) –output dim.
-
max_epochs
(
int) –max epochs.
-
warmup_teacher_temp
(
float, default:0.04) –warmup temperature.
-
teacher_temp
(
float, default:0.07) –teacher temperature.
-
warmup_teacher_temp_epochs
(
int, default:30) –warmup teacher epocs.
-
student_temp
(
float, default:0.1) –student temperature.
-
center_momentum
(
float, default:0.9) –center momentum.
Source code in quadra/losses/ssl/dino.py
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | |
forward(current_epoch, student_output, teacher_output)
¶
Runs forward.
Source code in quadra/losses/ssl/dino.py
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | |
update_center(teacher_output)
¶
Update center of the distribution of the teacher Args: teacher_output: teacher output.
Returns:
-
None–None
Source code in quadra/losses/ssl/dino.py
117 118 119 120 121 122 123 124 125 126 127 128 129 | |
IDMMLoss(smoothing=0.1)
¶
Bases: Module
IDMM loss described in https://arxiv.org/abs/2201.10728.
Source code in quadra/losses/ssl/idmm.py
28 29 30 | |
forward(p1, y1)
¶
IDMM loss described in https://arxiv.org/abs/2201.10728.
Parameters:
Returns:
-
Tensor–IDMM loss
Source code in quadra/losses/ssl/idmm.py
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | |
SimCLRLoss(temperature=1.0)
¶
Bases: Module
SIMCLRloss module.
Parameters:
-
temperature
(
float, default:1.0) –temperature of SIM loss.
Source code in quadra/losses/ssl/simclr.py
61 62 63 | |
forward(x1, x2)
¶
Forward pass of the loss.
Source code in quadra/losses/ssl/simclr.py
65 66 67 | |
SimSIAMLoss
¶
Bases: Module
SimSIAM loss module.
forward(p1, p2, z1, z2)
¶
Compute the SimSIAM loss.
Source code in quadra/losses/ssl/simsiam.py
28 29 30 | |
VICRegLoss(lambd, mu, nu=1, gamma=1)
¶
Bases: Module
VIC regression loss module.
Parameters:
-
lambd
(
float) –lambda multiplier for redundancy term.
-
mu
(
float) –mu multiplier for similarity term.
-
nu
(
float, default:1) –nu multiplier for variance term. Default: 1.
-
gamma
(
float, default:1) –gamma multiplier for covariance term. Default: 1.
Source code in quadra/losses/ssl/vicreg.py
61 62 63 64 65 66 67 68 69 70 71 72 | |
forward(z1, z2)
¶
Computes VICReg loss.
Source code in quadra/losses/ssl/vicreg.py
74 75 76 | |