Skip to content

hyperspherical

AlignLoss

Bases: Enum

Align loss enum.

TLHyperspherical(model, optimizer=None, lr_scheduler=None, align_weight=1, unifo_weight=1, classifier_weight=1, align_loss_type=AlignLoss.L2, classifier_loss=False, num_classes=None)

Bases: BaseLightningModule

Hyperspherical model: maps features extracted from a pretrained backbone into an hypersphere.

Parameters:

  • model (nn.Module) –

    Feature extractor as pytorch torch.nn.Module

  • optimizer (Optional[optim.Optimizer]) –

    optimizer of the training. If None a default Adam is used.

  • lr_scheduler (Optional[object]) –

    lr scheduler. If None a default ReduceLROnPlateau is used.

  • align_weight (float) –

    Weight for the align loss component for the hyperspherical loss. Defaults to 1.

  • unifo_weight (float) –

    Weight for the uniform loss component for the hyperspherical loss. Defaults to 1.

  • classifier_weight (float) –

    Weight for the classifier loss component for the hyperspherical loss. Defaults to 1.

  • align_loss_type (AlignLoss) –

    Which type of align loss to use. Defaults to AlignLoss.L2.

  • classifier_loss (bool) –

    Whether to compute a classifier loss to 'enhance' the hyperpsherical loss with the classification loss. It True, model.classifier must be defined Defaults to False.

  • num_classes (Optional[int]) –

    Number of classes for a classification problem. Defaults to None.

Source code in quadra/modules/ssl/hyperspherical.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def __init__(
    self,
    model: nn.Module,
    optimizer: Optional[optim.Optimizer] = None,
    lr_scheduler: Optional[object] = None,
    align_weight: float = 1,
    unifo_weight: float = 1,
    classifier_weight: float = 1,
    align_loss_type: AlignLoss = AlignLoss.L2,
    classifier_loss: bool = False,
    num_classes: Optional[int] = None,
):
    super().__init__(model, optimizer, lr_scheduler)
    self.align_loss_fun: Union[
        Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor],
        Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    ]
    self.align_weight = align_weight
    self.unifo_weight = unifo_weight
    self.classifier_weight = classifier_weight
    self.align_loss_type = align_loss_type
    if align_loss_type == AlignLoss.L2:
        self.align_loss_fun = loss.align_loss
    elif align_loss_type == AlignLoss.COSINE:
        self.align_loss_fun = loss.cosine_align_loss
    else:
        raise ValueError("The align loss must be one of 'AlignLoss.L2' (L2 distance) or AlignLoss.COSINE")

    if classifier_loss:
        if model.classifier is None:
            raise AssertionError("Classifier is not defined")

    self.classifier_loss = classifier_loss
    self.num_classes = num_classes