20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142 | def save_classification_result(
results: pd.DataFrame,
output_folder: str,
confusion_matrix: Optional[pd.DataFrame],
accuracy: float,
test_dataloader: DataLoader,
reconstructions: List[Dict],
config: DictConfig,
output: DictConfig,
ignore_classes: Optional[List[int]] = None,
):
"""Save classification results.
Args:
results: Dataframe containing the classification results
output_folder: Folder where to save the results
confusion_matrix: Confusion matrix
accuracy: Accuracy of the model
test_dataloader: Dataloader used for testing
reconstructions: List of dictionaries containing polygons or masks
config: Experiment configuration
output: Output configuration
ignore_classes: Eventual classes to ignore during reconstruction plot. Defaults to None.
"""
# Save csv
results.to_csv(os.path.join(output_folder, "test_results.csv"), index_label="index")
if confusion_matrix is not None:
# Save confusion matrix
disp = ConfusionMatrixDisplay(
confusion_matrix=np.array(confusion_matrix),
display_labels=[x.replace("pred:", "") for x in confusion_matrix.columns.to_list()],
)
disp.plot(include_values=True, cmap=plt.cm.Greens, ax=None, colorbar=False, xticks_rotation=90)
plt.title(f"Confusion Matrix (Accuracy: {(accuracy * 100):.2f}%)")
plt.savefig(
os.path.join(output_folder, "test_confusion_matrix.png"),
bbox_inches="tight",
pad_inches=0,
dpi=300,
)
plt.close()
if output.example:
if not hasattr(test_dataloader.dataset, "idx_to_class"):
raise ValueError("The provided dataset does not have an attribute 'idx_to_class")
idx_to_class = test_dataloader.dataset.idx_to_class
# Get misclassified samples
example_folder = os.path.join(output_folder, "example")
if not os.path.isdir(example_folder):
os.makedirs(example_folder)
# Skip if no no ground truth is available
if not all(results["real_label"] == -1):
for v in np.unique([results["real_label"], results["pred_label"]]):
if v == -1:
continue
k = idx_to_class[v]
if ignore_classes is not None and v in ignore_classes:
continue
plot_classification_results(
test_dataloader.dataset,
unorm=UnNormalize(mean=config.transforms.mean, std=config.transforms.std),
pred_labels=results["pred_label"].to_numpy(),
test_labels=results["real_label"].to_numpy(),
class_name=k,
original_folder=example_folder,
idx_to_class=idx_to_class,
pred_class_to_plot=v,
what="con",
rows=output.get("rows", 3),
cols=output.get("cols", 2),
figsize=output.get("figsize", (20, 20)),
)
plot_classification_results(
test_dataloader.dataset,
unorm=UnNormalize(mean=config.transforms.mean, std=config.transforms.std),
pred_labels=results["pred_label"].to_numpy(),
test_labels=results["real_label"].to_numpy(),
class_name=k,
original_folder=example_folder,
idx_to_class=idx_to_class,
pred_class_to_plot=v,
what="dis",
rows=output.get("rows", 3),
cols=output.get("cols", 2),
figsize=output.get("figsize", (20, 20)),
)
for counter, reconstruction in enumerate(reconstructions):
is_polygon = True
if isinstance(reconstruction["prediction"], np.ndarray):
is_polygon = False
if is_polygon:
if len(reconstruction["prediction"]) == 0:
continue
else:
if reconstruction["prediction"].sum() == 0:
continue
if counter > 5:
break
to_plot = plot_patch_reconstruction(
reconstruction,
idx_to_class,
class_to_idx=test_dataloader.dataset.class_to_idx, # type: ignore[attr-defined]
ignore_classes=ignore_classes,
is_polygon=is_polygon,
)
if to_plot:
output_name = f"reconstruction_{os.path.splitext(os.path.basename(reconstruction['file_name']))[0]}.png"
plt.savefig(os.path.join(example_folder, output_name), bbox_inches="tight", pad_inches=0)
plt.close()
|