Skip to content

dino

DinoDistillationLoss(output_dim, max_epochs, warmup_teacher_temp=0.04, teacher_temp=0.07, warmup_teacher_temp_epochs=30, student_temp=0.1, center_momentum=0.9)

Bases: Module

Dino distillation loss module.

Parameters:

  • output_dim (int) –

    output dim.

  • max_epochs (int) –

    max epochs.

  • warmup_teacher_temp (float, default: 0.04 ) –

    warmup temperature.

  • teacher_temp (float, default: 0.07 ) –

    teacher temperature.

  • warmup_teacher_temp_epochs (int, default: 30 ) –

    warmup teacher epocs.

  • student_temp (float, default: 0.1 ) –

    student temperature.

  • center_momentum (float, default: 0.9 ) –

    center momentum.

Source code in quadra/losses/ssl/dino.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def __init__(
    self,
    output_dim: int,
    max_epochs: int,
    warmup_teacher_temp: float = 0.04,
    teacher_temp: float = 0.07,
    warmup_teacher_temp_epochs: int = 30,
    student_temp: float = 0.1,
    center_momentum: float = 0.9,
):
    super().__init__()
    self.student_temp = student_temp
    self.center_momentum = center_momentum
    self.center: torch.Tensor
    # we apply a warm up for the teacher temperature because
    # a too high temperature makes the training instable at the beginning

    if warmup_teacher_temp_epochs >= max_epochs:
        raise ValueError(
            f"Number of warmup epochs ({warmup_teacher_temp_epochs}) must be smaller than max_epochs ({max_epochs})"
        )

    if warmup_teacher_temp_epochs < 30:
        log.warning("Warmup teacher epochs is very small (< 30). This may cause instabilities in the training")

    self.teacher_temp_schedule = np.concatenate(
        (
            np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),
            np.ones(max_epochs - warmup_teacher_temp_epochs) * teacher_temp,
        )
    )
    self.register_buffer("center", torch.zeros(1, output_dim))

forward(current_epoch, student_output, teacher_output)

Runs forward.

Source code in quadra/losses/ssl/dino.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def forward(
    self,
    current_epoch: int,
    student_output: torch.Tensor,
    teacher_output: torch.Tensor,
) -> torch.Tensor:
    """Runs forward."""
    teacher_temp = self.teacher_temp_schedule[current_epoch]
    loss = dino_distillation_loss(
        student_output,
        teacher_output,
        center_vector=self.center,
        teacher_temp=teacher_temp,
        student_temp=self.student_temp,
    )

    self.update_center(teacher_output)
    return loss

update_center(teacher_output)

Update center of the distribution of the teacher Args: teacher_output: teacher output.

Returns:

  • None

    None

Source code in quadra/losses/ssl/dino.py
117
118
119
120
121
122
123
124
125
126
127
128
129
@torch.no_grad()
def update_center(self, teacher_output: torch.Tensor) -> None:
    """Update center of the distribution of the teacher
    Args:
        teacher_output: teacher output.

    Returns:
        None
    """
    # TODO: check if this is correct
    # torch.cat expects a list of tensors but teacher_output is a tensor
    batch_center = torch.cat(teacher_output).mean(dim=0, keepdim=True)  # type: ignore[call-overload]
    self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)

dino_distillation_loss(student_output, teacher_output, center_vector, teacher_temp=0.04, student_temp=0.1)

Compute the DINO distillation loss.

Parameters:

  • student_output (Tensor) –

    tensor of the student output

  • teacher_output (Tensor) –

    tensor of the teacher output

  • center_vector (Tensor) –

    center vector of distribution

  • teacher_temp (float, default: 0.04 ) –

    temperature teacher

  • student_temp (float, default: 0.1 ) –

    temperature student.

Returns:

  • Tensor

    The computed loss

Source code in quadra/losses/ssl/dino.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def dino_distillation_loss(
    student_output: torch.Tensor,
    teacher_output: torch.Tensor,
    center_vector: torch.Tensor,
    teacher_temp: float = 0.04,
    student_temp: float = 0.1,
) -> torch.Tensor:
    """Compute the DINO distillation loss.

    Args:
        student_output: tensor of the student output
        teacher_output: tensor of the teacher output
        center_vector: center vector of distribution
        teacher_temp: temperature teacher
        student_temp: temperature student.

    Returns:
        The computed loss
    """
    student_temp = [s / student_temp for s in student_output]
    teacher_temp = [(t - center_vector) / teacher_temp for t in teacher_output]

    student_sm = [F.log_softmax(s, dim=-1) for s in student_temp]
    teacher_sm = [F.softmax(t, dim=-1).detach() for t in teacher_temp]

    total_loss = torch.tensor(0.0, device=student_output[0].device)
    n_loss_terms = torch.tensor(0.0, device=student_output[0].device)

    for t_ix, t in enumerate(teacher_sm):
        for s_ix, s in enumerate(student_sm):
            if t_ix == s_ix:
                continue

            loss = torch.sum(-t * s, dim=-1)  # (n_samples,)
            total_loss += loss.mean()  # scalar
            n_loss_terms += 1

    total_loss /= n_loss_terms
    return total_loss