Skip to content

simsiam

SimSIAMLoss

Bases: Module

SimSIAM loss module.

forward(p1, p2, z1, z2)

Compute the SimSIAM loss.

Source code in quadra/losses/ssl/simsiam.py
28
29
30
def forward(self, p1: torch.Tensor, p2: torch.Tensor, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
    """Compute the SimSIAM loss."""
    return simsiam_loss(p1, p2, z1, z2)

simsiam_loss(p1, p2, z1, z2)

SimSIAM loss described in https://arxiv.org/abs/2011.10566.

Parameters:

  • p1 (Tensor) –

    First predicted features (i.e. h(f(T(x1))))

  • p2 (Tensor) –

    Second predicted features (i.e. h(f(T'(x2))))

  • z1 (Tensor) –

    First 'projected features (i.e. f(T(x1)))

  • z2 (Tensor) –

    Second 'projected features (i.e. f(T(x2)))

Returns:

  • Tensor

    SimSIAM loss

Source code in quadra/losses/ssl/simsiam.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def simsiam_loss(
    p1: torch.Tensor,
    p2: torch.Tensor,
    z1: torch.Tensor,
    z2: torch.Tensor,
) -> torch.Tensor:
    """SimSIAM loss described in https://arxiv.org/abs/2011.10566.

    Args:
        p1: First `predicted` features (i.e. h(f(T(x1))))
        p2: Second `predicted` features (i.e. h(f(T'(x2))))
        z1: First 'projected features (i.e. f(T(x1)))
        z2: Second 'projected features (i.e. f(T(x2)))

    Returns:
        SimSIAM loss
    """
    return -(F.cosine_similarity(p1, z2).mean() + F.cosine_similarity(p2, z1).mean()) * 0.5