Skip to content

visualization

UnNormalize(mean, std)

Unnormalize a tensor image with mean and standard deviation.

Source code in quadra/utils/visualization.py
29
30
31
def __init__(self, mean, std):
    self.mean = mean
    self.std = std

__call__(tensor, make_copy=True)

Call function to unnormalize a tensor image with mean and standard deviation.

Parameters:

  • tensor (Tensor) –

    Tensor image of size (C, H, W) to be normalized.

  • make_copy (bool, default: True ) –

    whether to apply normalization to a copied tensor

Source code in quadra/utils/visualization.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def __call__(self, tensor: torch.Tensor, make_copy=True) -> torch.Tensor:
    """Call function to unnormalize a tensor image with mean and standard deviation.

    Args:
        tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        make_copy (bool): whether to apply normalization to a copied tensor
    Returns:
        Tensor: Normalized image.
    """
    if make_copy:
        new_t = tensor.detach().clone()
    else:
        new_t = tensor
    for t, m, s in zip(new_t, self.mean, self.std):
        t.mul_(s).add_(m)
        # The normalize code -> t.sub_(m).div_(s)
    return new_t

create_grid_figure(images, nrows, ncols, file_path, bounds, row_names=None, fig_size=(12, 8))

Create a grid figure with images.

Parameters:

  • images (Iterable[ndarray]) –

    List of images to plot.

  • nrows (int) –

    Number of rows in the grid.

  • ncols (int) –

    Number of columns in the grid.

  • file_path (str) –

    Path to save the figure.

  • row_names (Optional[Iterable[str]], default: None ) –

    Row names. Defaults to None.

  • fig_size (Tuple[int, int], default: (12, 8) ) –

    Figure size. Defaults to (12, 8).

  • bounds (Optional[List[Tuple[float, float]]]) –

    Bounds for the images. Defaults to None.

Source code in quadra/utils/visualization.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def create_grid_figure(
    images: Iterable[Iterable[np.ndarray]],
    nrows: int,
    ncols: int,
    file_path: str,
    bounds: List[Tuple[float, float]],
    row_names: Optional[Iterable[str]] = None,
    fig_size: Tuple[int, int] = (12, 8),
):
    """Create a grid figure with images.

    Args:
        images (Iterable[np.ndarray]): List of images to plot.
        nrows (int): Number of rows in the grid.
        ncols (int): Number of columns in the grid.
        file_path (str): Path to save the figure.
        row_names (Optional[Iterable[str]], optional): Row names. Defaults to None.
        fig_size (Tuple[int, int], optional): Figure size. Defaults to (12, 8).
        bounds (Optional[List[Tuple[float, float]]], optional): Bounds for the images. Defaults to None.
    """
    default_plt_backend = plt.get_backend()
    plt.switch_backend("Agg")
    _, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=fig_size, squeeze=False)
    for i, row in enumerate(images):
        for j, image in enumerate(row):
            if len(image.shape) == 3 and image.shape[0] == 1:
                image = image[0]
            ax[i][j].imshow(image, vmin=bounds[i][0], vmax=bounds[i][1])
            ax[i][j].get_xaxis().set_ticks([])
            ax[i][j].get_yaxis().set_ticks([])
    if row_names is not None:
        for ax, name in zip(ax[:, 0], row_names):
            ax.set_ylabel(name, rotation=90)

    plt.tight_layout()
    plt.savefig(file_path, bbox_inches="tight", dpi=300, facecolor="white", transparent=False)
    plt.close()
    plt.switch_backend(default_plt_backend)

create_visualization_dataset(dataset)

Create a visualization dataset by updating transforms.

Source code in quadra/utils/visualization.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def create_visualization_dataset(dataset: torch.utils.data.Dataset):
    """Create a visualization dataset by updating transforms."""

    def convert_transforms(transforms: Any):
        """Handle different types of transforms."""
        if isinstance(transforms, albumentations.BaseCompose):
            transforms.transforms = convert_transforms(transforms.transforms)
        if isinstance(transforms, (list, ListConfig, TransformsSeqType)):
            transforms = [convert_transforms(t) for t in transforms]
        if isinstance(transforms, (dict, DictConfig)):
            for tname, t in transforms.items():
                transforms[tname] = convert_transforms(t)
        if isinstance(transforms, (Normalize, ToTensorV2)):
            return NoOp(p=1)
        return transforms

    new_dataset = copy.deepcopy(dataset)
    # TODO: Create dataset class that has a transform attribut, we can then use isinstance
    if isinstance(dataset, torch.utils.data.Dataset):
        transform = copy.deepcopy(dataset.transform)  # type: ignore[attr-defined]
        if transform is not None:
            new_transforms = convert_transforms(transform)
            new_dataset.transform = new_transforms  # type: ignore[attr-defined]
        else:
            raise ValueError(f"The dataset transform {type(transform)} is not supported")
    else:
        raise ValueError(f"The dataset type {dataset} is not supported")
    return new_dataset

