Skip to content

ssl

BYOL(student, teacher, student_projection_mlp, student_prediction_mlp, teacher_projection_mlp, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch', teacher_momentum=0.9995, teacher_momentum_cosine_decay=True)

Bases: SSLModule

BYOL module, inspired by https://arxiv.org/abs/2006.07733.

Parameters:

  • student

    student model.

  • teacher

    teacher model.

  • student_projection_mlp

    student projection MLP.

  • student_prediction_mlp

    student prediction MLP.

  • teacher_projection_mlp

    teacher projection MLP.

  • criterion

    loss function.

  • classifier (ClassifierMixin | None, default: None ) –

    Standard sklearn classifier.

  • optimizer (Optimizer | None, default: None ) –

    optimizer of the training. If None a default Adam is used.

  • lr_scheduler (object | None, default: None ) –

    lr scheduler. If None a default ReduceLROnPlateau is used.

  • lr_scheduler_interval (str | None, default: 'epoch' ) –

    interval at which the lr scheduler is updated.

  • teacher_momentum (float, default: 0.9995 ) –

    momentum of the teacher parameters.

  • teacher_momentum_cosine_decay (bool | None, default: True ) –

    whether to use cosine decay for the teacher momentum. Default: True

Source code in quadra/modules/ssl/byol.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
    self,
    student: nn.Module,
    teacher: nn.Module,
    student_projection_mlp: nn.Module,
    student_prediction_mlp: nn.Module,
    teacher_projection_mlp: nn.Module,
    criterion: nn.Module,
    classifier: sklearn.base.ClassifierMixin | None = None,
    optimizer: Optimizer | None = None,
    lr_scheduler: object | None = None,
    lr_scheduler_interval: str | None = "epoch",
    teacher_momentum: float = 0.9995,
    teacher_momentum_cosine_decay: bool | None = True,
):
    super().__init__(
        model=student,
        criterion=criterion,
        classifier=classifier,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        lr_scheduler_interval=lr_scheduler_interval,
    )
    # Student model
    self.max_steps: int
    self.student_projection_mlp = student_projection_mlp
    self.student_prediction_mlp = student_prediction_mlp

    # Teacher model
    self.teacher = teacher
    self.teacher_projection_mlp = teacher_projection_mlp
    self.teacher_initialized = False
    self.teacher_momentum = teacher_momentum
    self.teacher_momentum_cosine_decay = teacher_momentum_cosine_decay

    self.initialize_teacher()

calculate_accuracy(batch)

Calculate accuracy on the given batch.

Source code in quadra/modules/ssl/byol.py
155
156
157
158
159
160
161
162
163
def calculate_accuracy(self, batch):
    """Calculate accuracy on the given batch."""
    images, labels = batch
    embedding = self.model(images).detach().cpu().numpy()
    predictions = self.classifier.predict(embedding)
    labels = labels.detach()
    acc = self.val_acc(torch.tensor(predictions, device=self.device), labels)

    return acc

initialize_teacher()

Initialize teacher from the state dict of the student one, checking also that student model requires greadient correctly.

Source code in quadra/modules/ssl/byol.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def initialize_teacher(self):
    """Initialize teacher from the state dict of the student one,
    checking also that student model requires greadient correctly.
    """
    self.teacher_projection_mlp.load_state_dict(self.student_projection_mlp.state_dict())
    for p in self.teacher_projection_mlp.parameters():
        p.requires_grad = False

    self.teacher.load_state_dict(self.model.state_dict())
    for p in self.teacher.parameters():
        p.requires_grad = False

    for p in self.student_projection_mlp.parameters():
        assert p.requires_grad is True
    for p in self.student_prediction_mlp.parameters():
        assert p.requires_grad is True

    self.teacher_initialized = True

optimizer_step(epoch, batch_idx, optimizer, optimizer_closure=None)

Override optimizer step to update the teacher parameters.

Source code in quadra/modules/ssl/byol.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def optimizer_step(
    self,
    epoch: int,
    batch_idx: int,
    optimizer: Optimizer | LightningOptimizer,
    optimizer_closure: Callable[[], Any] | None = None,
) -> None:
    """Override optimizer step to update the teacher parameters."""
    super().optimizer_step(
        epoch,
        batch_idx,
        optimizer,
        optimizer_closure=optimizer_closure,
    )
    self.update_teacher()

test_step(batch, *args)

Calculate accuracy on the test set for the given batch.

