Skip to content

classification

SklearnClassificationTrainer(input_shape, backbone, random_state=42, classifier=LogisticRegression, iteration_over_training=1)

Class to configure and run a classification using torch for feature extraction and sklearn to fit a classifier.

Parameters:

  • input_shape (list) –

    [H, W, C]

  • random_state (int, default: 42 ) –

    seed to fix randomness

  • classifier (ClassifierMixin, default: LogisticRegression ) –

    classification model

  • iteration_over_training (int, default: 1 ) –

    the number of iteration over training during feature extraction

  • backbone (Module) –

    the feature extractor

Source code in quadra/trainers/classification.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def __init__(
    self,
    input_shape: list,
    backbone: torch.nn.Module,
    random_state: int = 42,
    classifier: ClassifierMixin = LogisticRegression,
    iteration_over_training: int = 1,
) -> None:
    super().__init__()

    try:
        self.classifier = classifier(max_iter=1e4, random_state=random_state)
    except Exception:
        self.classifier = classifier

    self.input_shape = input_shape
    self.random_state = random_state
    self.iteration_over_training = iteration_over_training
    self.backbone = backbone

change_backbone(backbone)

Update feature extractor.

Source code in quadra/trainers/classification.py
52
53
54
55
def change_backbone(self, backbone: torch.nn.Module):
    """Update feature extractor."""
    self.backbone = backbone
    self.backbone.eval()

change_classifier(classifier)

Update classifier.

Source code in quadra/trainers/classification.py
57
58
59
def change_classifier(self, classifier: ClassifierMixin):
    """Update classifier."""
    self.classifier = classifier

fit(train_dataloader=None, train_features=None, train_labels=None)

Fit classifier on training set.

Source code in quadra/trainers/classification.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def fit(
    self,
    train_dataloader: DataLoader | None = None,
    train_features: ndarray | None = None,
    train_labels: ndarray | None = None,
):
    """Fit classifier on training set."""
    # Extract feature
    if self.backbone is None:
        raise AssertionError("You must set a model before running execution")

    if train_dataloader is not None:  # train_features is None or train_labels is None:
        log.info("Extracting features from training set")
        train_features, train_labels, _ = get_feature(
            feature_extractor=self.backbone,
            dl=train_dataloader,
            iteration_over_training=self.iteration_over_training,
            gradcam=False,
        )
    else:
        log.info("Using cached features for training set")
        # With the current implementation cached features are not sorted
        # Even though it doesn't seem to change anything
        if train_features is None or train_labels is None:
            raise AssertionError("Train features and labels must be provided when using cached data")
        permuted_indices = np.random.RandomState(seed=self.random_state).permutation(train_features.shape[0])
        train_features = train_features[permuted_indices]
        train_labels = train_labels[permuted_indices]

    log.info("Fitting classifier on %d features", len(train_features))  # type: ignore[arg-type]
    self.classifier.fit(train_features, train_labels)

test(test_dataloader, test_labels=None, test_features=None, class_to_keep=None, idx_to_class=None, predict_proba=True, gradcam=False)

Test classifier on test set.

Parameters:

  • test_dataloader (DataLoader) –

    Test dataloader

  • test_labels (ndarray | None, default: None ) –

    test labels

  • test_features (ndarray | None, default: None ) –

    Optional test features used when cache data is available

  • class_to_keep (list[int] | None, default: None ) –

    list of class to keep

  • idx_to_class (dict[int, str] | None, default: None ) –

    dictionary mapping class index to class name

  • predict_proba (bool, default: True ) –

    if True, predict also probability for each test image

  • gradcam (bool, default: False ) –

    Whether to compute gradcam

Returns:

Source code in quadra/trainers/classification.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def test(
    self,
    test_dataloader: DataLoader,
    test_labels: ndarray | None = None,
    test_features: ndarray | None = None,
    class_to_keep: list[int] | None = None,
    idx_to_class: dict[int, str] | None = None,
    predict_proba: bool = True,
    gradcam: bool = False,
) -> (
    tuple[str | dict, DataFrame, float, DataFrame, np.ndarray | None]
    | tuple[None, None, None, DataFrame, np.ndarray | None]
):
    """Test classifier on test set.

    Args:
        test_dataloader: Test dataloader
        test_labels: test labels
        test_features: Optional test features used when cache data is available
        class_to_keep: list of class to keep
        idx_to_class: dictionary mapping class index to class name
        predict_proba: if True, predict also probability for each test image
        gradcam: Whether to compute gradcam

    Returns:
        cl_rep: Classification report
        pd_cm: Confusion matrix dataframe
        accuracy: Test accuracy
        res: Test results
        cams: Gradcams
    """
    cams = None
    # Extract feature
    if test_features is None:
        log.info("Extracting features from test set")
        test_features, final_test_labels, cams = get_feature(
            feature_extractor=self.backbone,
            dl=test_dataloader,
            gradcam=gradcam,
            classifier=self.classifier,
            input_shape=(self.input_shape[2], self.input_shape[0], self.input_shape[1]),
        )
    else:
        if test_labels is None:
            raise ValueError("Test labels must be provided when using cached data")
        log.info("Using cached features for test set")
        final_test_labels = test_labels

    # Run classifier
    log.info("Predict classifier on test set")
    test_prediction_label = self.classifier.predict(test_features)
    if predict_proba:
        test_probability = self.classifier.predict_proba(test_features)
        test_probability = test_probability.max(axis=1)

    if class_to_keep is not None:
        if idx_to_class is None:
            raise ValueError("You must provide `idx_to_class` and `test_labels` when using `class_to_keep`")
        filtered_test_labels = [int(x) if idx_to_class[x] in class_to_keep else -1 for x in final_test_labels]
    else:
        filtered_test_labels = cast(list[int], final_test_labels.tolist())

    if not hasattr(test_dataloader.dataset, "x"):
        raise ValueError("Current dataset doesn't provide an `x` attribute")

    res = pd.DataFrame(
        {
            "sample": list(test_dataloader.dataset.x),
            "real_label": final_test_labels,
            "pred_label": test_prediction_label,
        }
    )

    if not all(t == -1 for t in filtered_test_labels):
        test_real_label_cm = np.array(filtered_test_labels)
        if cams is not None:
            cams = cams[test_real_label_cm != -1]  # TODO: Is class_to_keep still used?
        pred_labels_cm = np.array(test_prediction_label)[test_real_label_cm != -1]
        test_real_label_cm = test_real_label_cm[test_real_label_cm != -1].astype(pred_labels_cm.dtype)
        cl_rep, pd_cm, accuracy = get_results(test_real_label_cm, pred_labels_cm, idx_to_class)

        if predict_proba:
            res["probability"] = test_probability

        return cl_rep, pd_cm, accuracy, res, cams

    return None, None, None, res, cams