Skip to content

visualization

create_rgb_mask(mask, color_map, ignore_classes=None, ground_truth_mask=None)

Convert index mask to RGB mask.

Source code in quadra/utils/patch/visualization.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def create_rgb_mask(
    mask: np.ndarray,
    color_map: Dict,
    ignore_classes: Optional[List[int]] = None,
    ground_truth_mask: Optional[np.ndarray] = None,
):
    """Convert index mask to RGB mask."""
    output_mask = np.zeros([mask.shape[0], mask.shape[1], 3])
    for c in np.unique(mask):
        if ignore_classes is not None and c in ignore_classes:
            continue

        output_mask[mask == c] = color_map[str(c)]
    if ignore_classes is not None and ground_truth_mask is not None:
        output_mask[np.isin(ground_truth_mask, ignore_classes)] = [0, 0, 0]

    return output_mask

plot_patch_reconstruction(reconstruction, idx_to_class, class_to_idx, ignore_classes=None, is_polygon=True)

Helper function for plotting the patch reconstruction.

Parameters:

  • reconstruction (Dict) –

    Dict following this structure { "file_path": str, "mask_path": str, "prediction": { "label": str, "points": [{"x": int, "y": int}] } } if is_polygon else { "file_path": str, "mask_path": str, "prediction": np.ndarray }

  • idx_to_class (Dict[int, str]) –

    Dictionary mapping indices to label names

  • class_to_idx (Dict[str, int]) –

    Dictionary mapping class names to indices

  • ignore_classes (Optional[List[int]], default: None ) –

    Eventually the classes to not plot

  • is_polygon (bool, default: True ) –

    Boolean indicating if the prediction is a polygon or a mask.

Returns:

  • Figure

    Matplotlib plot showing predicted patch regions and eventually gt

Source code in quadra/utils/patch/visualization.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def plot_patch_reconstruction(
    reconstruction: Dict,
    idx_to_class: Dict[int, str],
    class_to_idx: Dict[str, int],
    ignore_classes: Optional[List[int]] = None,
    is_polygon: bool = True,
) -> Figure:
    """Helper function for plotting the patch reconstruction.

    Args:
        reconstruction: Dict following this structure
            {
                "file_path": str,
                "mask_path": str,
                "prediction": {
                    "label": str,
                    "points": [{"x": int, "y": int}]
                }
            } if is_polygon else
            {
                "file_path": str,
                "mask_path": str,
                "prediction": np.ndarray
            }
        idx_to_class: Dictionary mapping indices to label names
        class_to_idx: Dictionary mapping class names to indices
        ignore_classes: Eventually the classes to not plot
        is_polygon: Boolean indicating if the prediction is a polygon or a mask.

    Returns:
        Matplotlib plot showing predicted patch regions and eventually gt

    """
    cmap_name = "tab10"

    # 10 classes + good
    if len(idx_to_class.values()) > 11:
        cmap_name = "tab20"

    cmap = get_cmap(cmap_name)
    test_img = cv2.imread(reconstruction["image_path"])
    test_img = cv2.cvtColor(test_img, cv2.COLOR_BGR2RGB)
    gt_img = None

    if reconstruction["mask_path"] is not None and os.path.isfile(reconstruction["mask_path"]):
        gt_img = cv2.imread(reconstruction["mask_path"], 0)

    out = np.zeros((test_img.shape[0], test_img.shape[1]), dtype=np.uint8)

    if is_polygon:
        for _, region in enumerate(reconstruction["prediction"]):
            points = [[item["x"], item["y"]] for item in region["points"]]
            c_label = region["label"]

            out = cv2.drawContours(
                out,
                np.array([points], np.int32),
                -1,
                class_to_idx[c_label],
                thickness=cv2.FILLED,
            )
    else:
        out = reconstruction["prediction"]

    fig = plot_patch_results(
        image=test_img,
        prediction_image=out,
        ground_truth_image=gt_img,
        plot_original=True,
        ignore_classes=ignore_classes,
        save_path=None,
        class_to_idx=class_to_idx,
        cmap=cmap,
    )

    return fig

plot_patch_results(image, prediction_image, ground_truth_image, class_to_idx, plot_original=True, ignore_classes=None, image_height=10, save_path=None, cmap=get_cmap('tab20'))

