Skip to content

dino

Dino(student, teacher, student_projection_mlp, teacher_projection_mlp, criterion, freeze_last_layer=1, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch', teacher_momentum=0.9995, teacher_momentum_cosine_decay=True)

Bases: BYOL

DINO pytorch-lightning module.

Parameters:

  • student

    student model

  • teacher

    teacher model

  • student_projection_mlp

    student projection MLP

  • teacher_projection_mlp

    teacher projection MLP

  • criterion

    loss function

  • freeze_last_layer

    number of layers to freeze in the student model. Default: 1

  • classifier (ClassifierMixin | None, default: None ) –

    Standard sklearn classifier

  • optimizer (Optimizer | None, default: None ) –

    optimizer of the training. If None a default Adam is used.

  • lr_scheduler (object | None, default: None ) –

    lr scheduler. If None a default ReduceLROnPlateau is used.

  • lr_scheduler_interval (str | None, default: 'epoch' ) –

    interval at which the lr scheduler is updated.

  • teacher_momentum (float, default: 0.9995 ) –

    momentum of the teacher parameters

  • teacher_momentum_cosine_decay (bool | None, default: True ) –

    whether to use cosine decay for the teacher momentum

Source code in quadra/modules/ssl/dino.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(
    self,
    student: nn.Module,
    teacher: nn.Module,
    student_projection_mlp: nn.Module,
    teacher_projection_mlp: nn.Module,
    criterion: nn.Module,
    freeze_last_layer: int = 1,
    classifier: sklearn.base.ClassifierMixin | None = None,
    optimizer: Optimizer | None = None,
    lr_scheduler: object | None = None,
    lr_scheduler_interval: str | None = "epoch",
    teacher_momentum: float = 0.9995,
    teacher_momentum_cosine_decay: bool | None = True,
):
    super().__init__(
        student=student,
        teacher=teacher,
        student_projection_mlp=student_projection_mlp,
        student_prediction_mlp=nn.Identity(),
        teacher_projection_mlp=teacher_projection_mlp,
        criterion=criterion,
        teacher_momentum=teacher_momentum,
        teacher_momentum_cosine_decay=teacher_momentum_cosine_decay,
        classifier=classifier,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        lr_scheduler_interval=lr_scheduler_interval,
    )
    self.freeze_last_layer = freeze_last_layer

cancel_gradients_last_layer(epoch, freeze_last_layer)

Zero out the gradient of the last layer, as specified in the paper.

Parameters:

  • epoch (int) –

    current epoch

  • freeze_last_layer (int) –

    maximum freeze epoch: if epoch >= freeze_last_layer then the gradient of the last layer will not be freezed

Source code in quadra/modules/ssl/dino.py
135
136
137
138
139
140
141
142
143
144
145
146
147
def cancel_gradients_last_layer(self, epoch: int, freeze_last_layer: int):
    """Zero out the gradient of the last layer, as specified in the paper.

    Args:
        epoch: current epoch
        freeze_last_layer: maximum freeze epoch: if `epoch` >= `freeze_last_layer`
            then the gradient of the last layer will not be freezed
    """
    if epoch >= freeze_last_layer:
        return
    for n, p in self.student_projection_mlp.named_parameters():
        if "last_layer" in n:
            p.grad = None

configure_gradient_clipping(optimizer, gradient_clip_val=None, gradient_clip_algorithm=None)

Configure gradient clipping for the optimizer.

Source code in quadra/modules/ssl/dino.py
160
161
162
163
164
165
166
167
168
169
170
def configure_gradient_clipping(
    self,
    optimizer: Optimizer,
    gradient_clip_val: int | float | None = None,
    gradient_clip_algorithm: str | None = None,
):
    """Configure gradient clipping for the optimizer."""
    if gradient_clip_algorithm is not None and gradient_clip_val is not None:
        clip_gradients(self.model, gradient_clip_val)
        clip_gradients(self.student_projection_mlp, gradient_clip_val)
    self.cancel_gradients_last_layer(self.current_epoch, self.freeze_last_layer)

initialize_teacher()

Initialize teacher from the state dict of the student one, checking also that student model requires greadient correctly.

Source code in quadra/modules/ssl/dino.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def initialize_teacher(self):
    """Initialize teacher from the state dict of the student one,
    checking also that student model requires greadient correctly.
    """
    self.teacher_projection_mlp.load_state_dict(self.student_projection_mlp.state_dict())
    for p in self.teacher_projection_mlp.parameters():
        p.requires_grad = False

    self.teacher.load_state_dict(self.model.state_dict())
    for p in self.teacher.parameters():
        p.requires_grad = False

    all_frozen = True
    for p in self.model.parameters():
        all_frozen = all_frozen and (not p.requires_grad)

    if all_frozen:
        log.warning(
            "All parameters of the student model are frozen, the model will not be trained, automatically"
            " unfreezing all the layers"
        )

        for p in self.model.parameters():
            p.requires_grad = True

    for name, p in self.student_projection_mlp.named_parameters():
        if name != "last_layer.weight_g":
            assert p.requires_grad is True

    self.teacher_initialized = True

optimizer_step(epoch, batch_idx, optimizer, optimizer_closure=None)

Override optimizer step to update the teacher parameters.

Source code in quadra/modules/ssl/dino.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def optimizer_step(
    self,
    epoch: int,
    batch_idx: int,
    optimizer: Optimizer | LightningOptimizer,
    optimizer_closure: Callable[[], Any] | None = None,
) -> None:
    """Override optimizer step to update the teacher parameters."""
    super().optimizer_step(
        epoch,
        batch_idx,
        optimizer,
        optimizer_closure=optimizer_closure,
    )
    self.update_teacher()

student_multicrop_forward(x)

Student forward on the multicrop imges.

Parameters:

  • x (list[Tensor]) –

    List of torch.Tensor containing multicropped augmented images

Returns:

  • Tensor

    torch.Tensor: a tensor of shape NxBxD, where N is the number crops corresponding to the length of the input list x, B is the batch size and D is the output dimension

Source code in quadra/modules/ssl/dino.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def student_multicrop_forward(self, x: list[torch.Tensor]) -> torch.Tensor:
    """Student forward on the multicrop imges.

    Args:
        x: List of torch.Tensor containing multicropped augmented images

    Returns:
        torch.Tensor: a tensor of shape NxBxD, where N is the number crops
            corresponding to the length of the input list `x`, B is the batch size
            and D is the output dimension
    """
    n_crops = len(x)
    concatenated = torch.cat(x, dim=0)  # (n_samples * n_crops, C, H, W)
    embedding = self.model(concatenated)  # (n_samples * n_crops, in_dim)
    logits = self.student_projection_mlp(embedding)  # (n_samples * n_crops, out_dim)
    chunks = logits.chunk(n_crops)  # n_crops * (n_samples, out_dim)
    return chunks

teacher_multicrop_forward(x)

Teacher forward on the multicrop imges.

Parameters:

  • x (list[Tensor]) –

    List of torch.Tensor containing multicropped augmented images

Returns:

  • Tensor

    torch.Tensor: a tensor of shape NxBxD, where N is the number crops corresponding to the length of the input list x, B is the batch size and D is the output dimension

Source code in quadra/modules/ssl/dino.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def teacher_multicrop_forward(self, x: list[torch.Tensor]) -> torch.Tensor:
    """Teacher forward on the multicrop imges.

    Args:
        x: List of torch.Tensor containing multicropped augmented images

    Returns:
        torch.Tensor: a tensor of shape NxBxD, where N is the number crops
            corresponding to the length of the input list `x`, B is the batch size
            and D is the output dimension
    """
    n_crops = len(x)
    concatenated = torch.cat(x, dim=0)  # (n_samples * n_crops, C, H, W)
    embedding = self.teacher(concatenated)  # (n_samples * n_crops, in_dim)
    logits = self.teacher_projection_mlp(embedding)  # (n_samples * n_crops, out_dim)
    chunks = logits.chunk(n_crops)  # n_crops * (n_samples, out_dim)
    return chunks