plot_classification_results(test_dataset, pred_labels, test_labels, class_name, original_folder, gradcam_folder=None, grayscale_cams=None, unorm=None, idx_to_class=None, what=None, real_class_to_plot=None, pred_class_to_plot=None, rows=1, cols=4, figsize=(20, 20), gradcam=False)

Plot and save images extracted from classification. If gradcam is True, same images with a gradcam heatmap (layered on original image) will also be saved.

Parameters:

  • test_dataset (Dataset) –

    Test dataset

  • pred_labels (ndarray) –

    Predicted labels

  • test_labels (ndarray) –

    Test labels

  • class_name (str) –

    Name of the examples' class

  • original_folder (str) –

    Folder where original examples will be saved

  • gradcam_folder (Optional[str], default: None ) –

    Folder in which gradcam examples will be saved

  • grayscale_cams (Optional[ndarray], default: None ) –

    Grayscale gradcams (ordered as pred_labels and test_labels)

  • unorm (Optional[Callable[[Tensor], Tensor]], default: None ) –

    Albumentations function to unormalize image

  • idx_to_class (Optional[Dict], default: None ) –

    Dictionary of class conversion

  • what (Optional[str], default: None ) –

    Can be "dis" or "conc", used if real_class_to_plot or pred_class_to_plot are None

  • real_class_to_plot (Optional[int], default: None ) –

    Real class to plot.

  • pred_class_to_plot (Optional[int], default: None ) –

    Pred class to plot.

  • rows (Optional[int], default: 1 ) –

    How many rows in the plot there will be.

  • cols (int, default: 4 ) –

    How many cols in the plot there will be.

  • figsize (Tuple[int, int], default: (20, 20) ) –

    The figure size.

  • gradcam (bool, default: False ) –

    Whether to save also the gradcam version of the examples

