barlowtwins
BarlowTwinsLoss(lambd)
¶
Bases: Module
BarlowTwin loss.
Parameters:
-
lambd
(
float
) –lambda of the loss.
Source code in quadra/losses/ssl/barlowtwins.py
41 42 43 |
|
forward(z1, z2)
¶
Compute the BarlowTwins loss.
Source code in quadra/losses/ssl/barlowtwins.py
45 46 47 |
|
barlowtwins_loss(z1, z2, lambd)
¶
BarlowTwins loss described in https://arxiv.org/abs/2103.03230.
Parameters:
-
z1
(
Tensor
) –First
augmented
normalized features (i.e. f(T(x))). The normalization can be obtained with z1_norm = (z1 - z1.mean(0)) / z1.std(0) -
z2
(
Tensor
) –Second
augmented
normalized features (i.e. f(T(x))). The normalization can be obtained with z2_norm = (z2 - z2.mean(0)) / z2.std(0) -
lambd
(
float
) –lambda multiplier for redundancy term.
Returns:
-
Tensor
–BarlowTwins loss
Source code in quadra/losses/ssl/barlowtwins.py
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
|