Skip to content

common

BYOLPredictionHead(input_dim, hidden_dim, output_dim)

Bases: ProjectionHead

Prediction head used for BYOL.

Source code in quadra/modules/ssl/common.py
171
172
173
174
175
176
177
178
179
180
181
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
    super().__init__(
        [
            (
                torch.nn.Linear(input_dim, hidden_dim, bias=False),
                torch.nn.BatchNorm1d(hidden_dim),
                torch.nn.ReLU(inplace=True),
            ),
            (torch.nn.Linear(hidden_dim, output_dim, bias=False), None, None),
        ]
    )

BYOLProjectionHead(input_dim, hidden_dim, output_dim)

Bases: ProjectionHead

Projection head used for BYOL.

Source code in quadra/modules/ssl/common.py
187
188
189
190
191
192
193
194
195
196
197
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
    super().__init__(
        [
            (
                torch.nn.Linear(input_dim, hidden_dim, bias=False),
                torch.nn.BatchNorm1d(hidden_dim),
                torch.nn.ReLU(inplace=True),
            ),
            (torch.nn.Linear(hidden_dim, output_dim, bias=False), None, None),
        ]
    )

BarlowTwinsProjectionHead(input_dim, hidden_dim, output_dim)

Bases: ProjectionHead

Projection head used for Barlow Twins. "The projector network has three linear layers, each with 8192 output units. The first two layers of the projector are followed by a batch normalization layer and rectified linear units." https://arxiv.org/abs/2103.03230.

Source code in quadra/modules/ssl/common.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
    super().__init__(
        [
            (
                torch.nn.Linear(input_dim, hidden_dim, bias=False),
                torch.nn.BatchNorm1d(hidden_dim),
                torch.nn.ReLU(inplace=True),
            ),
            (
                torch.nn.Linear(hidden_dim, hidden_dim, bias=False),
                torch.nn.BatchNorm1d(hidden_dim),
                torch.nn.ReLU(inplace=True),
            ),
            (torch.nn.Linear(hidden_dim, output_dim, bias=False), None, None),
        ]
    )

DinoProjectionHead(input_dim, output_dim, hidden_dim, use_bn=False, norm_last_layer=True, num_layers=3, bottleneck_dim=256)

Bases: Module

Projection head used for Dino. This projection head does not have a batch norm layer.

Parameters:

  • input_dim (int) –

    Input dimension for MLP head.

  • output_dim (int) –

    Output dimension (projection dimension) for MLP head.

  • hidden_dim (int) –

    Hidden dimension. Defaults to 512.

  • bottleneck_dim (int, default: 256 ) –

    Bottleneck dimension. Defaults to 256.

  • num_layers (int, default: 3 ) –

    Number of hidden layers used in MLP. Defaults to 3.

  • norm_last_layer (bool, default: True ) –

    Decides applying normalization before last layer. Defaults to False.

Source code in quadra/modules/ssl/common.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def __init__(
    self,
    input_dim: int,
    output_dim: int,
    hidden_dim: int,
    use_bn: bool = False,
    norm_last_layer: bool = True,
    num_layers: int = 3,
    bottleneck_dim: int = 256,
):
    super().__init__()
    num_layers = max(num_layers, 1)
    self.mlp: nn.Linear | nn.Sequential
    if num_layers == 1:
        self.mlp = nn.Linear(input_dim, bottleneck_dim)
    else:
        layers: list[nn.Module] = [nn.Linear(input_dim, hidden_dim)]
        if use_bn:
            layers.append(nn.BatchNorm1d(hidden_dim))
        layers.append(nn.GELU())
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            if use_bn:
                layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.GELU())
        layers.append(nn.Linear(hidden_dim, bottleneck_dim))
        self.mlp = nn.Sequential(*layers)
    self.apply(self._init_weights)
    self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, output_dim, bias=False))
    self.last_layer.weight_g.data.fill_(1)
    if norm_last_layer:
        self.last_layer.weight_g.requires_grad = False

ExpanderReducer(input_dim, hidden_dim, output_dim)

Bases: ProjectionHead

Expander followed by a reducer.

Source code in quadra/modules/ssl/common.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
    super().__init__(
        [
            (
                torch.nn.Linear(input_dim, hidden_dim, bias=False),
                torch.nn.BatchNorm1d(hidden_dim),
                torch.nn.ReLU(inplace=True),
            ),
            (
                torch.nn.Linear(hidden_dim, output_dim, bias=False),
                torch.nn.BatchNorm1d(output_dim, affine=False),
                torch.nn.ReLU(inplace=True),
            ),
        ]
    )

