Skip to content

classification

ClassificationModule(model, criterion, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch', gradcam=False)

Bases: BaseLightningModule

Lightning module for classification tasks.

Parameters:

  • model (Module) –

    Feature extractor as PyTorch torch.nn.Module

  • criterion (Module) –

    the loss to be applied as a PyTorch torch.nn.Module.

  • optimizer (None | Optimizer, default: None ) –

    optimizer of the training. Defaults to None.

  • lr_scheduler (None | object, default: None ) –

    Pytorch learning rate scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.

  • lr_scheduler_interval (str | None, default: 'epoch' ) –

    the learning rate scheduler interval. Defaults to "epoch".

  • gradcam (bool, default: False ) –

    Whether to compute gradcam during prediction step

Source code in quadra/modules/classification/base.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def __init__(
    self,
    model: nn.Module,
    criterion: nn.Module,
    optimizer: None | optim.Optimizer = None,
    lr_scheduler: None | object = None,
    lr_scheduler_interval: str | None = "epoch",
    gradcam: bool = False,
):
    super().__init__(model, optimizer, lr_scheduler, lr_scheduler_interval)

    self.criterion = criterion
    self.gradcam = gradcam
    self.train_acc = torchmetrics.Accuracy()
    self.val_acc = torchmetrics.Accuracy()
    self.test_acc = torchmetrics.Accuracy()
    self.cam: GradCAM | None = None
    self.grad_rollout: VitAttentionGradRollout | None = None

    if not isinstance(self.model.features_extractor, timm.models.resnet.ResNet) and not is_vision_transformer(
        cast(BaseNetworkBuilder, self.model).features_extractor
    ):
        log.warning(
            "Backbone not compatible with gradcam. Only timm ResNets, timm ViTs and TorchHub dinoViTs supported",
        )
        self.gradcam = False

    self.original_requires_grads: list[bool] = []

on_predict_end()

If we computed gradcam, requires_grad values are reset to original value.

Source code in quadra/modules/classification/base.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def on_predict_end(self) -> None:
    """If we computed gradcam, requires_grad values are reset to original value."""
    if self.gradcam:
        # Get back to initial state
        for i, p in enumerate(self.model.parameters()):
            p.requires_grad = self.original_requires_grads[i]

        # We are using GradCAM package only for resnets at the moment
        if isinstance(self.model.features_extractor, timm.models.resnet.ResNet) and self.cam is not None:
            # Needed to solve jitting bug
            self.cam.activations_and_grads.release()
        elif (
            is_vision_transformer(cast(BaseNetworkBuilder, self.model).features_extractor)
            and self.grad_rollout is not None
        ):
            for handle in self.grad_rollout.f_hook_handles:
                handle.remove()
            for handle in self.grad_rollout.b_hook_handles:
                handle.remove()

    return super().on_predict_end()

on_predict_start()

If gradcam, prepares gradcam and saves params requires_grad state.

Source code in quadra/modules/classification/base.py
159
160
161
162
163
164
165
166
167
def on_predict_start(self) -> None:
    """If gradcam, prepares gradcam and saves params requires_grad state."""
    if self.gradcam:
        # Saving params requires_grad state
        for p in self.model.parameters():
            self.original_requires_grads.append(p.requires_grad)
        self.prepare_gradcam()

    return super().on_predict_start()

predict_step(batch, batch_idx, dataloader_idx=0)

Prediction step.

Parameters:

  • batch (Any) –

    Tuple composed by (image, target)

  • batch_idx (int) –

    Batch index

  • dataloader_idx (int, default: 0 ) –

    Dataloader index

Source code in quadra/modules/classification/base.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
    """Prediction step.

    Args:
        batch: Tuple composed by (image, target)
        batch_idx: Batch index
        dataloader_idx: Dataloader index
    Returns:
        Tuple containing:
            predicted_classes: indexes of predicted classes
            grayscale_cam: gray scale gradcams
    """
    im, _ = batch
    outputs = self(im)
    probs = torch.softmax(outputs, dim=1)
    predicted_classes = torch.max(probs, dim=1).indices.tolist()
    if self.gradcam:
        # inference_mode set to false because gradcam needs gradients
        with torch.inference_mode(False):
            im = im.clone()

            if isinstance(self.model.features_extractor, timm.models.resnet.ResNet) and self.cam:
                grayscale_cam = self.cam(input_tensor=im, targets=None)
            elif (
                is_vision_transformer(cast(BaseNetworkBuilder, self.model).features_extractor) and self.grad_rollout
            ):
                grayscale_cam_low_res = self.grad_rollout(input_tensor=im, targets_list=predicted_classes)
                orig_shape = grayscale_cam_low_res.shape
                new_shape = (orig_shape[0], im.shape[2], im.shape[3])
                zoom_factors = tuple(np.array(new_shape) / np.array(orig_shape))
                grayscale_cam = ndimage.zoom(grayscale_cam_low_res, zoom_factors, order=1)
    else:
        grayscale_cam = None
    return predicted_classes, grayscale_cam, torch.max(probs, dim=1)[0].tolist()

prepare_gradcam()

Instantiate gradcam handlers.

Source code in quadra/modules/classification/base.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def prepare_gradcam(self) -> None:
    """Instantiate gradcam handlers."""
    if isinstance(self.model.features_extractor, timm.models.resnet.ResNet):
        target_layers = [cast(BaseNetworkBuilder, self.model).features_extractor.layer4[-1]]

        # Get model current device
        device = next(self.model.parameters()).device

        self.cam = GradCAM(
            model=self.model,
            target_layers=target_layers,
            use_cuda=device.type == "cuda",
        )
        # Activating gradients
        for p in self.model.features_extractor.layer4[-1].parameters():
            p.requires_grad = True
    elif is_vision_transformer(cast(BaseNetworkBuilder, self.model).features_extractor):
        self.grad_rollout = VitAttentionGradRollout(self.model)
    else:
        log.warning("Gradcam not implemented for this backbone, it won't be computed")
        self.original_requires_grads.clear()
        self.gradcam = False

MultilabelClassificationModule(model, criterion, optimizer=None, lr_scheduler=None, lr_scheduler_interval='epoch', gradcam=False)

Bases: BaseLightningModule

SklearnClassification model: train a generic SklearnClassification model for a multilabel problem.

Parameters:

  • model (Sequential) –

    Feature extractor as PyTorch torch.nn.Module

  • criterion (Module) –

    the loss to be applied as a PyTorch torch.nn.Module.

  • optimizer (None | Optimizer, default: None ) –

    optimizer of the training. Defaults to None.

  • lr_scheduler (None | object, default: None ) –

    Pytorch learning rate scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.

  • lr_scheduler_interval (str | None, default: 'epoch' ) –

    the learning rate scheduler interval. Defaults to "epoch".

Source code in quadra/modules/classification/base.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
def __init__(
    self,
    model: nn.Sequential,
    criterion: nn.Module,
    optimizer: None | optim.Optimizer = None,
    lr_scheduler: None | object = None,
    lr_scheduler_interval: str | None = "epoch",
    gradcam: bool = False,
):
    super().__init__(model, optimizer, lr_scheduler, lr_scheduler_interval)
    self.criterion = criterion
    self.gradcam = gradcam

    # TODO: can we use gradcam with more backbones?
    if self.gradcam:
        if not isinstance(model[0].features_extractor, timm.models.resnet.ResNet):
            log.warning(
                "Backbone must be compatible with gradcam, at the moment only ResNets supported, disabling gradcam"
            )
            self.gradcam = False
        else:
            target_layers = [model[0].features_extractor.layer4[-1]]
            self.cam = GradCAM(model=model, target_layers=target_layers, use_cuda=torch.cuda.is_available())