Skip to content

focal

BinaryFocalLossWithLogits(alpha, gamma=2.0, reduction='none')

Bases: Module

Criterion that computes Focal loss.

According to :cite:lin2018focal, the Focal loss is computed as follows:

.. math::

\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
where
  • :math:p_t is the model's estimated probability for each class.

Parameters:

  • alpha (float) –

    Weighting factor for the rare class :math:\alpha \in [0, 1].

  • gamma (float, default: 2.0 ) –

    Focusing parameter :math:\gamma >= 0.

  • reduction (str, default: 'none' ) –

    Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output, 'sum': the output will be summed.

Shape
  • Input: :math:(N, *).
  • Target: :math:(N, *).

Examples:

>>> kwargs = {"alpha": 0.25, "gamma": 2.0, "reduction": 'mean'}
>>> loss = BinaryFocalLossWithLogits(**kwargs)
>>> input = torch.randn(1, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(2)
>>> output = loss(input, target)
>>> output.backward()
Source code in quadra/losses/classification/focal.py
311
312
313
314
315
def __init__(self, alpha: float, gamma: float = 2.0, reduction: str = "none") -> None:
    super().__init__()
    self.alpha: float = alpha
    self.gamma: float = gamma
    self.reduction: str = reduction

forward(input_tensor, target)

Forward call computation.

Source code in quadra/losses/classification/focal.py
317
318
319
def forward(self, input_tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Forward call computation."""
    return binary_focal_loss_with_logits(input_tensor, target, self.alpha, self.gamma, self.reduction)

FocalLoss(alpha, gamma=2.0, reduction='none', eps=None)

Bases: Module

Criterion that computes Focal loss.

According to :cite:lin2018focal, the Focal loss is computed as follows:

.. math::

\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
Where
  • :math:p_t is the model's estimated probability for each class.

Parameters:

  • alpha (float) –

    Weighting factor :math:\alpha \in [0, 1].

  • gamma (float, default: 2.0 ) –

    Focusing parameter :math:\gamma >= 0.

  • reduction (str, default: 'none' ) –

    Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output, 'sum': the output will be summed.

  • eps (Optional[float], default: None ) –

    used.

Shape
  • Input: :math:(N, C, *) where C = number of classes.
  • Target: :math:(N, *) where each value is :math:0 ≤ targets[i] ≤ C−1.
Example

N = 5 # num_classes kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'} criterion = FocalLoss(**kwargs) input = torch.randn(1, N, 3, 5, requires_grad=True) target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) output = criterion(input, target) output.backward()

Source code in quadra/losses/classification/focal.py
190
191
192
193
194
195
def __init__(self, alpha: float, gamma: float = 2.0, reduction: str = "none", eps: Optional[float] = None) -> None:
    super().__init__()
    self.alpha: float = alpha
    self.gamma: float = gamma
    self.reduction: str = reduction
    self.eps: Optional[float] = eps

forward(input_tensor, target)

Forward call computation.

Source code in quadra/losses/classification/focal.py
197
198
199
def forward(self, input_tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Forward call computation."""
    return focal_loss(input_tensor, target, self.alpha, self.gamma, self.reduction, self.eps)

binary_focal_loss_with_logits(input_tensor, target, alpha=0.25, gamma=2.0, reduction='none', eps=None)

Function that computes Binary Focal loss.

.. math::

\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
where
  • :math:p_t is the model's estimated probability for each class.

Parameters:

  • input_tensor (Tensor) –

    input data tensor of arbitrary shape.

  • target (Tensor) –

    the target tensor with shape matching input.

  • alpha (float, default: 0.25 ) –

    Weighting factor for the rare class :math:\alpha \in [0, 1].

  • gamma (float, default: 2.0 ) –

    Focusing parameter :math:\gamma >= 0.

  • reduction (str, default: 'none' ) –

    Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output, 'sum': the output will be summed.

  • eps (Optional[float], default: None ) –

Returns:

  • Tensor

    the computed loss.

Examples:

>>> kwargs = {"alpha": 0.25, "gamma": 2.0, "reduction": 'mean'}
>>> logits = torch.tensor([[[6.325]],[[5.26]],[[87.49]]])
>>> labels = torch.tensor([[[1.]],[[1.]],[[0.]]])
>>> binary_focal_loss_with_logits(logits, labels, **kwargs)
tensor(21.8725)
Source code in quadra/losses/classification/focal.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def binary_focal_loss_with_logits(
    input_tensor: torch.Tensor,
    target: torch.Tensor,
    alpha: float = 0.25,
    gamma: float = 2.0,
    reduction: str = "none",
    eps: Optional[float] = None,
) -> torch.Tensor:
    r"""Function that computes Binary Focal loss.

    .. math::

        \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)

    where:
       - :math:`p_t` is the model's estimated probability for each class.

    Args:
        input_tensor: input data tensor of arbitrary shape.
        target: the target tensor with shape matching input.
        alpha: Weighting factor for the rare class :math:`\alpha \in [0, 1]`.
        gamma: Focusing parameter :math:`\gamma >= 0`.
        reduction: Specifies the reduction to apply to the
            output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
            will be applied, ``'mean'``: the sum of the output will be divided by
            the number of elements in the output, ``'sum'``: the output will be
            summed.
        eps: Deprecated: scalar for numerically stability when dividing. This is no longer used.

    Returns:
        the computed loss.

    Examples:
        >>> kwargs = {"alpha": 0.25, "gamma": 2.0, "reduction": 'mean'}
        >>> logits = torch.tensor([[[6.325]],[[5.26]],[[87.49]]])
        >>> labels = torch.tensor([[[1.]],[[1.]],[[0.]]])
        >>> binary_focal_loss_with_logits(logits, labels, **kwargs)
        tensor(21.8725)
    """
    if eps is not None and not torch.jit.is_scripting():
        warnings.warn(
            "`binary_focal_loss_with_logits` has been reworked for improved numerical stability "
            "and the `eps` argument is no longer necessary",
            DeprecationWarning,
            stacklevel=2,
        )

    if not isinstance(input_tensor, torch.Tensor):
        raise TypeError(f"Input type is not a torch.Tensor. Got {type(input_tensor)}")

    if not len(input_tensor.shape) >= 2:
        raise ValueError(f"Invalid input shape, we expect BxCx*. Got: {input_tensor.shape}")

    if input_tensor.size(0) != target.size(0):
        raise ValueError(
            f"Expected input batch_size ({input_tensor.size(0)}) to match target batch_size ({target.size(0)})."
        )

    probs_pos = torch.sigmoid(input_tensor)
    probs_neg = torch.sigmoid(-input_tensor)
    loss_tmp = -alpha * torch.pow(probs_neg, gamma) * target * F.logsigmoid(input_tensor) - (1 - alpha) * torch.pow(
        probs_pos, gamma
    ) * (1.0 - target) * F.logsigmoid(-input_tensor)

    if reduction == "none":
        loss = loss_tmp
    elif reduction == "mean":
        loss = torch.mean(loss_tmp)
    elif reduction == "sum":
        loss = torch.sum(loss_tmp)
    else:
        raise NotImplementedError(f"Invalid reduction mode: {reduction}")
    return loss

focal_loss(input_tensor, target, alpha, gamma=2.0, reduction='none', eps=None)

Criterion that computes Focal loss.

According to :cite:lin2018focal, the Focal loss is computed as follows:

.. math::

\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
Where
  • :math:p_t is the model's estimated probability for each class.

Parameters:

  • input_tensor (Tensor) –

    Logits tensor with shape :math:(N, C, *) where C = number of classes.

  • target (Tensor) –

    Labels tensor with shape :math:(N, *) where each value is :math:0 ≤ targets[i] ≤ C−1.

  • alpha (float) –

    Weighting factor :math:\alpha \in [0, 1].

  • gamma (float, default: 2.0 ) –

    Focusing parameter :math:\gamma >= 0.

  • reduction (str, default: 'none' ) –

    Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output, 'sum': the output will be summed.

  • eps (Optional[float], default: None ) –

Returns:

  • Tensor

    The computed loss.

Example

N = 5 # num_classes input = torch.randn(1, N, 3, 5, requires_grad=True) target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) output = focal_loss(input, target, alpha=0.5, gamma=2.0, reduction='mean') output.backward()

Source code in quadra/losses/classification/focal.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def focal_loss(
    input_tensor: torch.Tensor,
    target: torch.Tensor,
    alpha: float,
    gamma: float = 2.0,
    reduction: str = "none",
    eps: Optional[float] = None,
) -> torch.Tensor:
    r"""Criterion that computes Focal loss.

    According to :cite:`lin2018focal`, the Focal loss is computed as follows:

    .. math::

        \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)

    Where:
       - :math:`p_t` is the model's estimated probability for each class.

    Args:
        input_tensor: Logits tensor with shape :math:`(N, C, *)` where C = number of classes.
        target: Labels tensor with shape :math:`(N, *)` where each value is :math:`0 ≤ targets[i] ≤ C−1`.
        alpha: Weighting factor :math:`\alpha \in [0, 1]`.
        gamma: Focusing parameter :math:`\gamma >= 0`.
        reduction: Specifies the reduction to apply to the
            output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
            will be applied, ``'mean'``: the sum of the output will be divided by
            the number of elements in the output, ``'sum'``: the output will be
            summed.
        eps: Deprecated: scalar to enforce numerical stabiliy. This is no longer used.

    Returns:
        The computed loss.

    Example:
        >>> N = 5  # num_classes
        >>> input = torch.randn(1, N, 3, 5, requires_grad=True)
        >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
        >>> output = focal_loss(input, target, alpha=0.5, gamma=2.0, reduction='mean')
        >>> output.backward()
    """
    if eps is not None and not torch.jit.is_scripting():
        warnings.warn(
            "`focal_loss` has been reworked for improved numerical stability "
            "and the `eps` argument is no longer necessary",
            DeprecationWarning,
            stacklevel=2,
        )

    if not isinstance(input_tensor, torch.Tensor):
        raise TypeError(f"Input type is not a torch.Tensor. Got {type(input_tensor)}")

    if not len(input_tensor.shape) >= 2:
        raise ValueError(f"Invalid input shape, we expect BxCx*. Got: {input_tensor.shape}")

    if input_tensor.size(0) != target.size(0):
        raise ValueError(
            f"Expected input batch_size ({input_tensor.size(0)}) to match target batch_size ({target.size(0)})."
        )

    n = input_tensor.size(0)
    out_size = (n,) + input_tensor.size()[2:]
    if target.size()[1:] != input_tensor.size()[2:]:
        raise ValueError(f"Expected target size {out_size}, got {target.size()}")

    if not input_tensor.device == target.device:
        raise ValueError(f"input and target must be in the same device. Got: {input_tensor.device} and {target.device}")

    # compute softmax over the classes axis
    input_soft: torch.Tensor = F.softmax(input_tensor, dim=1)
    log_input_soft: torch.Tensor = F.log_softmax(input_tensor, dim=1)

    # create the labels one hot tensor
    target_one_hot: torch.Tensor = one_hot(
        target, num_classes=input_tensor.shape[1], device=input_tensor.device, dtype=input_tensor.dtype
    )

    # compute the actual focal loss
    weight = torch.pow(-input_soft + 1.0, gamma)

    focal = -alpha * weight * log_input_soft
    loss_tmp = torch.einsum("bc...,bc...->b...", (target_one_hot, focal))

    if reduction == "none":
        loss = loss_tmp
    elif reduction == "mean":
        loss = torch.mean(loss_tmp)
    elif reduction == "sum":
        loss = torch.sum(loss_tmp)
    else:
        raise NotImplementedError(f"Invalid reduction mode: {reduction}")
    return loss

one_hot(labels, num_classes, device=None, dtype=None, eps=1e-06)

Convert an integer label x-D tensor to a one-hot (x+1)-D tensor.

Parameters:

  • labels (Tensor) –

    tensor with labels of shape :math:(N, *), where N is batch size. Each value is an integer representing correct classification.

  • num_classes (int) –

    number of classes in labels.

  • device (Optional[device], default: None ) –

    the desired device of returned tensor.

  • dtype (Optional[dtype], default: None ) –

    the desired data type of returned tensor.

  • eps (float, default: 1e-06 ) –

    a value added to the returned tensor.

Returns:

  • Tensor

    the labels in one hot tensor of shape :math:(N, C, *),

Examples:

>>> labels = torch.LongTensor([[[0, 1], [2, 0]]])
>>> one_hot(labels, num_classes=3)
tensor([[[[1.0000e+00, 1.0000e-06],
          [1.0000e-06, 1.0000e+00]],

         [[1.0000e-06, 1.0000e+00],
          [1.0000e-06, 1.0000e-06]],

         [[1.0000e-06, 1.0000e-06],
          [1.0000e+00, 1.0000e-06]]]])
Source code in quadra/losses/classification/focal.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def one_hot(
    labels: torch.Tensor,
    num_classes: int,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
    eps: float = 1e-6,
) -> torch.Tensor:
    r"""Convert an integer label x-D tensor to a one-hot (x+1)-D tensor.

    Args:
        labels: tensor with labels of shape :math:`(N, *)`, where N is batch size.
            Each value is an integer representing correct classification.
        num_classes: number of classes in labels.
        device: the desired device of returned tensor.
        dtype: the desired data type of returned tensor.
        eps: a value added to the returned tensor.

    Returns:
        the labels in one hot tensor of shape :math:`(N, C, *)`,

    Examples:
        >>> labels = torch.LongTensor([[[0, 1], [2, 0]]])
        >>> one_hot(labels, num_classes=3)
        tensor([[[[1.0000e+00, 1.0000e-06],
                  [1.0000e-06, 1.0000e+00]],
        <BLANKLINE>
                 [[1.0000e-06, 1.0000e+00],
                  [1.0000e-06, 1.0000e-06]],
        <BLANKLINE>
                 [[1.0000e-06, 1.0000e-06],
                  [1.0000e+00, 1.0000e-06]]]])

    """
    if not isinstance(labels, torch.Tensor):
        raise TypeError(f"Input labels type is not a torch.Tensor. Got {type(labels)}")

    if not labels.dtype == torch.int64:
        raise ValueError(f"labels must be of the same dtype torch.int64. Got: {labels.dtype}")

    if num_classes < 1:
        raise ValueError(f"The number of classes must be bigger than one. Got: {num_classes}")

    shape = labels.shape
    one_hot_output = torch.zeros((shape[0], num_classes) + shape[1:], device=device, dtype=dtype)

    return one_hot_output.scatter_(1, labels.unsqueeze(1), 1.0) + eps