Skip to content

idmm

IDMMLoss(smoothing=0.1)

Bases: Module

IDMM loss described in https://arxiv.org/abs/2201.10728.

Source code in quadra/losses/ssl/idmm.py
28
29
30
def __init__(self, smoothing: float = 0.1):
    super().__init__()
    self.smoothing = smoothing

forward(p1, y1)

IDMM loss described in https://arxiv.org/abs/2201.10728.

Parameters:

  • p1 (Tensor) –

    Prediction labels for z1

  • y1 (Tensor) –

    Instance labels for z1

Returns:

Source code in quadra/losses/ssl/idmm.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def forward(
    self,
    p1: torch.Tensor,
    y1: torch.Tensor,
) -> torch.Tensor:
    """IDMM loss described in https://arxiv.org/abs/2201.10728.

    Args:
        p1: Prediction labels for `z1`
        y1: Instance labels for `z1`

    Returns:
        IDMM loss
    """
    return idmm_loss(
        p1,
        y1,
        self.smoothing,
    )

idmm_loss(p1, y1, smoothing=0.1)

IDMM loss described in https://arxiv.org/abs/2201.10728.

Parameters:

  • p1 (Tensor) –

    Prediction labels for z1

  • y1 (Tensor) –

    Instance labels for z1

  • smoothing (float, default: 0.1 ) –

    smoothing factor used for label smoothing. Defaults to 0.1.

Returns:

Source code in quadra/losses/ssl/idmm.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def idmm_loss(
    p1: torch.Tensor,
    y1: torch.Tensor,
    smoothing: float = 0.1,
) -> torch.Tensor:
    """IDMM loss described in https://arxiv.org/abs/2201.10728.

    Args:
        p1: Prediction labels for `z1`
        y1: Instance labels for `z1`
        smoothing: smoothing factor used for label smoothing.
            Defaults to 0.1.

    Returns:
        IDMM loss
    """
    loss = F.cross_entropy(p1, y1, label_smoothing=smoothing)
    return loss