Source code in quadra/modules/ssl/byol.py
168
169
170
171
172
def test_step(self, batch, *args: list[Any]) -> None:
    """Calculate accuracy on the test set for the given batch."""
    acc = self.calculate_accuracy(batch)
    self.log(name="test_acc", value=acc, on_step=False, on_epoch=True, prog_bar=True)
    return acc

update_teacher()

Update teacher given self.teacher_momentum by an exponential moving average of the student parameters, that is: theta_t * tau + theta_s * (1 - tau), where theta_{s,t} are the parameters of the student and the teacher model, while tau is the teacher momentum. If self.teacher_momentum_cosine_decay is True, then the teacher momentum will follow a cosine scheduling from self.teacher_momentum to 1: tau = 1 - (1 - tau) * (cos(pi * t / T) + 1) / 2, where t is the current step and T is the max number of steps.

Source code in quadra/modules/ssl/byol.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def update_teacher(self):
    """Update teacher given `self.teacher_momentum` by an exponential moving average
    of the student parameters, that is: theta_t * tau + theta_s * (1 - tau), where
    `theta_{s,t}` are the parameters of the student and the teacher model, while `tau` is the
    teacher momentum. If `self.teacher_momentum_cosine_decay` is True, then the teacher
    momentum will follow a cosine scheduling from `self.teacher_momentum` to 1:
    tau = 1 - (1 - tau) * (cos(pi * t / T) + 1) / 2, where `t` is the current step and
    `T` is the max number of steps.
    """
    with torch.no_grad():
        if self.teacher_momentum_cosine_decay:
            teacher_momentum = (
                1
                - (1 - self.teacher_momentum)
                * (math.cos(math.pi * self.trainer.global_step / self.max_steps) + 1)
                / 2
            )
        else:
            teacher_momentum = self.teacher_momentum
        self.log("teacher_momentum", teacher_momentum, prog_bar=True)
        for student_ps, teacher_ps in zip(
            list(self.model.parameters()) + list(self.student_projection_mlp.parameters()),
            list(self.teacher.parameters()) + list(self.teacher_projection_mlp.parameters()),
        ):
            teacher_ps.data = teacher_ps.data * teacher_momentum + (1 - teacher_momentum) * student_ps.data

