Skip to content

oxford_pet

OxfordPetSegmentationDataModule(data_path, idx_to_class, name='oxford_pet_segmentation_datamodule', dataset=SegmentationDatasetMulticlass, batch_size=32, test_size=0.3, val_size=0.3, seed=42, num_workers=6, train_transform=None, test_transform=None, val_transform=None, **kwargs)

Bases: SegmentationMulticlassDataModule

OxfordPetSegmentationDataModule.

Parameters:

  • data_path (str) –

    path to the oxford pet dataset

  • idx_to_class (dict) –

    dict with corrispondence btw mask index and classes: {1: class_1, 2: class_2, ..., N: class_N} except background class which is 0.

  • name (str, default: 'oxford_pet_segmentation_datamodule' ) –

    Defaults to "oxford_pet_segmentation_datamodule".

  • dataset (type[SegmentationDatasetMulticlass], default: SegmentationDatasetMulticlass ) –

    Defaults to SegmentationDataset.

  • batch_size (int, default: 32 ) –

    batch size for training. Defaults to 32.

  • test_size (float, default: 0.3 ) –

    Defaults to 0.3.

  • val_size (float, default: 0.3 ) –

    Defaults to 0.3.

  • seed (int, default: 42 ) –

    Defaults to 42.

  • num_workers (int, default: 6 ) –

    number of workers for data loading. Defaults to 6.

  • train_transform (Compose | None, default: None ) –

    Train transform. Defaults to None.

  • test_transform (Compose | None, default: None ) –

    Test transform. Defaults to None.

  • val_transform (Compose | None, default: None ) –

    Validation transform. Defaults to None.

Source code in quadra/datamodules/generic/oxford_pet.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def __init__(
    self,
    data_path: str,
    idx_to_class: dict,
    name: str = "oxford_pet_segmentation_datamodule",
    dataset: type[SegmentationDatasetMulticlass] = SegmentationDatasetMulticlass,
    batch_size: int = 32,
    test_size: float = 0.3,
    val_size: float = 0.3,
    seed: int = 42,
    num_workers: int = 6,
    train_transform: albumentations.Compose | None = None,
    test_transform: albumentations.Compose | None = None,
    val_transform: albumentations.Compose | None = None,
    **kwargs: Any,
):
    super().__init__(
        data_path=data_path,
        idx_to_class=idx_to_class,
        name=name,
        dataset=dataset,
        batch_size=batch_size,
        test_size=test_size,
        val_size=val_size,
        seed=seed,
        num_workers=num_workers,
        train_transform=train_transform,
        test_transform=test_transform,
        val_transform=val_transform,
        **kwargs,
    )

download_data()

Download the dataset if it is not already downloaded.

Source code in quadra/datamodules/generic/oxford_pet.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def download_data(self):
    """Download the dataset if it is not already downloaded."""
    image_folder = os.path.join(self.data_path, "images")
    annotation_folder = os.path.join(self.data_path, "annotations")
    if not self._check_exists(image_folder, annotation_folder):
        for url, md5 in self._RESOURCES:
            download_and_extract_archive(url, download_root=self.data_path, md5=md5, remove_finished=True)
        log.info("Fixing corrupted files...")
        images_filenames = sorted(os.listdir(image_folder))
        for filename in images_filenames:
            file_wo_ext = os.path.splitext(os.path.basename(filename))[0]
            try:
                mask = cv2.imread(os.path.join(annotation_folder, "trimaps", file_wo_ext + ".png"))
                mask = self._preprocess_mask(mask)
                if np.sum(mask) == 0:
                    os.remove(os.path.join(image_folder, filename))
                    os.remove(os.path.join(annotation_folder, "trimaps", file_wo_ext + ".png"))
                    log.info("Removed %s", filename)
                else:
                    img = cv2.imread(os.path.join(image_folder, filename))
                    cv2.imwrite(os.path.join(image_folder, file_wo_ext + ".jpg"), img)
            except Exception:
                ip = os.path.join(image_folder, filename)
                mp = os.path.join(annotation_folder, "trimaps", file_wo_ext + ".png")
                if os.path.exists(ip):
                    os.remove(ip)
                if os.path.exists(mp):
                    os.remove(mp)
                log.info("Removed %s", filename)