Skip to content

segmentation

SegmentationDataset(image_paths, mask_paths, batch_size=None, object_masks=None, resize=224, mask_preprocess=None, labels=None, transform=None, mask_smoothing=False, defect_transform=None)

Bases: torch.utils.data.Dataset

Custom SegmentationDataset class for loading images and masks.

Parameters:

  • image_paths (List[str]) –

    List of paths to images.

  • mask_paths (List[str]) –

    List of paths to masks.

  • batch_size (Optional[int]) –

    Batch size.

  • object_masks (Optional[List[Union[np.ndarray, Any]]]) –

    List of paths to object masks.

  • resize (int) –

    Resize image to this size.

  • mask_preprocess (Optional[Callable]) –

    Preprocess mask.

  • labels (Optional[List[str]]) –

    List of labels.

  • transform (Optional[albumentations.Compose]) –

    Transformations to apply to images and masks.

  • mask_smoothing (bool) –

    Smooth mask.

  • defect_transform (Optional[albumentations.Compose]) –

    Transformations to apply to images and masks for defects.

Source code in quadra/datasets/segmentation.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def __init__(
    self,
    image_paths: List[str],
    mask_paths: List[str],
    batch_size: Optional[int] = None,
    object_masks: Optional[List[Union[np.ndarray, Any]]] = None,
    resize: int = 224,
    mask_preprocess: Optional[Callable] = None,
    labels: Optional[List[str]] = None,
    transform: Optional[albumentations.Compose] = None,
    mask_smoothing: bool = False,
    defect_transform: Optional[albumentations.Compose] = None,
):
    self.transform = transform
    self.defect_transform = defect_transform
    self.image_paths = image_paths
    self.mask_paths = mask_paths
    self.labels = labels
    self.mask_preprocess = mask_preprocess
    self.resize = resize
    self.object_masks = object_masks
    self.data_len = len(self.image_paths)
    self.batch_size = None if batch_size is None else max(batch_size, self.data_len)
    self.smooth_mask = mask_smoothing

SegmentationDatasetMulticlass(image_paths, mask_paths, idx_to_class, batch_size=None, transform=None, one_hot=False)

Bases: torch.utils.data.Dataset

Custom SegmentationDataset class for loading images and multilabel masks.

Parameters:

  • image_paths (List[str]) –

    List of paths to images.

  • mask_paths (List[str]) –

    List of paths to masks.

  • idx_to_class (Dict) –

    dict with corrispondence btw mask index and classes: {1: class_1, 2: class_2, ..., N: class_N}

  • batch_size (Optional[int]) –

    Batch size.

  • transform (Optional[albumentations.Compose]) –

    Transformations to apply to images and masks.

  • one_hot (bool) –

    if True return a binary mask (n_classxHxW), otherwise the labelled mask HxW. SMP loss requires the second format.

Source code in quadra/datasets/segmentation.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def __init__(
    self,
    image_paths: List[str],
    mask_paths: List[str],
    idx_to_class: Dict,
    batch_size: Optional[int] = None,
    transform: Optional[albumentations.Compose] = None,
    one_hot: bool = False,
):
    self.transform = transform
    self.image_paths = image_paths
    self.mask_paths = mask_paths
    self.idx_to_class = idx_to_class
    self.data_len = len(self.image_paths)
    self.batch_size = None if batch_size is None else max(batch_size, self.data_len)
    self.one_hot = one_hot

__getitem__(index)

Get image and mask.

Source code in quadra/datasets/segmentation.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def __getitem__(self, index):
    """Get image and mask."""
    # This is required to avoid infinite loop when running the dataset outside of a dataloader
    if self.batch_size is not None and self.batch_size == index:
        raise StopIteration
    if self.batch_size is None and self.data_len == index:
        raise StopIteration

    index = index % self.data_len
    image_path = self.image_paths[index]

    image = cv2.imread(str(image_path))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    if (
        self.mask_paths[index] is np.nan
        or self.mask_paths[index] is None
        or not os.path.isfile(self.mask_paths[index])
    ):
        mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
    else:
        mask_path = self.mask_paths[index]
        mask = cv2.imread(str(mask_path), 0)

    # we go back to binary masks avoid transformation errors
    mask = self._preprocess_mask(mask)

    if self.transform is not None:
        masks = list(mask)
        aug = self.transform(image=image, masks=masks)
        image = aug["image"]
        mask = np.stack(aug["masks"])  # C x H x W

    # we compute single channel mask again
    # zero is the background
    if not self.one_hot:  # one hot is done by smp dice loss
        mask_out = np.zeros(mask.shape[1:])
        for i in range(1, mask.shape[0]):
            mask_out[mask[i] == 1] = i
        # mask_out shape -> HxW
    else:
        mask_out = mask
        # mask_out shape -> CxHxW where C is number of classes (included the background)

    return image, mask_out.astype(int), 0

__len__()

Returns the dataset lenght.

Source code in quadra/datasets/segmentation.py
232
233
234
235
236
237
def __len__(self):
    """Returns the dataset lenght."""
    if self.batch_size is None:
        return self.data_len

    return max(self.data_len, self.batch_size)