Skip to content

byol

BYOLRegressionLoss

Bases: Module

BYOL regression loss module.

forward(x, y)

Compute the BYOL regression loss.

Parameters:

  • x (Tensor) –

    First Tensor

  • y (Tensor) –

    Second Tensor

Returns:

  • Tensor

    BYOL regression loss

Source code in quadra/losses/ssl/byol.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def forward(
    self,
    x: torch.Tensor,
    y: torch.Tensor,
) -> torch.Tensor:
    """Compute the BYOL regression loss.

    Args:
        x: First Tensor
        y: Second Tensor

    Returns:
        BYOL regression loss
    """
    return byol_regression_loss(x, y)

byol_regression_loss(x, y)

Byol regression loss Args: x: tensor y: tensor.

Returns:

  • Tensor

    tensor

Source code in quadra/losses/ssl/byol.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
def byol_regression_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Byol regression loss
    Args:
        x: tensor
        y: tensor.

    Returns:
        tensor
    """
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return 2 - 2 * (x * y).sum(dim=1).mean()