Source code in quadra/utils/visualization.py
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
def plot_classification_results(
    test_dataset: torch.utils.data.Dataset,
    pred_labels: np.ndarray,
    test_labels: np.ndarray,
    class_name: str,
    original_folder: str,
    gradcam_folder: Optional[str] = None,
    grayscale_cams: Optional[np.ndarray] = None,
    unorm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
    idx_to_class: Optional[Dict] = None,
    what: Optional[str] = None,
    real_class_to_plot: Optional[int] = None,
    pred_class_to_plot: Optional[int] = None,
    rows: Optional[int] = 1,
    cols: int = 4,
    figsize: Tuple[int, int] = (20, 20),
    gradcam: bool = False,
) -> None:
    """Plot and save images extracted from classification. If gradcam is True, same images
    with a gradcam heatmap (layered on original image) will also be saved.

    Args:
        test_dataset: Test dataset
        pred_labels: Predicted labels
        test_labels: Test labels
        class_name: Name of the examples' class
        original_folder: Folder where original examples will be saved
        gradcam_folder: Folder in which gradcam examples will be saved
        grayscale_cams: Grayscale gradcams (ordered as pred_labels and test_labels)
        unorm: Albumentations function to unormalize image
        idx_to_class: Dictionary of class conversion
        what: Can be "dis" or "conc", used if real_class_to_plot or pred_class_to_plot are None
        real_class_to_plot: Real class to plot.
        pred_class_to_plot: Pred class to plot.
        rows: How many rows in the plot there will be.
        cols: How many cols in the plot there will be.
        figsize: The figure size.
        gradcam: Whether to save also the gradcam version of the examples

    """
    to_plot = True
    if gradcam:
        if grayscale_cams is None:
            raise ValueError("gradcam is True but grayscale_cams is None")
        if gradcam_folder is None:
            raise ValueError("gradcam is True but gradcam_folder is None")

    if real_class_to_plot is not None:
        sample_idx = np.where(test_labels == real_class_to_plot)[0]
        if gradcam and grayscale_cams is not None:
            grayscale_cams = grayscale_cams[test_labels == real_class_to_plot]
        pred_labels = pred_labels[test_labels == real_class_to_plot]
        test_labels = test_labels[test_labels == real_class_to_plot]

    if pred_class_to_plot is not None:
        sample_idx = np.where(pred_labels == pred_class_to_plot)[0]
        if gradcam and grayscale_cams is not None:
            grayscale_cams = grayscale_cams[pred_labels == pred_class_to_plot]
        test_labels = test_labels[pred_labels == pred_class_to_plot]
        pred_labels = pred_labels[pred_labels == pred_class_to_plot]

    if pred_class_to_plot is None and real_class_to_plot is None:
        raise ValueError("'real_class_to_plot' and 'pred_class_to_plot' must not be both None")

    if what is not None:
        if what == "dis":
            cordant = pred_labels != test_labels
        elif what == "con":
            cordant = pred_labels == test_labels
        else:
            raise AssertionError(f"{what} not a valid plot type. Must be con or dis")

        sample_idx = np.array(sample_idx)[cordant]
        pred_labels = np.array(pred_labels)[cordant]
        test_labels = np.array(test_labels)[cordant]
        if gradcam:
            grayscale_cams = np.array(grayscale_cams)[cordant]

    # randomize
    idx_random = random.sample(range(len(sample_idx)), len(sample_idx))

    sample_idx = sample_idx[idx_random]
    pred_labels = pred_labels[idx_random]
    test_labels = test_labels[idx_random]
    if gradcam and grayscale_cams is not None:
        grayscale_cams = grayscale_cams[idx_random]

    cordant_chunks = list(_chunks(sample_idx, cols))

    if len(sample_idx) == 0:
        to_plot = False
        print("Nothing to plot")
    else:
        if rows is None or rows == 0:
            total_rows = len(cordant_chunks)
        else:
            total_rows = len(cordant_chunks[:rows])
        if gradcam:
            modality_list = ["original", "gradcam"]
        else:
            modality_list = ["original"]
        for modality in modality_list:
            fig = plt.figure(figsize=figsize)
            grid = ImageGrid(
                fig,
                111,  # similar to subplot(111)
                nrows_ncols=(total_rows, cols),
                axes_pad=(0.2, 0.5),
            )
            for i, ax in enumerate(grid):
                if idx_to_class is not None:
                    try:
                        pred_label = idx_to_class[pred_labels[i]]
                    except Exception:
                        pred_label = pred_labels[i]
                    try:
                        test_label = idx_to_class[test_labels[i]]
                    except Exception:
                        test_label = test_labels[i]

                ax.axis("off")
                ax.set_title(f"True: {str(test_label)}\nPred {str(pred_label)}")
                image, _ = test_dataset[sample_idx[i]]

                if unorm is not None:
                    image = np.array(unorm(image))
                if modality == "gradcam" and grayscale_cams is not None:
                    grayscale_cam = grayscale_cams[i]
                    rgb_cam = show_cam_on_image(
                        np.transpose(image, (1, 2, 0)), grayscale_cam, use_rgb=True, image_weight=0.7
                    )

                    ax.imshow(rgb_cam, cmap="gray")
                    if i == len(pred_labels) - 1:
                        break
                else:
                    if isinstance(image, torch.Tensor):
                        image = image.cpu().numpy()

                    if image.max() <= 1:
                        image = image * 255
                    image = image.astype(int)

                    if len(image.shape) == 3:
                        if image.shape[0] == 1:
                            image = image[0]
                        elif image.shape[0] == 3:
                            image = image.transpose((1, 2, 0))
                    ax.imshow(image, cmap="gray")
                    if i == len(pred_labels) - 1:
                        break

            for item in grid:
                item.axis("off")

            if to_plot:
                save_folder: str = ""
                if modality == "gradcam" and gradcam_folder is not None:
                    save_folder = gradcam_folder
                elif modality == "original":
                    save_folder = original_folder
                else:
                    log.warning("modality %s has no corresponding folder", modality)
                    return

                plt.savefig(
                    os.path.join(save_folder, f"{what}cordant_{class_name}_" + modality + ".png"),
                    bbox_inches="tight",
                    pad_inches=0,
                )
                plt.close()

plot_multiclass_prediction(image, prediction_image, ground_truth_image, class_to_idx, plot_original=True, ignore_class=0, image_height=10, save_path=None, color_map='tab20')

Function used to plot the image predicted.

Parameters:

  • image (ndarray) –

    The image to plot

  • prediction_image (ndarray) –

    The prediction image

  • ground_truth_image (ndarray) –

    The ground truth image

  • class_to_idx (Dict[str, int]) –

    The class to idx mapping

  • plot_original (bool, default: True ) –

    Whether to plot the original image

  • ignore_class (Optional[int], default: 0 ) –

    The class to ignore

  • image_height (int, default: 10 ) –

    The height of the output figure

  • save_path (Optional[str], default: None ) –

    The path to save the figure

  • color_map (str, default: 'tab20' ) –

    The color map to use. Defaults to "tab20".