BarlowTwins(model, projection_mlp, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')

Bases: SSLModule

BarlowTwins model.

Parameters:

  • model (Module) –

    Network Module used for extract features

  • projection_mlp (Module) –

    Module to project extracted features

  • criterion (Module) –

    SSL loss to be applied

  • classifier (ClassifierMixin | None, default: None ) –

    Standard sklearn classifier. Defaults to None.

  • optimizer (Optimizer | None, default: None ) –

    optimizer of the training. If None a default Adam is used. Defaults to None.

  • lr_scheduler (object | None, default: None ) –

    lr scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.

  • lr_scheduler_interval (str | None, default: 'epoch' ) –

    interval at which the lr scheduler is updated. Defaults to "epoch".

Source code in quadra/modules/ssl/barlowtwins.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def __init__(
    self,
    model: nn.Module,
    projection_mlp: nn.Module,
    criterion: nn.Module,
    classifier: sklearn.base.ClassifierMixin | None = None,
    optimizer: optim.Optimizer | None = None,
    lr_scheduler: object | None = None,
    lr_scheduler_interval: str | None = "epoch",
):
    super().__init__(model, criterion, classifier, optimizer, lr_scheduler, lr_scheduler_interval)
    # self.save_hyperparameters()
    self.projection_mlp = projection_mlp
    self.criterion = criterion

Dino(student, teacher, student_projection_mlp, teacher_projection_mlp, criterion, freeze_last_layer=1, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch', teacher_momentum=0.9995, teacher_momentum_cosine_decay=True)

Bases: BYOL

DINO pytorch-lightning module.

Parameters:

  • student

    student model

  • teacher

    teacher model

  • student_projection_mlp

    student projection MLP

  • teacher_projection_mlp

    teacher projection MLP

  • criterion

    loss function

  • freeze_last_layer

    number of layers to freeze in the student model. Default: 1

  • classifier (ClassifierMixin | None, default: None ) –

    Standard sklearn classifier

  • optimizer (Optimizer | None, default: None ) –

    optimizer of the training. If None a default Adam is used.

  • lr_scheduler (object | None, default: None ) –

    lr scheduler. If None a default ReduceLROnPlateau is used.

  • lr_scheduler_interval (str | None, default: 'epoch' ) –

    interval at which the lr scheduler is updated.

  • teacher_momentum (float, default: 0.9995 ) –

    momentum of the teacher parameters

  • teacher_momentum_cosine_decay (bool | None, default: True ) –

    whether to use cosine decay for the teacher momentum

Source code in quadra/modules/ssl/dino.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(
    self,
    student: nn.Module,
    teacher: nn.Module,
    student_projection_mlp: nn.Module,
    teacher_projection_mlp: nn.Module,
    criterion: nn.Module,
    freeze_last_layer: int = 1,
    classifier: sklearn.base.ClassifierMixin | None = None,
    optimizer: Optimizer | None = None,
    lr_scheduler: object | None = None,
    lr_scheduler_interval: str | None = "epoch",
    teacher_momentum: float = 0.9995,
    teacher_momentum_cosine_decay: bool | None = True,
):
    super().__init__(
        student=student,
        teacher=teacher,
        student_projection_mlp=student_projection_mlp,
        student_prediction_mlp=nn.Identity(),
        teacher_projection_mlp=teacher_projection_mlp,
        criterion=criterion,
        teacher_momentum=teacher_momentum,
        teacher_momentum_cosine_decay=teacher_momentum_cosine_decay,
        classifier=classifier,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        lr_scheduler_interval=lr_scheduler_interval,
    )
    self.freeze_last_layer = freeze_last_layer

cancel_gradients_last_layer(epoch, freeze_last_layer)

Zero out the gradient of the last layer, as specified in the paper.

Parameters:

  • epoch (int) –

    current epoch

  • freeze_last_layer (int) –

    maximum freeze epoch: if epoch >= freeze_last_layer then the gradient of the last layer will not be freezed

Source code in quadra/modules/ssl/dino.py
135
136
137
138
139
140
141
142
143
144
145
146
147
def cancel_gradients_last_layer(self, epoch: int, freeze_last_layer: int):
    """Zero out the gradient of the last layer, as specified in the paper.

    Args:
        epoch: current epoch
        freeze_last_layer: maximum freeze epoch: if `epoch` >= `freeze_last_layer`
            then the gradient of the last layer will not be freezed
    """
    if epoch >= freeze_last_layer:
        return
    for n, p in self.student_projection_mlp.named_parameters():
        if "last_layer" in n:
            p.grad = None

configure_gradient_clipping(optimizer, gradient_clip_val=None, gradient_clip_algorithm=None)

Configure gradient clipping for the optimizer.

Source code in quadra/modules/ssl/dino.py
160
161
162
163
164
165
166
167
168
169
170
def configure_gradient_clipping(
    self,
    optimizer: Optimizer,
    gradient_clip_val: int | float | None = None,
    gradient_clip_algorithm: str | None = None,
):
    """Configure gradient clipping for the optimizer."""
    if gradient_clip_algorithm is not None and gradient_clip_val is not None:
        clip_gradients(self.model, gradient_clip_val)
        clip_gradients(self.student_projection_mlp, gradient_clip_val)
    self.cancel_gradients_last_layer(self.current_epoch, self.freeze_last_layer)

initialize_teacher()

Initialize teacher from the state dict of the student one, checking also that student model requires greadient correctly.

Source code in quadra/modules/ssl/dino.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def initialize_teacher(self):
    """Initialize teacher from the state dict of the student one,
    checking also that student model requires greadient correctly.
    """
    self.teacher_projection_mlp.load_state_dict(self.student_projection_mlp.state_dict())
    for p in self.teacher_projection_mlp.parameters():
        p.requires_grad = False

    self.teacher.load_state_dict(self.model.state_dict())
    for p in self.teacher.parameters():
        p.requires_grad = False

    all_frozen = True
    for p in self.model.parameters():
        all_frozen = all_frozen and (not p.requires_grad)

    if all_frozen:
        log.warning(
            "All parameters of the student model are frozen, the model will not be trained, automatically"
            " unfreezing all the layers"
        )

        for p in self.model.parameters():
            p.requires_grad = True

    for name, p in self.student_projection_mlp.named_parameters():
        if name != "last_layer.weight_g":
            assert p.requires_grad is True

    self.teacher_initialized = True

optimizer_step(epoch, batch_idx, optimizer, optimizer_closure=None)

Override optimizer step to update the teacher parameters.

Source code in quadra/modules/ssl/dino.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def optimizer_step(
    self,
    epoch: int,
    batch_idx: int,
    optimizer: Optimizer | LightningOptimizer,
    optimizer_closure: Callable[[], Any] | None = None,
) -> None:
    """Override optimizer step to update the teacher parameters."""
    super().optimizer_step(
        epoch,
        batch_idx,
        optimizer,
        optimizer_closure=optimizer_closure,
    )
    self.update_teacher()

student_multicrop_forward(x)

Student forward on the multicrop imges.

Parameters:

  • x (list[Tensor]) –

    List of torch.Tensor containing multicropped augmented images

Returns:

  • Tensor

    torch.Tensor: a tensor of shape NxBxD, where N is the number crops corresponding to the length of the input list x, B is the batch size and D is the output dimension

Source code in quadra/modules/ssl/dino.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def student_multicrop_forward(self, x: list[torch.Tensor]) -> torch.Tensor:
    """Student forward on the multicrop imges.

    Args:
        x: List of torch.Tensor containing multicropped augmented images

    Returns:
        torch.Tensor: a tensor of shape NxBxD, where N is the number crops
            corresponding to the length of the input list `x`, B is the batch size
            and D is the output dimension
    """
    n_crops = len(x)
    concatenated = torch.cat(x, dim=0)  # (n_samples * n_crops, C, H, W)
    embedding = self.model(concatenated)  # (n_samples * n_crops, in_dim)
    logits = self.student_projection_mlp(embedding)  # (n_samples * n_crops, out_dim)
    chunks = logits.chunk(n_crops)  # n_crops * (n_samples, out_dim)
    return chunks

teacher_multicrop_forward(x)

Teacher forward on the multicrop imges.

Parameters:

  • x (list[Tensor]) –

    List of torch.Tensor containing multicropped augmented images

Returns:

  • Tensor

    torch.Tensor: a tensor of shape NxBxD, where N is the number crops corresponding to the length of the input list x, B is the batch size and D is the output dimension

Source code in quadra/modules/ssl/dino.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def teacher_multicrop_forward(self, x: list[torch.Tensor]) -> torch.Tensor:
    """Teacher forward on the multicrop imges.

    Args:
        x: List of torch.Tensor containing multicropped augmented images

    Returns:
        torch.Tensor: a tensor of shape NxBxD, where N is the number crops
            corresponding to the length of the input list `x`, B is the batch size
            and D is the output dimension
    """
    n_crops = len(x)
    concatenated = torch.cat(x, dim=0)  # (n_samples * n_crops, C, H, W)
    embedding = self.teacher(concatenated)  # (n_samples * n_crops, in_dim)
    logits = self.teacher_projection_mlp(embedding)  # (n_samples * n_crops, out_dim)
    chunks = logits.chunk(n_crops)  # n_crops * (n_samples, out_dim)
    return chunks

IDMM(model, prediction_mlp, criterion, multiview_loss=True, mixup_fn=None, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')

Bases: SSLModule

IDMM model.

Parameters:

  • model (Module) –

    backbone model

  • prediction_mlp (Module) –

    student prediction MLP

  • criterion (Module) –

    loss function

  • multiview_loss (bool, default: True ) –

    whether to use the multiview loss as definied in https://arxiv.org/abs/2201.10728. Defaults to True.

  • mixup_fn (Mixup | None, default: None ) –

    the mixup/cutmix function to be applied to a batch of images. Defaults to None.

  • classifier (ClassifierMixin | None, default: None ) –

    Standard sklearn classifier

  • optimizer (Optimizer | None, default: None ) –

    optimizer of the training. If None a default Adam is used.

  • lr_scheduler (object | None, default: None ) –

    lr scheduler. If None a default ReduceLROnPlateau is used.

  • lr_scheduler_interval (str | None, default: 'epoch' ) –

    interval at which the lr scheduler is updated.

Source code in quadra/modules/ssl/idmm.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(
    self,
    model: torch.nn.Module,
    prediction_mlp: torch.nn.Module,
    criterion: torch.nn.Module,
    multiview_loss: bool = True,
    mixup_fn: timm.data.Mixup | None = None,
    classifier: sklearn.base.ClassifierMixin | None = None,
    optimizer: torch.optim.Optimizer | None = None,
    lr_scheduler: object | None = None,
    lr_scheduler_interval: str | None = "epoch",
):
    super().__init__(
        model,
        criterion,
        classifier,
        optimizer,
        lr_scheduler,
        lr_scheduler_interval,
    )
    # self.save_hyperparameters()
    self.prediction_mlp = prediction_mlp
    self.mixup_fn = mixup_fn
    self.multiview_loss = multiview_loss

SimCLR(model, projection_mlp, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')

Bases: SSLModule

SIMCLR class.

Parameters:

  • model (Module) –

    Feature extractor as pytorch torch.nn.Module

  • projection_mlp (Module) –

    projection head as pytorch torch.nn.Module

  • criterion (Module) –

    SSL loss to be applied

  • classifier (ClassifierMixin | None, default: None ) –

    Standard sklearn classifier. Defaults to None.

  • optimizer (Optimizer | None, default: None ) –

    optimizer of the training. If None a default Adam is used. Defaults to None.

  • lr_scheduler (object | None, default: None ) –

    lr scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.

  • lr_scheduler_interval (str | None, default: 'epoch' ) –

    interval at which the lr scheduler is updated. Defaults to "epoch".

Source code in quadra/modules/ssl/simclr.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def __init__(
    self,
    model: nn.Module,
    projection_mlp: nn.Module,
    criterion: torch.nn.Module,
    classifier: sklearn.base.ClassifierMixin | None = None,
    optimizer: torch.optim.Optimizer | None = None,
    lr_scheduler: object | None = None,
    lr_scheduler_interval: str | None = "epoch",
):
    super().__init__(
        model,
        criterion,
        classifier,
        optimizer,
        lr_scheduler,
        lr_scheduler_interval,
    )
    self.projection_mlp = projection_mlp

training_step(batch, batch_idx)

Parameters:

Returns:

  • Tensor

    The computed loss

Source code in quadra/modules/ssl/simclr.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def training_step(
    self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int
) -> torch.Tensor:
    """Args:
        batch: The batch of data
        batch_idx: The index of the batch.

    Returns:
        The computed loss
    """
    # pylint: disable=unused-argument
    (im_x, im_y), _ = batch
    emb_x = self(im_x)
    emb_y = self(im_y)
    loss = self.criterion(emb_x, emb_y)

    self.log(
        "loss",
        loss,
        on_epoch=True,
        on_step=True,
        logger=True,
        prog_bar=True,
    )
    return loss

SimSIAM(model, projection_mlp, prediction_mlp, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')

Bases: SSLModule

SimSIAM model.

Parameters:

  • model (Module) –

    Feature extractor as pytorch torch.nn.Module

  • projection_mlp (Module) –

    optional projection head as pytorch torch.nn.Module

  • prediction_mlp (Module) –

    optional predicition head as pytorch torch.nn.Module

  • criterion (Module) –

    loss to be applied.

  • classifier (ClassifierMixin | None, default: None ) –

    Standard sklearn classifier.

  • optimizer (Optimizer | None, default: None ) –

    optimizer of the training. If None a default Adam is used.

  • lr_scheduler (object | None, default: None ) –

    lr scheduler. If None a default ReduceLROnPlateau is used.

  • lr_scheduler_interval (str | None, default: 'epoch' ) –

    interval at which the lr scheduler is updated.

Source code in quadra/modules/ssl/simsiam.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def __init__(
    self,
    model: torch.nn.Module,
    projection_mlp: torch.nn.Module,
    prediction_mlp: torch.nn.Module,
    criterion: torch.nn.Module,
    classifier: sklearn.base.ClassifierMixin | None = None,
    optimizer: torch.optim.Optimizer | None = None,
    lr_scheduler: object | None = None,
    lr_scheduler_interval: str | None = "epoch",
):
    super().__init__(
        model,
        criterion,
        classifier,
        optimizer,
        lr_scheduler,
        lr_scheduler_interval,
    )
    # self.save_hyperparameters()
    self.projection_mlp = projection_mlp
    self.prediction_mlp = prediction_mlp

VICReg(model, projection_mlp, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')

Bases: SSLModule

VICReg model.

Parameters:

  • model (Module) –

    Network Module used for extract features

  • projection_mlp (Module) –

    Module to project extracted features

  • criterion (Module) –

    SSL loss to be applied.

  • classifier (ClassifierMixin | None, default: None ) –

    Standard sklearn classifier. Defaults to None.

  • optimizer (Optimizer | None, default: None ) –

    optimizer of the training. If None a default Adam is used. Defaults to None.

  • lr_scheduler (object | None, default: None ) –

    lr scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.

  • lr_scheduler_interval (str | None, default: 'epoch' ) –

    interval at which the lr scheduler is updated. Defaults to "epoch".

Source code in quadra/modules/ssl/vicreg.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def __init__(
    self,
    model: nn.Module,
    projection_mlp: nn.Module,
    criterion: nn.Module,
    classifier: sklearn.base.ClassifierMixin | None = None,
    optimizer: optim.Optimizer | None = None,
    lr_scheduler: object | None = None,
    lr_scheduler_interval: str | None = "epoch",
):
    super().__init__(
        model,
        criterion,
        classifier,
        optimizer,
        lr_scheduler,
        lr_scheduler_interval,
    )
    # self.save_hyperparameters()
    self.projection_mlp = projection_mlp
    self.criterion = criterion