Skip to content

oxford_pet

OxfordPetSegmentationDataModule(data_path, idx_to_class, name='oxford_pet_segmentation_datamodule', dataset=SegmentationDatasetMulticlass, batch_size=32, test_size=0.3, val_size=0.3, seed=42, num_workers=6, train_transform=None, test_transform=None, val_transform=None, **kwargs)

Bases: SegmentationMulticlassDataModule

OxfordPetSegmentationDataModule.

Parameters:

  • data_path (str) –

    path to the oxford pet dataset

  • idx_to_class (dict) –

    dict with corrispondence btw mask index and classes: {1: class_1, 2: class_2, ..., N: class_N} except background class which is 0.

  • name (str, default: 'oxford_pet_segmentation_datamodule' ) –

    Defaults to "oxford_pet_segmentation_datamodule".

  • dataset (type[SegmentationDatasetMulticlass], default: SegmentationDatasetMulticlass ) –

    Defaults to SegmentationDataset.

  • batch_size (int, default: 32 ) –

    batch size for training. Defaults to 32.

  • test_size (float, default: 0.3 ) –

    Defaults to 0.3.

  • val_size (float, default: 0.3 ) –

    Defaults to 0.3.

  • seed (int, default: 42 ) –

    Defaults to 42.

  • num_workers (int, default: 6 ) –

    number of workers for data loading. Defaults to 6.

  • train_transform (Compose | None, default: None ) –

    Train transform. Defaults to None.

  • test_transform (Compose | None, default: None ) –

    Test transform. Defaults to None.

  • val_transform (Compose | None, default: None ) –

    Validation transform. Defaults to None.

Source code in quadra/datamodules/generic/oxford_pet.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def __init__(
    self,
    data_path: str,
    idx_to_class: dict,
    name: str = "oxford_pet_segmentation_datamodule",
    dataset: type[SegmentationDatasetMulticlass] = SegmentationDatasetMulticlass,
    batch_size: int = 32,
    test_size: float = 0.3,
    val_size: float = 0.3,
    seed: int = 42,
    num_workers: int = 6,
    train_transform: albumentations.Compose | None = None,
    test_transform: albumentations.Compose | None = None,
    val_transform: albumentations.Compose | None = None,
    **kwargs: Any,
):
    super().__init__(
        data_path=data_path,
        idx_to_class=idx_to_class,
        name=name,
        dataset=dataset,
        batch_size=batch_size,
        test_size=test_size,
        val_size=val_size,
        seed=seed,
        num_workers=num_workers,
        train_transform=train_transform,
        test_transform=test_transform,
        val_transform=val_transform,
        **kwargs,
    )

_check_exists(image_folder, annotation_folder)

Check if the dataset is already downloaded.

Source code in quadra/datamodules/generic/oxford_pet.py
91
92
93
def _check_exists(self, image_folder: str, annotation_folder: str) -> bool:
    """Check if the dataset is already downloaded."""
    return all(os.path.exists(folder) and os.path.isdir(folder) for folder in (image_folder, annotation_folder))

_prepare_data()

Prepare the data to be used by the DataModule.

Source code in quadra/datamodules/generic/oxford_pet.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def _prepare_data(self) -> None:
    """Prepare the data to be used by the DataModule."""
    self.download_data()

    trainval_split_filepath = os.path.join(self.data_path, "annotations", "trainval.txt")
    with open(trainval_split_filepath) as f:
        split_data = f.read().strip("\n").split("\n")
    trainval_filenames = [
        x.split(" ")[0]
        for x in split_data
        if os.path.exists(os.path.join(self.data_path, "images", x.split(" ")[0] + ".jpg"))
    ]
    train_filenames = [x for i, x in enumerate(trainval_filenames) if i % 10 != 0]
    val_filenames = [x for i, x in enumerate(trainval_filenames) if i % 10 == 0]

    test_split_filepath = os.path.join(self.data_path, "annotations", "test.txt")
    with open(test_split_filepath) as f:
        split_data = f.read().strip("\n").split("\n")
    test_filenames = [
        x.split(" ")[0]
        for x in split_data
        if os.path.exists(os.path.join(self.data_path, "images", x.split(" ")[0] + ".jpg"))
    ]

    df_list = []
    for split_name, filenames in [
        ("train", train_filenames),
        ("val", val_filenames),
        ("test", test_filenames),
    ]:
        samples = [os.path.join(self.data_path, "images", f + ".jpg") for f in filenames]
        masks = [os.path.join(self.data_path, "annotations", "trimaps", f + ".png") for f in filenames]
        targets = [1] * len(filenames)

        df = pd.DataFrame({"samples": samples, "masks": masks, "targets": targets})
        df["split"] = split_name
        df_list.append(df)

    self.data = pd.concat(df_list, axis=0)

_preprocess_mask(mask)

Preprocess mask function that is adapted from https://albumentations.ai/docs/examples/pytorch_semantic_segmentation/.

Parameters:

  • mask (ndarray) –

    mask to be preprocessed

Returns:

Source code in quadra/datamodules/generic/oxford_pet.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def _preprocess_mask(self, mask: np.ndarray) -> np.ndarray:
    """Preprocess mask function that is adapted from
    https://albumentations.ai/docs/examples/pytorch_semantic_segmentation/.

    Args:
        mask: mask to be preprocessed

    Returns:
        binarized mask
    """
    mask = mask.astype(np.float32)
    mask[mask == 2.0] = 0.0
    mask[(mask == 1.0) | (mask == 3.0)] = 1.0
    mask = (mask > 0).astype(np.uint8)
    return mask

download_data()

Download the dataset if it is not already downloaded.

Source code in quadra/datamodules/generic/oxford_pet.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def download_data(self):
    """Download the dataset if it is not already downloaded."""
    image_folder = os.path.join(self.data_path, "images")
    annotation_folder = os.path.join(self.data_path, "annotations")
    if not self._check_exists(image_folder, annotation_folder):
        for url, md5 in self._RESOURCES:
            download_and_extract_archive(url, download_root=self.data_path, md5=md5, remove_finished=True)
        log.info("Fixing corrupted files...")
        images_filenames = sorted(os.listdir(image_folder))
        for filename in images_filenames:
            file_wo_ext = os.path.splitext(os.path.basename(filename))[0]
            try:
                mask = cv2.imread(os.path.join(annotation_folder, "trimaps", file_wo_ext + ".png"))
                mask = self._preprocess_mask(mask)
                if np.sum(mask) == 0:
                    os.remove(os.path.join(image_folder, filename))
                    os.remove(os.path.join(annotation_folder, "trimaps", file_wo_ext + ".png"))
                    log.info("Removed %s", filename)
                else:
                    img = cv2.imread(os.path.join(image_folder, filename))
                    cv2.imwrite(os.path.join(image_folder, file_wo_ext + ".jpg"), img)
            except Exception:
                ip = os.path.join(image_folder, filename)
                mp = os.path.join(annotation_folder, "trimaps", file_wo_ext + ".png")
                if os.path.exists(ip):
                    os.remove(ip)
                if os.path.exists(mp):
                    os.remove(mp)
                log.info("Removed %s", filename)