hyperspherical
TLHyperspherical(model, optimizer=None, lr_scheduler=None, align_weight=1, unifo_weight=1, classifier_weight=1, align_loss_type=AlignLoss.L2, classifier_loss=False, num_classes=None)
¶
Bases: BaseLightningModule
Hyperspherical model: maps features extracted from a pretrained backbone into an hypersphere.
Parameters:
-
model
(
nn.Module
) –Feature extractor as pytorch
torch.nn.Module
-
optimizer
(
Optional[optim.Optimizer]
) –optimizer of the training. If None a default Adam is used.
-
lr_scheduler
(
Optional[object]
) –lr scheduler. If None a default ReduceLROnPlateau is used.
-
align_weight
(
float
) –Weight for the align loss component for the hyperspherical loss. Defaults to 1.
-
unifo_weight
(
float
) –Weight for the uniform loss component for the hyperspherical loss. Defaults to 1.
-
classifier_weight
(
float
) –Weight for the classifier loss component for the hyperspherical loss. Defaults to 1.
-
align_loss_type
(
AlignLoss
) –Which type of align loss to use. Defaults to AlignLoss.L2.
-
classifier_loss
(
bool
) –Whether to compute a classifier loss to 'enhance' the hyperpsherical loss with the classification loss. It True, model.classifier must be defined Defaults to False.
-
num_classes
(
Optional[int]
) –Number of classes for a classification problem. Defaults to None.
Source code in quadra/modules/ssl/hyperspherical.py
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
|