vicreg
VICRegLoss(lambd, mu, nu=1, gamma=1)
¶
Bases: Module
VIC regression loss module.
Parameters:
-
lambd
(
float
) –lambda multiplier for redundancy term.
-
mu
(
float
) –mu multiplier for similarity term.
-
nu
(
float
, default:1
) –nu multiplier for variance term. Default: 1.
-
gamma
(
float
, default:1
) –gamma multiplier for covariance term. Default: 1.
Source code in quadra/losses/ssl/vicreg.py
61 62 63 64 65 66 67 68 69 70 71 72 |
|
forward(z1, z2)
¶
Computes VICReg loss.
Source code in quadra/losses/ssl/vicreg.py
74 75 76 |
|
vicreg_loss(z1, z2, lambd, mu, nu=1, gamma=1)
¶
VICReg loss described in https://arxiv.org/abs/2105.04906.
Parameters:
-
z1
(
Tensor
) –First
augmented
normalized features (i.e. f(T(x))). The normalization can be obtained with z1_norm = (z1 - z1.mean(0)) / z1.std(0) -
z2
(
Tensor
) –Second
augmented
normalized features (i.e. f(T(x))). The normalization can be obtained with z2_norm = (z2 - z2.mean(0)) / z2.std(0) -
lambd
(
float
) –lambda multiplier for redundancy term.
-
mu
(
float
) –mu multiplier for similarity term.
-
nu
(
float
, default:1
) –nu multiplier for variance term. Default: 1
-
gamma
(
float
, default:1
) –gamma multiplier for covariance term. Default: 1
Returns:
-
Tensor
–VICReg loss
Source code in quadra/losses/ssl/vicreg.py
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
|