Skip to content

segmentation

smooth_mask(mask)

Smooths for segmentation.

Parameters:

Returns:

Source code in quadra/utils/segmentation.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def smooth_mask(mask: np.ndarray) -> np.ndarray:
    """Smooths for segmentation.

    Args:
        mask: Input mask

    Returns:
        Smoothed mask
    """
    labeled_mask = skimage.measure.label(mask)
    labels = np.arange(0, np.max(labeled_mask) + 1)
    output_mask = np.zeros_like(mask).astype(np.float32)
    for l in labels:
        component_mask = labeled_mask == l
        _, distance = medial_axis(component_mask, return_distance=True)
        component_mask_norm = distance ** (1 / 2.2)
        component_mask_norm = (component_mask_norm - np.min(component_mask_norm)) / (
            np.max(component_mask_norm) - np.min(component_mask_norm)
        )
        output_mask += component_mask_norm
    output_mask = output_mask * mask
    return output_mask