Skip to content

ssl

BYOLRegressionLoss

Bases: Module

BYOL regression loss module.

forward(x, y)

Compute the BYOL regression loss.

Parameters:

  • x (Tensor) –

    First Tensor

  • y (Tensor) –

    Second Tensor

Returns:

  • Tensor

    BYOL regression loss

Source code in quadra/losses/ssl/byol.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def forward(
    self,
    x: torch.Tensor,
    y: torch.Tensor,
) -> torch.Tensor:
    """Compute the BYOL regression loss.

    Args:
        x: First Tensor
        y: Second Tensor

    Returns:
        BYOL regression loss
    """
    return byol_regression_loss(x, y)

BarlowTwinsLoss(lambd)

Bases: Module

BarlowTwin loss.

Parameters:

  • lambd (float) –

    lambda of the loss.

Source code in quadra/losses/ssl/barlowtwins.py
41
42
43
def __init__(self, lambd: float):
    super().__init__()
    self.lambd = lambd

forward(z1, z2)

Compute the BarlowTwins loss.

Source code in quadra/losses/ssl/barlowtwins.py
45
46
47
def forward(self, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
    """Compute the BarlowTwins loss."""
    return barlowtwins_loss(z1, z2, self.lambd)

DinoDistillationLoss(output_dim, max_epochs, warmup_teacher_temp=0.04, teacher_temp=0.07, warmup_teacher_temp_epochs=30, student_temp=0.1, center_momentum=0.9)

Bases: Module

Dino distillation loss module.

Parameters:

  • output_dim (int) –

    output dim.

  • max_epochs (int) –

    max epochs.

  • warmup_teacher_temp (float, default: 0.04 ) –

    warmup temperature.

  • teacher_temp (float, default: 0.07 ) –

    teacher temperature.

  • warmup_teacher_temp_epochs (int, default: 30 ) –

    warmup teacher epocs.

  • student_temp (float, default: 0.1 ) –

    student temperature.

  • center_momentum (float, default: 0.9 ) –

    center momentum.

Source code in quadra/losses/ssl/dino.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def __init__(
    self,
    output_dim: int,
    max_epochs: int,
    warmup_teacher_temp: float = 0.04,
    teacher_temp: float = 0.07,
    warmup_teacher_temp_epochs: int = 30,
    student_temp: float = 0.1,
    center_momentum: float = 0.9,
):
    super().__init__()
    self.student_temp = student_temp
    self.center_momentum = center_momentum
    self.center: torch.Tensor
    # we apply a warm up for the teacher temperature because
    # a too high temperature makes the training instable at the beginning

    if warmup_teacher_temp_epochs >= max_epochs:
        raise ValueError(
            f"Number of warmup epochs ({warmup_teacher_temp_epochs}) must be smaller than max_epochs ({max_epochs})"
        )

    if warmup_teacher_temp_epochs < 30:
        log.warning("Warmup teacher epochs is very small (< 30). This may cause instabilities in the training")

    self.teacher_temp_schedule = np.concatenate(
        (
            np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),
            np.ones(max_epochs - warmup_teacher_temp_epochs) * teacher_temp,
        )
    )
    self.register_buffer("center", torch.zeros(1, output_dim))

forward(current_epoch, student_output, teacher_output)

Runs forward.

Source code in quadra/losses/ssl/dino.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def forward(
    self,
    current_epoch: int,
    student_output: torch.Tensor,
    teacher_output: torch.Tensor,
) -> torch.Tensor:
    """Runs forward."""
    teacher_temp = self.teacher_temp_schedule[current_epoch]
    loss = dino_distillation_loss(
        student_output,
        teacher_output,
        center_vector=self.center,
        teacher_temp=teacher_temp,
        student_temp=self.student_temp,
    )

    self.update_center(teacher_output)
    return loss

update_center(teacher_output)

Update center of the distribution of the teacher Args: teacher_output: teacher output.

Returns:

  • None

    None

Source code in quadra/losses/ssl/dino.py
117
118
119
120
121
122
123
124
125
126
127
128
129
@torch.no_grad()
def update_center(self, teacher_output: torch.Tensor) -> None:
    """Update center of the distribution of the teacher
    Args:
        teacher_output: teacher output.

    Returns:
        None
    """
    # TODO: check if this is correct
    # torch.cat expects a list of tensors but teacher_output is a tensor
    batch_center = torch.cat(teacher_output).mean(dim=0, keepdim=True)  # type: ignore[call-overload]
    self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)

IDMMLoss(smoothing=0.1)

Bases: Module

IDMM loss described in https://arxiv.org/abs/2201.10728.

Source code in quadra/losses/ssl/idmm.py
28
29
30
def __init__(self, smoothing: float = 0.1):
    super().__init__()
    self.smoothing = smoothing

forward(p1, y1)

IDMM loss described in https://arxiv.org/abs/2201.10728.

Parameters:

  • p1 (Tensor) –

    Prediction labels for z1

  • y1 (Tensor) –

    Instance labels for z1

Returns:

  • Tensor

    IDMM loss

Source code in quadra/losses/ssl/idmm.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def forward(
    self,
    p1: torch.Tensor,
    y1: torch.Tensor,
) -> torch.Tensor:
    """IDMM loss described in https://arxiv.org/abs/2201.10728.

    Args:
        p1: Prediction labels for `z1`
        y1: Instance labels for `z1`

    Returns:
        IDMM loss
    """
    return idmm_loss(
        p1,
        y1,
        self.smoothing,
    )

SimCLRLoss(temperature=1.0)

Bases: Module

SIMCLRloss module.

Parameters:

  • temperature (float, default: 1.0 ) –

    temperature of SIM loss.

Source code in quadra/losses/ssl/simclr.py
61
62
63
def __init__(self, temperature: float = 1.0):
    super().__init__()
    self.temperature = temperature

forward(x1, x2)

Forward pass of the loss.

Source code in quadra/losses/ssl/simclr.py
65
66
67
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
    """Forward pass of the loss."""
    return simclr_loss(x1, x2, temperature=self.temperature)

SimSIAMLoss

Bases: Module

SimSIAM loss module.

forward(p1, p2, z1, z2)

Compute the SimSIAM loss.

Source code in quadra/losses/ssl/simsiam.py
28
29
30
def forward(self, p1: torch.Tensor, p2: torch.Tensor, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
    """Compute the SimSIAM loss."""
    return simsiam_loss(p1, p2, z1, z2)

VICRegLoss(lambd, mu, nu=1, gamma=1)

Bases: Module

VIC regression loss module.

Parameters:

  • lambd (float) –

    lambda multiplier for redundancy term.

  • mu (float) –

    mu multiplier for similarity term.

  • nu (float, default: 1 ) –

    nu multiplier for variance term. Default: 1.

  • gamma (float, default: 1 ) –

    gamma multiplier for covariance term. Default: 1.

Source code in quadra/losses/ssl/vicreg.py
61
62
63
64
65
66
67
68
69
70
71
72
def __init__(
    self,
    lambd: float,
    mu: float,
    nu: float = 1,
    gamma: float = 1,
):
    super().__init__()
    self.lambd = lambd
    self.mu = mu
    self.nu = nu
    self.gamma = gamma

forward(z1, z2)

Computes VICReg loss.

Source code in quadra/losses/ssl/vicreg.py
74
75
76
def forward(self, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
    """Computes VICReg loss."""
    return vicreg_loss(z1, z2, self.lambd, self.mu, self.nu, self.gamma)