Skip to content

asl

AsymmetricLoss(gamma_neg=4, gamma_pos=0, m=0.05, eps=1e-08, disable_torch_grad_focal_loss=False, apply_sigmoid=True)

Bases: Module

Notice - optimized version, minimizes memory allocation and gpu uploading, favors inplace operations.

Parameters:

  • gamma_neg (float, default: 4 ) –

    gamma for negative samples

  • gamma_pos (float, default: 0 ) –

    gamma for positive samples

  • m (float, default: 0.05 ) –

    bias value added to negative samples

  • eps (float, default: 1e-08 ) –

    epsilon to avoid division by zero

  • disable_torch_grad_focal_loss (bool, default: False ) –

    if True, disables torch grad for focal loss

  • apply_sigmoid (bool, default: True ) –

    if True, applies sigmoid to input before computing loss

Source code in quadra/losses/classification/asl.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def __init__(
    self,
    gamma_neg: float = 4,
    gamma_pos: float = 0,
    m: float = 0.05,
    eps: float = 1e-8,
    disable_torch_grad_focal_loss: bool = False,
    apply_sigmoid: bool = True,
):
    super().__init__()

    self.gamma_neg = gamma_neg
    self.gamma_pos = gamma_pos
    self.m = m
    self.eps = eps
    self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
    self.apply_sigmoid = apply_sigmoid

    # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
    self.targets: torch.Tensor
    self.anti_targets: torch.Tensor
    self.xs_pos: torch.Tensor
    self.xs_neg: torch.Tensor
    self.asymmetric_w: torch.Tensor
    self.loss: torch.Tensor

forward(x, y)

Compute the asymmetric loss.

Parameters:

  • x (Tensor) –

    input logits (after sigmoid)

  • y (Tensor) –

    targets (multi-label binarized vector)

Returns:

Source code in quadra/losses/classification/asl.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Compute the asymmetric loss.

    Args:
        x: input logits (after sigmoid)
        y: targets (multi-label binarized vector)

    Returns:
        asymettric loss
    """
    self.targets = y
    self.anti_targets = 1 - y

    # Calculating Probabilities
    self.xs_pos = x
    if self.apply_sigmoid:
        self.xs_pos = torch.sigmoid(self.xs_pos)
    self.xs_neg = 1.0 - self.xs_pos

    # Asymmetric clipping
    if self.m is not None and self.m > 0:
        self.xs_neg.add_(self.m).clamp_(max=1)

    # Basic CE calculation
    self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
    self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))

    # Asymmetric Focusing
    if self.gamma_neg > 0 or self.gamma_pos > 0:
        if self.disable_torch_grad_focal_loss:
            torch.set_grad_enabled(False)
        self.xs_pos = self.xs_pos * self.targets
        self.xs_neg = self.xs_neg * self.anti_targets
        self.asymmetric_w = torch.pow(
            1 - self.xs_pos - self.xs_neg, self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets
        )
        if self.disable_torch_grad_focal_loss:
            torch.set_grad_enabled(True)
        self.loss *= self.asymmetric_w

    return -self.loss.sum()