Skip to content

vicreg

VICRegLoss(lambd, mu, nu=1, gamma=1)

Bases: Module

VIC regression loss module.

Parameters:

  • lambd (float) –

    lambda multiplier for redundancy term.

  • mu (float) –

    mu multiplier for similarity term.

  • nu (float, default: 1 ) –

    nu multiplier for variance term. Default: 1.

  • gamma (float, default: 1 ) –

    gamma multiplier for covariance term. Default: 1.

Source code in quadra/losses/ssl/vicreg.py
61
62
63
64
65
66
67
68
69
70
71
72
def __init__(
    self,
    lambd: float,
    mu: float,
    nu: float = 1,
    gamma: float = 1,
):
    super().__init__()
    self.lambd = lambd
    self.mu = mu
    self.nu = nu
    self.gamma = gamma

forward(z1, z2)

Computes VICReg loss.

Source code in quadra/losses/ssl/vicreg.py
74
75
76
def forward(self, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
    """Computes VICReg loss."""
    return vicreg_loss(z1, z2, self.lambd, self.mu, self.nu, self.gamma)

vicreg_loss(z1, z2, lambd, mu, nu=1, gamma=1)

VICReg loss described in https://arxiv.org/abs/2105.04906.

Parameters:

  • z1 (Tensor) –

    First augmented normalized features (i.e. f(T(x))). The normalization can be obtained with z1_norm = (z1 - z1.mean(0)) / z1.std(0)

  • z2 (Tensor) –

    Second augmented normalized features (i.e. f(T(x))). The normalization can be obtained with z2_norm = (z2 - z2.mean(0)) / z2.std(0)

  • lambd (float) –

    lambda multiplier for redundancy term.

  • mu (float) –

    mu multiplier for similarity term.

  • nu (float, default: 1 ) –

    nu multiplier for variance term. Default: 1

  • gamma (float, default: 1 ) –

    gamma multiplier for covariance term. Default: 1

Returns:

Source code in quadra/losses/ssl/vicreg.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def vicreg_loss(
    z1: torch.Tensor,
    z2: torch.Tensor,
    lambd: float,
    mu: float,
    nu: float = 1,
    gamma: float = 1,
) -> torch.Tensor:
    """VICReg loss described in https://arxiv.org/abs/2105.04906.

    Args:
        z1: First `augmented` normalized features (i.e. f(T(x))). The normalization can be obtained with
            z1_norm = (z1 - z1.mean(0)) / z1.std(0)
        z2: Second `augmented` normalized features (i.e. f(T(x))). The normalization can be obtained with
            z2_norm = (z2 - z2.mean(0)) / z2.std(0)
        lambd: lambda multiplier for redundancy term.
        mu: mu multiplier for similarity term.
        nu: nu multiplier for variance term. Default: 1
        gamma: gamma multiplier for covariance term. Default: 1

    Returns:
        VICReg loss
    """
    # Variance loss
    std_z1 = torch.sqrt(z1.var(dim=0) + 0.0001)
    std_z2 = torch.sqrt(z2.var(dim=0) + 0.0001)
    v_z1 = torch.nn.functional.relu(gamma - std_z1).mean()
    v_z2 = torch.nn.functional.relu(gamma - std_z2).mean()
    var_loss = v_z1 + v_z2

    # Similarity loss
    sim_loss = torch.nn.functional.mse_loss(z1, z2)

    # Covariance loss
    n = z1.size(0)
    d = z1.size(1)
    z1 = z1 - z1.mean(dim=0)
    z2 = z2 - z2.mean(dim=0)
    cov_z1 = (z1.T @ z1) / (n - 1)
    cov_z2 = (z2.T @ z2) / (n - 1)
    off_diagonal_cov_z1 = cov_z1.flatten()[:-1].view(d - 1, d + 1)[:, 1:].flatten()
    off_diagonal_cov_z2 = cov_z2.flatten()[:-1].view(d - 1, d + 1)[:, 1:].flatten()
    cov_loss = off_diagonal_cov_z1.pow_(2).sum() / d + off_diagonal_cov_z2.pow_(2).sum() / d

    return lambd * sim_loss + mu * var_loss + nu * cov_loss