Skip to content

idmm

IDMM(model, prediction_mlp, criterion, multiview_loss=True, mixup_fn=None, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')

Bases: SSLModule

IDMM model.

Parameters:

  • model (Module) –

    backbone model

  • prediction_mlp (Module) –

    student prediction MLP

  • criterion (Module) –

    loss function

  • multiview_loss (bool, default: True ) –

    whether to use the multiview loss as definied in https://arxiv.org/abs/2201.10728. Defaults to True.

  • mixup_fn (Optional[Mixup], default: None ) –

    the mixup/cutmix function to be applied to a batch of images. Defaults to None.

  • classifier (Optional[ClassifierMixin], default: None ) –

    Standard sklearn classifier

  • optimizer (Optional[Optimizer], default: None ) –

    optimizer of the training. If None a default Adam is used.

  • lr_scheduler (Optional[object], default: None ) –

    lr scheduler. If None a default ReduceLROnPlateau is used.

  • lr_scheduler_interval (Optional[str], default: 'epoch' ) –

    interval at which the lr scheduler is updated.

Source code in quadra/modules/ssl/idmm.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    model: torch.nn.Module,
    prediction_mlp: torch.nn.Module,
    criterion: torch.nn.Module,
    multiview_loss: bool = True,
    mixup_fn: Optional[timm.data.Mixup] = None,
    classifier: Optional[sklearn.base.ClassifierMixin] = None,
    optimizer: Optional[torch.optim.Optimizer] = None,
    lr_scheduler: Optional[object] = None,
    lr_scheduler_interval: Optional[str] = "epoch",
):

    super().__init__(
        model,
        criterion,
        classifier,
        optimizer,
        lr_scheduler,
        lr_scheduler_interval,
    )
    # self.save_hyperparameters()
    self.prediction_mlp = prediction_mlp
    self.mixup_fn = mixup_fn
    self.multiview_loss = multiview_loss