MultiCropModel(backbone, head)

Bases: Module

MultiCrop model for DINO augmentation.

It takes 2 global crops and N (possible) local crops as a single tensor.

Parameters:

  • backbone (Module) –

    Backbone model.

  • head (Module) –

    Head model.

Source code in quadra/modules/ssl/common.py
271
272
273
274
def __init__(self, backbone: nn.Module, head: nn.Module):
    super().__init__()
    self.backbone = backbone
    self.head = head

ProjectionHead(blocks)

Bases: Module

Base class for all projection and prediction heads.

Parameters:

  • blocks (list[tuple[Module | None, ...]]) –

    List of tuples, each denoting one block of the projection head MLP. Each tuple reads (linear_layer, batch_norm_layer, non_linearity_layer). batch_norm layer can be possibly None, the same happens for non_linearity_layer.

Source code in quadra/modules/ssl/common.py
20
21
22
23
24
25
26
27
28
29
30
31
def __init__(self, blocks: list[tuple[torch.nn.Module | None, ...]]):
    super().__init__()

    layers: list[nn.Module] = []
    for linear, batch_norm, non_linearity in blocks:
        if linear:
            layers.append(linear)
        if batch_norm:
            layers.append(batch_norm)
        if non_linearity:
            layers.append(non_linearity)
    self.layers = torch.nn.Sequential(*layers)

SimCLRPredictionHead(input_dim, hidden_dim, output_dim)

Bases: ProjectionHead

Prediction head used for SimCLR. "We set g(h) = W(2)σ(W(1)h), with the same input and output dimensionality (i.e. 2048)." https://arxiv.org/abs/2002.05709.

Source code in quadra/modules/ssl/common.py
107
108
109
110
111
112
113
114
115
116
117
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
    super().__init__(
        [
            (
                torch.nn.Linear(input_dim, hidden_dim, bias=False),
                torch.nn.BatchNorm1d(hidden_dim),
                torch.nn.ReLU(inplace=True),
            ),
            (torch.nn.Linear(hidden_dim, output_dim, bias=False), None, None),
        ]
    )

SimCLRProjectionHead(input_dim, hidden_dim, output_dim)

Bases: ProjectionHead

Projection head used for SimCLR. "We use a MLP with one hidden layer to obtain zi = g(h) = W_2 * σ(W_1 * h) where σ is a ReLU non-linearity." https://arxiv.org/abs/2002.05709.

Source code in quadra/modules/ssl/common.py
88
89
90
91
92
93
94
95
96
97
98
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
    super().__init__(
        [
            (
                torch.nn.Linear(input_dim, hidden_dim),
                None,
                torch.nn.ReLU(inplace=True),
            ),
            (torch.nn.Linear(hidden_dim, output_dim), None, None),
        ]
    )

SimSiamPredictionHead(input_dim, hidden_dim, output_dim)

Bases: ProjectionHead

Prediction head used for SimSiam. "The prediction MLP (h) has BN applied to its hidden fc layers. Its output fc does not have BN (...) or ReLU. This MLP has 2 layers." https://arxiv.org/abs/2011.10566.

Source code in quadra/modules/ssl/common.py
155
156
157
158
159
160
161
162
163
164
165
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
    super().__init__(
        [
            (
                torch.nn.Linear(input_dim, hidden_dim, bias=False),
                torch.nn.BatchNorm1d(hidden_dim),
                torch.nn.ReLU(inplace=True),
            ),
            (torch.nn.Linear(hidden_dim, output_dim, bias=False), None, None),
        ]
    )

SimSiamProjectionHead(input_dim, hidden_dim, output_dim)

Bases: ProjectionHead

Projection head used for SimSiam. "The projection MLP (in f) has BN applied to each fully-connected (fc) layer, including its output fc. Its output fc has no ReLU. The hidden fc is 2048-d. This MLP has 3 layers." https://arxiv.org/abs/2011.10566.

Source code in quadra/modules/ssl/common.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
    super().__init__(
        [
            (
                torch.nn.Linear(input_dim, hidden_dim, bias=False),
                torch.nn.BatchNorm1d(hidden_dim),
                torch.nn.ReLU(inplace=True),
            ),
            (
                torch.nn.Linear(hidden_dim, hidden_dim, bias=False),
                torch.nn.BatchNorm1d(hidden_dim, affine=False),
                torch.nn.ReLU(inplace=True),
            ),
            (
                torch.nn.Linear(hidden_dim, output_dim, bias=False),
                torch.nn.BatchNorm1d(output_dim, affine=False),
                None,
            ),
        ]
    )