Skip to content

oxford_pet

OxfordPetSegmentationDataModule(data_path, test_size=0.3, val_size=0.3, seed=42, name='oxford_pet_segmentation_datamodule', dataset=SegmentationDataset, batch_size=32, num_workers=6, train_transform=None, test_transform=None, val_transform=None)

Bases: SegmentationDataModule

OxfordPetSegmentationDataModule.

Parameters:

  • data_path (str) –

    path to the oxford pet dataset

  • test_size (float) –

    Defaults to 0.3.

  • val_size (float) –

    Defaults to 0.3.

  • seed (int) –

    Defaults to 42.

  • name (str) –

    Defaults to "oxford_pet_segmentation_datamodule".

  • dataset (Type[SegmentationDataset]) –

    Defaults to SegmentationDataset.

  • batch_size (int) –

    batch size for training. Defaults to 32.

  • num_workers (int) –

    number of workers for data loading. Defaults to 6.

  • train_transform (Optional[albumentations.Compose]) –

    Train transform. Defaults to None.

  • test_transform (Optional[albumentations.Compose]) –

    Test transform. Defaults to None.

  • val_transform (Optional[albumentations.Compose]) –

    Validation transform. Defaults to None.

Source code in quadra/datamodules/generic/oxford_pet.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def __init__(
    self,
    data_path: str,
    test_size: float = 0.3,
    val_size: float = 0.3,
    seed: int = 42,
    name: str = "oxford_pet_segmentation_datamodule",
    dataset: Type[SegmentationDataset] = SegmentationDataset,
    batch_size: int = 32,
    num_workers: int = 6,
    train_transform: Optional[albumentations.Compose] = None,
    test_transform: Optional[albumentations.Compose] = None,
    val_transform: Optional[albumentations.Compose] = None,
):
    super().__init__(
        data_path=data_path,
        test_size=test_size,
        val_size=val_size,
        seed=seed,
        name=name,
        dataset=dataset,
        train_transform=train_transform,
        test_transform=test_transform,
        val_transform=val_transform,
        batch_size=batch_size,
        num_workers=num_workers,
    )

download_data()

Download the dataset if it is not already downloaded.

Source code in quadra/datamodules/generic/oxford_pet.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def download_data(self):
    """Download the dataset if it is not already downloaded."""
    image_folder = os.path.join(self.data_path, "images")
    annotation_folder = os.path.join(self.data_path, "annotations")
    if not self._check_exists(image_folder, annotation_folder):
        for url, md5 in self._RESOURCES:
            download_and_extract_archive(url, download_root=self.data_path, md5=md5, remove_finished=True)
        log.info("Fixing corrupted files...")
        images_filenames = list(sorted(os.listdir(image_folder)))
        for filename in images_filenames:
            file_wo_ext = os.path.splitext(os.path.basename(filename))[0]
            try:
                mask = cv2.imread(os.path.join(annotation_folder, "trimaps", file_wo_ext + ".png"))
                mask = self._preprocess_mask(mask)
                if np.sum(mask) == 0:
                    os.remove(os.path.join(image_folder, filename))
                    os.remove(os.path.join(annotation_folder, "trimaps", file_wo_ext + ".png"))
                    log.info("Removed %s", filename)
                else:
                    img = cv2.imread(os.path.join(image_folder, filename))
                    cv2.imwrite(os.path.join(image_folder, file_wo_ext + ".jpg"), img)
            except Exception:
                ip = os.path.join(image_folder, filename)
                mp = os.path.join(annotation_folder, "trimaps", file_wo_ext + ".png")
                if os.path.exists(ip):
                    os.remove(ip)
                if os.path.exists(mp):
                    os.remove(mp)
                log.info("Removed %s", filename)