barlowtwins
BarlowTwinsLoss(lambd)
¶
Bases: Module
BarlowTwin loss.
Parameters:
-
lambd
(
float) –lambda of the loss.
Source code in quadra/losses/ssl/barlowtwins.py
41 42 43 | |
forward(z1, z2)
¶
Compute the BarlowTwins loss.
Source code in quadra/losses/ssl/barlowtwins.py
45 46 47 | |
barlowtwins_loss(z1, z2, lambd)
¶
BarlowTwins loss described in https://arxiv.org/abs/2103.03230.
Parameters:
-
z1
(
Tensor) –First
augmentednormalized features (i.e. f(T(x))). The normalization can be obtained with z1_norm = (z1 - z1.mean(0)) / z1.std(0) -
z2
(
Tensor) –Second
augmentednormalized features (i.e. f(T(x))). The normalization can be obtained with z2_norm = (z2 - z2.mean(0)) / z2.std(0) -
lambd
(
float) –lambda multiplier for redundancy term.
Returns:
-
Tensor–BarlowTwins loss
Source code in quadra/losses/ssl/barlowtwins.py
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | |