Skip to content

base

BaseLightningModule(model, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')

Bases: LightningModule

Base lightning module.

Parameters:

  • model (Module) –

    Network Module used for extract features

  • optimizer (Optimizer | None, default: None ) –

    optimizer of the training. If None a default Adam is used.

  • lr_scheduler (object | None, default: None ) –

    lr scheduler. If None a default ReduceLROnPlateau is used.

Source code in quadra/modules/base.py
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(
    self,
    model: nn.Module,
    optimizer: Optimizer | None = None,
    lr_scheduler: object | None = None,
    lr_scheduler_interval: str | None = "epoch",
):
    super().__init__()
    self.model = ModelSignatureWrapper(model)
    self.optimizer = optimizer
    self.schedulers = lr_scheduler
    self.lr_scheduler_interval = lr_scheduler_interval

configure_optimizers()

Get default optimizer if not passed a value.

Returns:

  • tuple[list[Any], list[dict[str, Any]]]

    optimizer and lr scheduler as Tuple containing a list of optimizers and a list of lr schedulers

Source code in quadra/modules/base.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def configure_optimizers(self) -> tuple[list[Any], list[dict[str, Any]]]:
    """Get default optimizer if not passed a value.

    Returns:
        optimizer and lr scheduler as Tuple containing a list of optimizers and a list of lr schedulers
    """
    # get default optimizer
    if getattr(self, "optimizer", None) is None or not self.optimizer:
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=3e-4)

    # get default scheduler
    if getattr(self, "schedulers", None) is None or not self.schedulers:
        self.schedulers = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.5, patience=30)

    lr_scheduler_conf = {
        "scheduler": self.schedulers,
        "interval": self.lr_scheduler_interval,
        "monitor": "val_loss",
        "strict": False,
    }
    return [self.optimizer], [lr_scheduler_conf]

forward(x)

Forward method Args: x: input tensor.

Returns:

Source code in quadra/modules/base.py
41
42
43
44
45
46
47
48
49
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward method
    Args:
        x: input tensor.

    Returns:
        model inference
    """
    return self.model(x)

optimizer_zero_grad(epoch, batch_idx, optimizer, optimizer_idx=0)

Redefine optimizer zero grad.

Source code in quadra/modules/base.py
74
75
76
def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx: int = 0):
    """Redefine optimizer zero grad."""
    optimizer.zero_grad(set_to_none=True)

SSLModule(model, criterion, classifier=None, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch')

Bases: BaseLightningModule

Base module for self supervised learning.

Parameters:

  • model (Module) –

    Network Module used for extract features

  • criterion (Module) –

    SSL loss to be applied

  • classifier (ClassifierMixin | None, default: None ) –

    Standard sklearn classifiers

  • optimizer (Optimizer | None, default: None ) –

    optimizer of the training. If None a default Adam is used.

  • lr_scheduler (object | None, default: None ) –

    lr scheduler. If None a default ReduceLROnPlateau is used.

Source code in quadra/modules/base.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def __init__(
    self,
    model: nn.Module,
    criterion: nn.Module,
    classifier: sklearn.base.ClassifierMixin | None = None,
    optimizer: Optimizer | None = None,
    lr_scheduler: object | None = None,
    lr_scheduler_interval: str | None = "epoch",
):
    super().__init__(model, optimizer, lr_scheduler, lr_scheduler_interval)
    self.criterion = criterion
    self.classifier_train_loader: torch.utils.data.DataLoader | None
    if classifier is None:
        self.classifier = LogisticRegression(max_iter=10000, n_jobs=8, random_state=42)
    else:
        self.classifier = classifier

    self.val_acc = torchmetrics.Accuracy()

calculate_accuracy(batch)

Calculate accuracy on a batch of data.

Source code in quadra/modules/base.py
123
124
125
126
127
128
129
130
131
132
133
def calculate_accuracy(self, batch):
    """Calculate accuracy on a batch of data."""
    images, labels = batch
    with torch.no_grad():
        embedding = self.model(images).cpu().numpy()

    predictions = self.classifier.predict(embedding)
    labels = labels.detach()
    acc = self.val_acc(torch.tensor(predictions, device=self.device), labels)

    return acc

fit_estimator()

Fit a classifier on the embeddings extracted from the current trained model.

Source code in quadra/modules/base.py
109
110
111
112
113
114
115
116
117
118
119
120
121
def fit_estimator(self):
    """Fit a classifier on the embeddings extracted from the current trained model."""
    targets = []
    train_embeddings = []
    self.model.eval()
    with torch.no_grad():
        for im, target in self.classifier_train_loader:
            emb = self.model(im.to(self.device))
            targets.append(target)
            train_embeddings.append(emb)
    targets = torch.cat(targets, dim=0).cpu().numpy()
    train_embeddings = torch.cat(train_embeddings, dim=0).cpu().numpy()
    self.classifier.fit(train_embeddings, targets)

SegmentationModel(model, loss_fun, optimizer=None, lr_scheduler=None)

Bases: BaseLightningModule

Generic segmentation model.

Parameters:

  • model (Module) –

    segmentation model to be used.

  • loss_fun (Callable) –

    loss function to be used.

  • optimizer (Optimizer | None, default: None ) –

    Optimizer to be used. Defaults to None.

  • lr_scheduler (object | None, default: None ) –

    lr scheduler to be used. Defaults to None.

Source code in quadra/modules/base.py
177
178
179
180
181
182
183
184
185
def __init__(
    self,
    model: torch.nn.Module,
    loss_fun: Callable,
    optimizer: Optimizer | None = None,
    lr_scheduler: object | None = None,
):
    super().__init__(model, optimizer, lr_scheduler)
    self.loss_fun = loss_fun

compute_loss(pred_masks, target_masks)

Compute loss Args: pred_masks: predicted masks target_masks: target masks.

Returns:

  • Tensor

    The computed loss

Source code in quadra/modules/base.py
216
217
218
219
220
221
222
223
224
225
226
227
def compute_loss(self, pred_masks: torch.Tensor, target_masks: torch.Tensor) -> torch.Tensor:
    """Compute loss
    Args:
        pred_masks: predicted masks
        target_masks: target masks.

    Returns:
        The computed loss

    """
    loss = self.loss_fun(pred_masks, target_masks)
    return loss

forward(x)

Forward method Args: x: input tensor.

Returns:

Source code in quadra/modules/base.py
187
188
189
190
191
192
193
194
195
196
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward method
    Args:
        x: input tensor.

    Returns:
        model inference
    """
    x = self.model(x)
    return x

predict_step(batch, batch_idx, dataloader_idx=None)

Predict step.

Source code in quadra/modules/base.py
268
269
270
271
272
273
274
275
276
277
278
def predict_step(
    self,
    batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
    batch_idx: int,
    dataloader_idx: int | None = None,
) -> Any:
    """Predict step."""
    # pylint: disable=unused-argument
    images, masks, labels = batch
    pred_masks = self(images)
    return images.cpu(), masks.cpu(), pred_masks.cpu(), labels.cpu()

step(batch)

Compute loss Args: batch: batch.

Returns:

Source code in quadra/modules/base.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute loss
    Args:
        batch: batch.

    Returns:
        Prediction and target masks
    """
    images, target_masks, _ = batch
    pred_masks = self(images)
    if len(pred_masks.shape) == 3:
        pred_masks = pred_masks.unsqueeze(1)
    if len(target_masks.shape) == 3:
        target_masks = target_masks.unsqueeze(1)
    assert pred_masks.shape == target_masks.shape

    return pred_masks, target_masks

test_step(batch, batch_idx)

Test step.

Source code in quadra/modules/base.py
255
256
257
258
259
260
261
262
263
264
265
266
def test_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int):
    """Test step."""
    # pylint: disable=unused-argument
    pred_masks, target_masks = self.step(batch)
    loss = self.compute_loss(pred_masks, target_masks)
    self.log_dict(
        {"test_loss": loss},
        on_step=True,
        on_epoch=True,
        prog_bar=True,
    )
    return loss

training_step(batch, batch_idx)

Training step.

Source code in quadra/modules/base.py
229
230
231
232
233
234
235
236
237
238
239
240
def training_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int):
    """Training step."""
    # pylint: disable=unused-argument
    pred_masks, target_masks = self.step(batch)
    loss = self.compute_loss(pred_masks, target_masks)
    self.log_dict(
        {"loss": loss},
        on_step=True,
        on_epoch=True,
        prog_bar=True,
    )
    return loss

validation_step(batch, batch_idx)

Validation step.

Source code in quadra/modules/base.py
242
243
244
245
246
247
248
249
250
251
252
253
def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx):
    """Validation step."""
    # pylint: disable=unused-argument
    pred_masks, target_masks = self.step(batch)
    loss = self.compute_loss(pred_masks, target_masks)
    self.log_dict(
        {"val_loss": loss},
        on_step=True,
        on_epoch=True,
        prog_bar=True,
    )
    return loss

SegmentationModelMulticlass(model, loss_fun, optimizer=None, lr_scheduler=None)

Bases: SegmentationModel

Generic multiclass segmentation model.

Parameters:

  • model (Module) –

    segmentation model to be used.

  • loss_fun (Callable) –

    loss function to be used.

  • optimizer (Optimizer | None, default: None ) –

    Optimizer to be used. Defaults to None.

  • lr_scheduler (object | None, default: None ) –

    lr scheduler to be used. Defaults to None.

Source code in quadra/modules/base.py
291
292
293
294
295
296
297
298
def __init__(
    self,
    model: torch.nn.Module,
    loss_fun: Callable,
    optimizer: Optimizer | None = None,
    lr_scheduler: object | None = None,
):
    super().__init__(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, loss_fun=loss_fun)

step(batch)

Compute step Args: batch: batch.

Returns:

Source code in quadra/modules/base.py
300
301
302
303
304
305
306
307
308
309
310
311
312
def step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute step
    Args:
        batch: batch.

    Returns:
        prediction, target

    """
    images, target_masks, _ = batch
    pred_masks = self(images)

    return pred_masks, target_masks