Skip to content

byol

BYOL(student, teacher, student_projection_mlp, student_prediction_mlp, teacher_projection_mlp, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch', teacher_momentum=0.9995, teacher_momentum_cosine_decay=True)

Bases: SSLModule

BYOL module, inspired by https://arxiv.org/abs/2006.07733.

Parameters:

  • student

    student model.

  • teacher

    teacher model.

  • student_projection_mlp

    student projection MLP.

  • student_prediction_mlp

    student prediction MLP.

  • teacher_projection_mlp

    teacher projection MLP.

  • criterion

    loss function.

  • classifier (ClassifierMixin | None, default: None ) –

    Standard sklearn classifier.

  • optimizer (Optimizer | None, default: None ) –

    optimizer of the training. If None a default Adam is used.

  • lr_scheduler (object | None, default: None ) –

    lr scheduler. If None a default ReduceLROnPlateau is used.

  • lr_scheduler_interval (str | None, default: 'epoch' ) –

    interval at which the lr scheduler is updated.

  • teacher_momentum (float, default: 0.9995 ) –

    momentum of the teacher parameters.

  • teacher_momentum_cosine_decay (bool | None, default: True ) –

    whether to use cosine decay for the teacher momentum. Default: True

Source code in quadra/modules/ssl/byol.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
    self,
    student: nn.Module,
    teacher: nn.Module,
    student_projection_mlp: nn.Module,
    student_prediction_mlp: nn.Module,
    teacher_projection_mlp: nn.Module,
    criterion: nn.Module,
    classifier: sklearn.base.ClassifierMixin | None = None,
    optimizer: Optimizer | None = None,
    lr_scheduler: object | None = None,
    lr_scheduler_interval: str | None = "epoch",
    teacher_momentum: float = 0.9995,
    teacher_momentum_cosine_decay: bool | None = True,
):
    super().__init__(
        model=student,
        criterion=criterion,
        classifier=classifier,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        lr_scheduler_interval=lr_scheduler_interval,
    )
    # Student model
    self.max_steps: int
    self.student_projection_mlp = student_projection_mlp
    self.student_prediction_mlp = student_prediction_mlp

    # Teacher model
    self.teacher = teacher
    self.teacher_projection_mlp = teacher_projection_mlp
    self.teacher_initialized = False
    self.teacher_momentum = teacher_momentum
    self.teacher_momentum_cosine_decay = teacher_momentum_cosine_decay

    self.initialize_teacher()

calculate_accuracy(batch)

Calculate accuracy on the given batch.

Source code in quadra/modules/ssl/byol.py
155
156
157
158
159
160
161
162
163
def calculate_accuracy(self, batch):
    """Calculate accuracy on the given batch."""
    images, labels = batch
    embedding = self.model(images).detach().cpu().numpy()
    predictions = self.classifier.predict(embedding)
    labels = labels.detach()
    acc = self.val_acc(torch.tensor(predictions, device=self.device), labels)

    return acc

initialize_teacher()

Initialize teacher from the state dict of the student one, checking also that student model requires greadient correctly.

Source code in quadra/modules/ssl/byol.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def initialize_teacher(self):
    """Initialize teacher from the state dict of the student one,
    checking also that student model requires greadient correctly.
    """
    self.teacher_projection_mlp.load_state_dict(self.student_projection_mlp.state_dict())
    for p in self.teacher_projection_mlp.parameters():
        p.requires_grad = False

    self.teacher.load_state_dict(self.model.state_dict())
    for p in self.teacher.parameters():
        p.requires_grad = False

    for p in self.student_projection_mlp.parameters():
        assert p.requires_grad is True
    for p in self.student_prediction_mlp.parameters():
        assert p.requires_grad is True

    self.teacher_initialized = True

optimizer_step(epoch, batch_idx, optimizer, optimizer_closure=None)

Override optimizer step to update the teacher parameters.

Source code in quadra/modules/ssl/byol.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def optimizer_step(
    self,
    epoch: int,
    batch_idx: int,
    optimizer: Optimizer | LightningOptimizer,
    optimizer_closure: Callable[[], Any] | None = None,
) -> None:
    """Override optimizer step to update the teacher parameters."""
    super().optimizer_step(
        epoch,
        batch_idx,
        optimizer,
        optimizer_closure=optimizer_closure,
    )
    self.update_teacher()

test_step(batch, *args)

Calculate accuracy on the test set for the given batch.

Source code in quadra/modules/ssl/byol.py
168
169
170
171
172
def test_step(self, batch, *args: list[Any]) -> None:
    """Calculate accuracy on the test set for the given batch."""
    acc = self.calculate_accuracy(batch)
    self.log(name="test_acc", value=acc, on_step=False, on_epoch=True, prog_bar=True)
    return acc

update_teacher()

Update teacher given self.teacher_momentum by an exponential moving average of the student parameters, that is: theta_t * tau + theta_s * (1 - tau), where theta_{s,t} are the parameters of the student and the teacher model, while tau is the teacher momentum. If self.teacher_momentum_cosine_decay is True, then the teacher momentum will follow a cosine scheduling from self.teacher_momentum to 1: tau = 1 - (1 - tau) * (cos(pi * t / T) + 1) / 2, where t is the current step and T is the max number of steps.

Source code in quadra/modules/ssl/byol.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def update_teacher(self):
    """Update teacher given `self.teacher_momentum` by an exponential moving average
    of the student parameters, that is: theta_t * tau + theta_s * (1 - tau), where
    `theta_{s,t}` are the parameters of the student and the teacher model, while `tau` is the
    teacher momentum. If `self.teacher_momentum_cosine_decay` is True, then the teacher
    momentum will follow a cosine scheduling from `self.teacher_momentum` to 1:
    tau = 1 - (1 - tau) * (cos(pi * t / T) + 1) / 2, where `t` is the current step and
    `T` is the max number of steps.
    """
    with torch.no_grad():
        if self.teacher_momentum_cosine_decay:
            teacher_momentum = (
                1
                - (1 - self.teacher_momentum)
                * (math.cos(math.pi * self.trainer.global_step / self.max_steps) + 1)
                / 2
            )
        else:
            teacher_momentum = self.teacher_momentum
        self.log("teacher_momentum", teacher_momentum, prog_bar=True)
        for student_ps, teacher_ps in zip(
            list(self.model.parameters()) + list(self.student_projection_mlp.parameters()),
            list(self.teacher.parameters()) + list(self.teacher_projection_mlp.parameters()),
        ):
            teacher_ps.data = teacher_ps.data * teacher_momentum + (1 - teacher_momentum) * student_ps.data