Source code in quadra/utils/visualization.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def plot_multiclass_prediction(
    image: np.ndarray,
    prediction_image: np.ndarray,
    ground_truth_image: np.ndarray,
    class_to_idx: Dict[str, int],
    plot_original: bool = True,
    ignore_class: Optional[int] = 0,
    image_height: int = 10,
    save_path: Optional[str] = None,
    color_map: str = "tab20",
) -> None:
    """Function used to plot the image predicted.

    Args:
        image: The image to plot
        prediction_image: The prediction image
        ground_truth_image: The ground truth image
        class_to_idx: The class to idx mapping
        plot_original: Whether to plot the original image
        ignore_class: The class to ignore
        image_height: The height of the output figure
        save_path: The path to save the figure
        color_map: The color map to use. Defaults to "tab20".
    """
    image = image[0 : prediction_image.shape[0], 0 : prediction_image.shape[1], :]
    class_idxs = list(class_to_idx.values())
    cm = get_cmap(color_map)
    cmap = {str(c): tuple(int(i * 255) for i in cm(c / len(class_idxs))[:-1]) for c in class_idxs}
    output_images = []
    titles = []
    if plot_original:
        output_images.append(image)
        titles.append("Original Image")

    ground_truth_mask = reconstruct_multiclass_mask(ground_truth_image, image.shape, cmap, ignore_class=ignore_class)
    output_images.append(ground_truth_mask)
    titles.append("Ground Truth Mask")

    prediction_mask = reconstruct_multiclass_mask(
        prediction_image,
        image.shape,
        cmap,
        ignore_class=ignore_class,
    )
    output_images.append(prediction_mask)
    titles.append("Prediction Mask")
    if ignore_class is not None:
        prediction_mask = reconstruct_multiclass_mask(
            prediction_image, image.shape, cmap, ignore_class=ignore_class, ground_truth_mask=ground_truth_image
        )
        prediction_title = f"Prediction Mask \n (Ignoring Ground Truth Class: {ignore_class})"
        output_images.append(prediction_mask)
        titles.append(prediction_title)

    _, axs = plt.subplots(
        ncols=len(output_images),
        nrows=1,
        figsize=(len(output_images) * image_height, image_height),
        squeeze=False,
        facecolor="white",
    )
    for i, output_image in output_images:
        axs[0, i].imshow(show_mask_on_image(image, output_image))
        axs[0, i].set_title(titles[i])
        axs[0, i].axis("off")
    custom_lines = [Line2D([0], [0], color=tuple(i / 255.0 for i in cmap[str(c)]), lw=4) for c in class_idxs]
    custom_labels = list(class_to_idx.keys())
    axs[0, -1].legend(custom_lines, custom_labels, loc="center left", bbox_to_anchor=(1.01, 0.81), borderaxespad=0)
    if save_path is not None:
        plt.savefig(save_path, bbox_inches="tight")
        plt.close()

reconstruct_multiclass_mask(mask, image_shape, color_map, ignore_class=None, ground_truth_mask=None)

Reconstruct a multiclass mask from a single channel mask.

Parameters:

  • mask (ndarray) –

    A single channel mask.

  • image_shape (Tuple[int, ...]) –

    The shape of the image.

  • color_map (ListedColormap) –

    The color map to use.

  • ignore_class (Optional[int], default: None ) –

    The class to ignore. Defaults to None.

  • ground_truth_mask (Optional[ndarray], default: None ) –

    The ground truth mask. Defaults to None.

Returns:

Source code in quadra/utils/visualization.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def reconstruct_multiclass_mask(
    mask: np.ndarray,
    image_shape: Tuple[int, ...],
    color_map: ListedColormap,
    ignore_class: Optional[int] = None,
    ground_truth_mask: Optional[np.ndarray] = None,
) -> np.ndarray:
    """Reconstruct a multiclass mask from a single channel mask.

    Args:
        mask (np.ndarray): A single channel mask.
        image_shape (Tuple[int, ...]): The shape of the image.
        color_map (ListedColormap): The color map to use.
        ignore_class (Optional[int], optional): The class to ignore. Defaults to None.
        ground_truth_mask (Optional[np.ndarray], optional): The ground truth mask. Defaults to None.

    Returns:
        mask: np.ndarray
    """
    output_mask = np.zeros(image_shape)
    for c in np.unique(mask):
        if ignore_class is not None and c == ignore_class:
            continue

        output_mask[mask == c] = color_map[str(c)]

    if ignore_class is not None and ground_truth_mask is not None:
        output_mask[ground_truth_mask == ignore_class] = [0, 0, 0]

    return output_mask

show_mask_on_image(image, mask)

Show a mask on an image.

Parameters:

Returns:

  • ndarray

    np.ndarray: The image with the mask.

Source code in quadra/utils/visualization.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def show_mask_on_image(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
    """Show a mask on an image.

    Args:
        image (np.ndarray): The image.
        mask (np.ndarray): The mask.

    Returns:
        np.ndarray: The image with the mask.
    """
    image = image.astype(np.float32) / 255
    mask = mask.astype(np.float32) / 255
    out = mask + image
    out = out / np.max(out)
    return (255 * out).astype(np.uint8)