Skip to content

hyperspherical

align_loss(x, y, alpha=2)

Mean(l2^alpha).

Parameters:

  • x (Tensor) –

    feature n1

  • y (Tensor) –

    feature n2

  • alpha (int, default: 2 ) –

    pow of the norm loss.

Returns:

Source code in quadra/losses/ssl/hyperspherical.py
20
21
22
23
24
25
26
27
28
29
30
31
32
def align_loss(x: torch.Tensor, y: torch.Tensor, alpha: int = 2) -> torch.Tensor:
    """Mean(l2^alpha).

    Args:
        x: feature n1
        y: feature n2
        alpha: pow of the norm loss.

    Returns:
        Align loss
    """
    norm = torch.norm(x - y, p=2, dim=1)
    return torch.mean(torch.pow(norm, alpha))

cosine_align_loss(x, y)

Computes mean of cosine distance based on similarity mean(1 - cosine_similarity).

Parameters:

Returns:

  • Tensor

    cosine align loss

Source code in quadra/losses/ssl/hyperspherical.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
def cosine_align_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Computes mean of cosine distance based on similarity mean(1 - cosine_similarity).

    Args:
        x: feature n1
        y: feature n2.

    Returns:
        cosine align loss
    """
    cos = 1 - cosine_similarity(x, y, dim=1)
    return torch.mean(cos)

uniform_loss(x, t=2.0)

log(mean(exp(-t*dist_p2))).

Parameters:

  • x (Tensor) –

    feature tensor

  • t (float, default: 2.0 ) –

    temperature of the dist_p2.

Returns:

Source code in quadra/losses/ssl/hyperspherical.py
35
36
37
38
39
40
41
42
43
44
45
def uniform_loss(x: torch.Tensor, t: float = 2.0) -> torch.Tensor:
    """log(mean(exp(-t*dist_p2))).

    Args:
        x: feature tensor
        t: temperature of the dist_p2.

    Returns:
        Uniform loss
    """
    return torch.log(torch.mean(torch.exp(torch.pow(torch.pdist(x, p=2), 2) * -t)))