Function used to plot the image predicted Args: prediction_image: The prediction image image: The original image to plot ground_truth_image: The ground truth image class_to_idx: Dictionary mapping class names to indices plot_original: Boolean to plot the original image ignore_classes: The classes to ignore, default is 0 image_height: The height of the output figure save_path: The path to save the figure cmap: The colormap to use.

Returns:

  • Figure

    The matplotlib figure

Source code in quadra/utils/patch/visualization.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def plot_patch_results(
    image: np.ndarray,
    prediction_image: np.ndarray,
    ground_truth_image: Optional[np.ndarray],
    class_to_idx: Dict[str, int],
    plot_original: bool = True,
    ignore_classes: Optional[List[int]] = None,
    image_height: int = 10,
    save_path: Optional[str] = None,
    cmap: Colormap = get_cmap("tab20"),
) -> Figure:
    """Function used to plot the image predicted
    Args:
        prediction_image: The prediction image
        image: The original image to plot
        ground_truth_image: The ground truth image
        class_to_idx: Dictionary mapping class names to indices
        plot_original: Boolean to plot the original image
        ignore_classes: The classes to ignore, default is 0
        image_height: The height of the output figure
        save_path: The path to save the figure
        cmap: The colormap to use.

    Returns:
        The matplotlib figure
    """
    if ignore_classes is None:
        ignore_classes = [0]

    image = image[0 : prediction_image.shape[0], 0 : prediction_image.shape[1], :]
    idx_to_class = {v: k for k, v in class_to_idx.items()}

    if ignore_classes is not None:
        class_to_idx = {k: v for k, v in class_to_idx.items() if v not in ignore_classes}

    class_idxs = list(class_to_idx.values())

    cmap = {str(c): tuple(int(i * 255) for i in cmap(c / len(class_idxs))[:-1]) for c in class_idxs}
    output_images = []
    titles = []

    if plot_original:
        output_images.append(image)
        titles.append("Original Image")

    if ground_truth_image is not None:
        ground_truth_image = ground_truth_image[0 : prediction_image.shape[0], 0 : prediction_image.shape[1]]
        ground_truth_mask = create_rgb_mask(ground_truth_image, cmap, ignore_classes=ignore_classes)
        output_images.append(ground_truth_mask)
        titles.append("Ground Truth Mask")

    prediction_mask = create_rgb_mask(
        prediction_image,
        cmap,
        ignore_classes=ignore_classes,
    )

    output_images.append(prediction_mask)
    titles.append("Prediction Mask")
    if ignore_classes is not None and ground_truth_image is not None:
        prediction_mask = create_rgb_mask(
            prediction_image, cmap, ignore_classes=ignore_classes, ground_truth_mask=ground_truth_image
        )

        ignored_classes_str = [idx_to_class[c] for c in ignore_classes]
        prediction_title = f"Prediction Mask \n (Ignoring Ground Truth Class: {ignored_classes_str})"
        output_images.append(prediction_mask)
        titles.append(prediction_title)

    fig, axs = plt.subplots(
        ncols=len(output_images),
        nrows=1,
        figsize=(len(output_images) * image_height, image_height),
        squeeze=False,
        facecolor="white",
    )

    for i, output_image in enumerate(output_images):
        axs[0, i].imshow(show_mask_on_image(image, output_image))
        axs[0, i].set_title(titles[i])
        axs[0, i].axis("off")

    custom_lines = [Line2D([0], [0], color=tuple(i / 255.0 for i in cmap[str(c)]), lw=4) for c in class_idxs]
    custom_labels = list(class_to_idx.keys())
    axs[0, -1].legend(custom_lines, custom_labels, loc="center left", bbox_to_anchor=(1.01, 0.81), borderaxespad=0)
    if save_path is not None:
        plt.savefig(save_path, bbox_inches="tight")
        plt.close()

    return fig

show_mask_on_image(image, mask)

Plot mask on top of the original image.

Source code in quadra/utils/patch/visualization.py
 95
 96
 97
 98
 99
100
101
def show_mask_on_image(image: np.ndarray, mask: np.ndarray):
    """Plot mask on top of the original image."""
    image = image.astype(np.float32) / 255
    mask = mask.astype(np.float32) / 255
    out = mask + image.astype(np.float32)
    out = out / np.max(out)
    return np.uint8(255 * out)