Skip to content

datamodules

AnomalyDataModule(data_path, category=None, image_size=None, train_batch_size=32, test_batch_size=32, num_workers=8, train_transform=None, val_transform=None, test_transform=None, seed=0, task='segmentation', mask_suffix=None, create_test_set_if_empty=True, phase='train', name='anomaly_datamodule', valid_area_mask=None, crop_area=None, **kwargs)

Bases: BaseDataModule

Anomalib-like Lightning Data Module.

Parameters:

  • data_path (str) –

    Path to the dataset

  • category (str | None, default: None ) –

    Name of the sub category to use.

  • image_size (int | tuple[int, int] | None, default: None ) –

    Variable to which image is resized.

  • train_batch_size (int, default: 32 ) –

    Training batch size.

  • test_batch_size (int, default: 32 ) –

    Testing batch size.

  • train_transform (Compose | None, default: None ) –

    transformations for training. Defaults to None.

  • val_transform (Compose | None, default: None ) –

    transformations for validation. Defaults to None.

  • test_transform (Compose | None, default: None ) –

    transformations for testing. Defaults to None.

  • num_workers (int, default: 8 ) –

    Number of workers.

  • seed (int, default: 0 ) –

    seed used for the random subset splitting

  • task (str, default: 'segmentation' ) –

    Whether we are interested in segmenting the anomalies (segmentation) or not (classification)

  • mask_suffix (str | None, default: None ) –

    String to append to the base filename to get the mask name, by default for MVTec dataset masks are saved as imagename_mask.png in this case the parameter should be filled with "_mask"

  • create_test_set_if_empty (bool, default: True ) –

    If True, the test set is created from good images if it is empty.

  • phase (str, default: 'train' ) –

    Either train or test.

  • name (str, default: 'anomaly_datamodule' ) –

    Name of the data module.

  • valid_area_mask (str | None, default: None ) –

    Optional path to the mask to use to filter out the valid area of the image. If None, the whole image is considered valid. The mask should match the image size even if the image is cropped.

  • crop_area (tuple[int, int, int, int] | None, default: None ) –

    Optional tuple of 4 integers (x1, y1, x2, y2) to crop the image to the specified area. If None, the whole image is considered valid.

Source code in quadra/datamodules/anomaly.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def __init__(
    self,
    data_path: str,
    category: str | None = None,
    image_size: int | tuple[int, int] | None = None,
    train_batch_size: int = 32,
    test_batch_size: int = 32,
    num_workers: int = 8,
    train_transform: albumentations.Compose | None = None,
    val_transform: albumentations.Compose | None = None,
    test_transform: albumentations.Compose | None = None,
    seed: int = 0,
    task: str = "segmentation",
    mask_suffix: str | None = None,
    create_test_set_if_empty: bool = True,
    phase: str = "train",
    name: str = "anomaly_datamodule",
    valid_area_mask: str | None = None,
    crop_area: tuple[int, int, int, int] | None = None,
    **kwargs,
) -> None:
    super().__init__(
        data_path=data_path,
        name=name,
        seed=seed,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        num_workers=num_workers,
        **kwargs,
    )

    self.root = data_path
    self.category = category
    self.data_path = os.path.join(self.root, self.category) if self.category is not None else self.root
    self.image_size = image_size

    self.train_batch_size = train_batch_size
    self.test_batch_size = test_batch_size
    self.task = task

    self.train_dataset: AnomalyDataset
    self.test_dataset: AnomalyDataset
    self.val_dataset: AnomalyDataset
    self.mask_suffix = mask_suffix
    self.create_test_set_if_empty = create_test_set_if_empty
    self.phase = phase
    self.valid_area_mask = valid_area_mask
    self.crop_area = crop_area

val_data: pd.DataFrame property

Get validation data.

predict_dataloader()

Returns a dataloader used for predictions.

Source code in quadra/datamodules/anomaly.py
172
173
174
175
176
177
178
179
180
def predict_dataloader(self) -> DataLoader:
    """Returns a dataloader used for predictions."""
    return DataLoader(
        self.test_dataset,
        shuffle=False,
        batch_size=self.test_batch_size,
        num_workers=self.num_workers,
        pin_memory=True,
    )

setup(stage=None)

Setup data module based on stages of training.

Source code in quadra/datamodules/anomaly.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def setup(self, stage: str | None = None) -> None:
    """Setup data module based on stages of training."""
    if stage == "fit" and self.phase == "train":
        self.train_dataset = AnomalyDataset(
            transform=self.train_transform,
            task=self.task,
            samples=self.train_data,
            valid_area_mask=self.valid_area_mask,
            crop_area=self.crop_area,
        )

        if len(self.val_data) == 0:
            log.info("Validation dataset is empty, using test set instead")

        self.val_dataset = AnomalyDataset(
            transform=self.test_transform,
            task=self.task,
            samples=self.val_data if len(self.val_data) > 0 else self.data,
            valid_area_mask=self.valid_area_mask,
            crop_area=self.crop_area,
        )
    if stage == "test" or self.phase == "test":
        self.test_dataset = AnomalyDataset(
            transform=self.test_transform,
            task=self.task,
            samples=self.test_data,
            valid_area_mask=self.valid_area_mask,
            crop_area=self.crop_area,
        )

test_dataloader()

Get test dataloader.

Source code in quadra/datamodules/anomaly.py
162
163
164
165
166
167
168
169
170
def test_dataloader(self) -> DataLoader:
    """Get test dataloader."""
    return DataLoader(
        self.test_dataset,
        shuffle=False,
        batch_size=self.test_batch_size,
        num_workers=self.num_workers,
        pin_memory=True,
    )

train_dataloader()

Get train dataloader.

Source code in quadra/datamodules/anomaly.py
142
143
144
145
146
147
148
149
150
def train_dataloader(self) -> DataLoader:
    """Get train dataloader."""
    return DataLoader(
        self.train_dataset,
        shuffle=True,
        batch_size=self.train_batch_size,
        num_workers=self.num_workers,
        pin_memory=True,
    )

val_dataloader()

Get validation dataloader.

Source code in quadra/datamodules/anomaly.py
152
153
154
155
156
157
158
159
160
def val_dataloader(self) -> DataLoader:
    """Get validation dataloader."""
    return DataLoader(
        dataset=self.val_dataset,
        shuffle=False,
        batch_size=self.test_batch_size,
        num_workers=self.num_workers,
        pin_memory=True,
    )

ClassificationDataModule(data_path, dataset=ImageClassificationListDataset, name='classification_datamodule', num_workers=8, batch_size=32, seed=42, val_size=0.2, test_size=0.2, num_data_class=None, exclude_filter=None, include_filter=None, label_map=None, load_aug_images=False, aug_name=None, n_aug_to_take=4, replace_str_from=None, replace_str_to=None, train_transform=None, val_transform=None, test_transform=None, train_split_file=None, test_split_file=None, val_split_file=None, class_to_idx=None, **kwargs)

Bases: BaseDataModule

Base class single folder based classification datamodules. If there is no nested folders, use this class.

Parameters:

  • data_path (str) –

    Path to the data main folder.

  • name (str, default: 'classification_datamodule' ) –

    The name for the data module. Defaults to "classification_datamodule".

  • num_workers (int, default: 8 ) –

    Number of workers for dataloaders. Defaults to 16.

  • batch_size (int, default: 32 ) –

    Batch size. Defaults to 32.

  • seed (int, default: 42 ) –

    Random generator seed. Defaults to 42.

  • dataset (type[ImageClassificationListDataset], default: ImageClassificationListDataset ) –

    Dataset class.

  • val_size (float | None, default: 0.2 ) –

    The validation split. Defaults to 0.2.

  • test_size (float, default: 0.2 ) –

    The test split. Defaults to 0.2.

  • exclude_filter (list[str] | None, default: None ) –

    The filter for excluding folders. Defaults to None.

  • include_filter (list[str] | None, default: None ) –

    The filter for including folders. Defaults to None.

  • label_map (dict[str, Any] | None, default: None ) –

    The mapping for labels. Defaults to None.

  • num_data_class (int | None, default: None ) –

    The number of samples per class. Defaults to None.

  • train_transform (Compose | None, default: None ) –

    Transformations for train dataset. Defaults to None.

  • val_transform (Compose | None, default: None ) –

    Transformations for validation dataset. Defaults to None.

  • test_transform (Compose | None, default: None ) –

    Transformations for test dataset. Defaults to None.

  • train_split_file (str | None, default: None ) –

    The file with train split. Defaults to None.

  • val_split_file (str | None, default: None ) –

    The file with validation split. Defaults to None.

  • test_split_file (str | None, default: None ) –

    The file with test split. Defaults to None.

  • class_to_idx (dict[str, int] | None, default: None ) –

    The mapping from class name to index. Defaults to None.

  • **kwargs (Any, default: {} ) –

    Additional arguments for BaseDataModule.

Source code in quadra/datamodules/classification.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self,
    data_path: str,
    dataset: type[ImageClassificationListDataset] = ImageClassificationListDataset,
    name: str = "classification_datamodule",
    num_workers: int = 8,
    batch_size: int = 32,
    seed: int = 42,
    val_size: float | None = 0.2,
    test_size: float = 0.2,
    num_data_class: int | None = None,
    exclude_filter: list[str] | None = None,
    include_filter: list[str] | None = None,
    label_map: dict[str, Any] | None = None,
    load_aug_images: bool = False,
    aug_name: str | None = None,
    n_aug_to_take: int | None = 4,
    replace_str_from: str | None = None,
    replace_str_to: str | None = None,
    train_transform: albumentations.Compose | None = None,
    val_transform: albumentations.Compose | None = None,
    test_transform: albumentations.Compose | None = None,
    train_split_file: str | None = None,
    test_split_file: str | None = None,
    val_split_file: str | None = None,
    class_to_idx: dict[str, int] | None = None,
    **kwargs: Any,
):
    super().__init__(
        data_path=data_path,
        name=name,
        seed=seed,
        batch_size=batch_size,
        num_workers=num_workers,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        load_aug_images=load_aug_images,
        aug_name=aug_name,
        n_aug_to_take=n_aug_to_take,
        replace_str_from=replace_str_from,
        replace_str_to=replace_str_to,
        **kwargs,
    )
    self.replace_str = None
    self.exclude_filter = exclude_filter
    self.include_filter = include_filter
    self.val_size = val_size
    self.test_size = test_size
    self.label_map = label_map
    self.num_data_class = num_data_class
    self.dataset = dataset
    self.train_split_file = train_split_file
    self.test_split_file = test_split_file
    self.val_split_file = val_split_file
    self.class_to_idx: dict[str, int] | None

    if class_to_idx is not None:
        self.class_to_idx = class_to_idx
        self.num_classes = len(self.class_to_idx)
    else:
        self.class_to_idx = self._find_classes_from_data_path(self.data_path)
        if self.class_to_idx is None:
            log.warning("Could not build a class_to_idx from the data_path subdirectories")
            self.num_classes = 0
        else:
            self.num_classes = len(self.class_to_idx)

predict_dataloader()

Returns a dataloader used for predictions.

Source code in quadra/datamodules/classification.py
418
419
420
def predict_dataloader(self) -> DataLoader:
    """Returns a dataloader used for predictions."""
    return self.test_dataloader()

setup(stage=None)

Setup data module based on stages of training.

Source code in quadra/datamodules/classification.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
def setup(self, stage: str | None = None) -> None:
    """Setup data module based on stages of training."""
    if stage in ["train", "fit"]:
        self.train_dataset = self.dataset(
            samples=self.data[self.data["split"] == "train"]["samples"].tolist(),
            targets=self.data[self.data["split"] == "train"]["targets"].tolist(),
            transform=self.train_transform,
            class_to_idx=self.class_to_idx,
        )
        self.val_dataset = self.dataset(
            samples=self.data[self.data["split"] == "val"]["samples"].tolist(),
            targets=self.data[self.data["split"] == "val"]["targets"].tolist(),
            transform=self.val_transform,
            class_to_idx=self.class_to_idx,
        )
    if stage in ["test", "predict"]:
        self.test_dataset = self.dataset(
            samples=self.data[self.data["split"] == "test"]["samples"].tolist(),
            targets=self.data[self.data["split"] == "test"]["targets"].tolist(),
            transform=self.test_transform,
            class_to_idx=self.class_to_idx,
        )

test_dataloader()

Returns the test dataloader.

Raises:

  • ValueError

    If test dataset is not initialized.

Returns:

  • DataLoader

    test dataloader.

Source code in quadra/datamodules/classification.py
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
def test_dataloader(self) -> DataLoader:
    """Returns the test dataloader.

    Raises:
        ValueError: If test dataset is not initialized.


    Returns:
        test dataloader.
    """
    if not self.test_dataset_available:
        raise ValueError("Test dataset is not initialized")

    loader = DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
        persistent_workers=self.num_workers > 0,
    )
    return loader

train_dataloader()

Returns the train dataloader.

Raises:

  • ValueError

    If train dataset is not initialized.

Returns:

  • DataLoader

    Train dataloader.

Source code in quadra/datamodules/classification.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
def train_dataloader(self) -> DataLoader:
    """Returns the train dataloader.

    Raises:
        ValueError: If train dataset is not initialized.

    Returns:
        Train dataloader.
    """
    if not self.train_dataset_available:
        raise ValueError("Train dataset is not initialized")
    if not isinstance(self.train_dataset, torch.utils.data.Dataset):
        raise ValueError("Train dataset has to be single `torch.utils.data.Dataset` instance.")
    return DataLoader(
        self.train_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
        persistent_workers=self.num_workers > 0,
    )

val_dataloader()

Returns the validation dataloader.

Raises:

  • ValueError

    If validation dataset is not initialized.

Returns:

  • DataLoader

    val dataloader.

Source code in quadra/datamodules/classification.py
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
def val_dataloader(self) -> DataLoader:
    """Returns the validation dataloader.

    Raises:
        ValueError: If validation dataset is not initialized.

    Returns:
        val dataloader.
    """
    if not self.val_dataset_available:
        raise ValueError("Validation dataset is not initialized")
    if not isinstance(self.val_dataset, torch.utils.data.Dataset):
        raise ValueError("Validation dataset has to be single `torch.utils.data.Dataset` instance.")
    return DataLoader(
        self.val_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
        persistent_workers=self.num_workers > 0,
    )

MultilabelClassificationDataModule(data_path, images_and_labels_file=None, train_split_file=None, test_split_file=None, val_split_file=None, name='multilabel_datamodule', dataset=MultilabelClassificationDataset, num_classes=None, num_workers=16, batch_size=64, test_batch_size=64, seed=42, val_size=0.2, test_size=0.2, train_transform=None, val_transform=None, test_transform=None, class_to_idx=None, **kwargs)

Bases: BaseDataModule

Base class for all multi-label modules.

Parameters:

  • data_path (str) –

    Path to the data main folder.

  • images_and_labels_file (str | None, default: None ) –

    a path to a txt file containing the relative (to data_path) path of images with their relative labels, in a comma-separated way. E.g.:

    • path1,l1,l2,l3
    • path2,l4,l5
    • ...

    One of images_and_label and both train_split_file and test_split_file must be set. Defaults to None.

  • name (str, default: 'multilabel_datamodule' ) –

    The name for the data module. Defaults to "multilabel_datamodule".

  • dataset (Callable, default: MultilabelClassificationDataset ) –

    a callable returning a torch.utils.data.Dataset class.

  • num_classes (int | None, default: None ) –

    the number of classes in the dataset. This is used to create one-hot encoded targets. Defaults to None.

  • num_workers (int, default: 16 ) –

    Number of workers for dataloaders. Defaults to 16.

  • batch_size (int, default: 64 ) –

    Training batch size. Defaults to 64.

  • test_batch_size (int, default: 64 ) –

    Testing batch size. Defaults to 64.

  • seed (int, default: 42 ) –

    Random generator seed. Defaults to SegmentationEvalua2.

  • val_size (float | None, default: 0.2 ) –

    The validation split. Defaults to 0.2.

  • test_size (float | None, default: 0.2 ) –

    The test split. Defaults to 0.2.

  • train_transform (Compose | None, default: None ) –

    Transformations for train dataset. Defaults to None.

  • val_transform (Compose | None, default: None ) –

    Transformations for validation dataset. Defaults to None.

  • test_transform (Compose | None, default: None ) –

    Transformations for test dataset. Defaults to None.

  • train_split_file (str | None, default: None ) –

    The file with train split. Defaults to None.

  • val_split_file (str | None, default: None ) –

    The file with validation split. Defaults to None.

  • test_split_file (str | None, default: None ) –

    The file with test split. Defaults to None.

  • class_to_idx (dict[str, int] | None, default: None ) –

    a clss to idx dictionary. Defaults to None.

Source code in quadra/datamodules/classification.py
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
def __init__(
    self,
    data_path: str,
    images_and_labels_file: str | None = None,
    train_split_file: str | None = None,
    test_split_file: str | None = None,
    val_split_file: str | None = None,
    name: str = "multilabel_datamodule",
    dataset: Callable = MultilabelClassificationDataset,
    num_classes: int | None = None,
    num_workers: int = 16,
    batch_size: int = 64,
    test_batch_size: int = 64,
    seed: int = 42,
    val_size: float | None = 0.2,
    test_size: float | None = 0.2,
    train_transform: albumentations.Compose | None = None,
    val_transform: albumentations.Compose | None = None,
    test_transform: albumentations.Compose | None = None,
    class_to_idx: dict[str, int] | None = None,
    **kwargs,
):
    super().__init__(
        data_path=data_path,
        name=name,
        num_workers=num_workers,
        batch_size=batch_size,
        seed=seed,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        **kwargs,
    )
    if not (images_and_labels_file is not None or (train_split_file is not None and test_split_file is not None)):
        raise ValueError(
            "Either `images_and_labels_file` or both `train_split_file` and `test_split_file` must be set"
        )
    self.images_and_labels_file = images_and_labels_file
    self.dataset = dataset
    self.num_classes = num_classes
    self.train_batch_size = batch_size
    self.test_batch_size = test_batch_size
    self.val_size = val_size
    self.test_size = test_size
    self.train_split_file = train_split_file
    self.test_split_file = test_split_file
    self.val_split_file = val_split_file
    self.class_to_idx = class_to_idx
    self.train_dataset: MultilabelClassificationDataset
    self.val_dataset: MultilabelClassificationDataset
    self.test_dataset: MultilabelClassificationDataset

predict_dataloader()

Returns a dataloader used for predictions.

Source code in quadra/datamodules/classification.py
1001
1002
1003
def predict_dataloader(self) -> DataLoader:
    """Returns a dataloader used for predictions."""
    return self.test_dataloader()

setup(stage=None)

Setup data module based on stages of training.

Source code in quadra/datamodules/classification.py
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
def setup(self, stage: str | None = None) -> None:
    """Setup data module based on stages of training."""
    if stage in ["train", "fit"]:
        train_samples = self.data[self.data["split"] == "train"]["samples"].tolist()
        train_targets = self.data[self.data["split"] == "train"]["targets"].tolist()
        val_samples = self.data[self.data["split"] == "val"]["samples"].tolist()
        val_targets = self.data[self.data["split"] == "val"]["targets"].tolist()
        self.train_dataset = self.dataset(
            samples=train_samples,
            targets=train_targets,
            transform=self.train_transform,
            class_to_idx=self.class_to_idx,
        )
        self.val_dataset = self.dataset(
            samples=val_samples,
            targets=val_targets,
            transform=self.val_transform,
            class_to_idx=self.class_to_idx,
        )
    if stage == "test":
        test_samples = self.data[self.data["split"] == "test"]["samples"].tolist()
        test_targets = self.data[self.data["split"] == "test"]["targets"].tolist()
        self.test_dataset = self.dataset(
            samples=test_samples,
            targets=test_targets,
            transform=self.test_transform,
            class_to_idx=self.class_to_idx,
        )

test_dataloader()

Returns the test dataloader.

Raises:

  • ValueError

    If test dataset is not initialized.

Returns:

  • DataLoader

    test dataloader.

Source code in quadra/datamodules/classification.py
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
def test_dataloader(self) -> DataLoader:
    """Returns the test dataloader.

    Raises:
        ValueError: If test dataset is not initialized.


    Returns:
        test dataloader.
    """
    if not self.test_dataset_available:
        raise ValueError("Test dataset is not initialized")

    loader = DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
        persistent_workers=self.num_workers > 0,
    )
    return loader

train_dataloader()

Returns the train dataloader.

Raises:

  • ValueError

    If train dataset is not initialized.

Returns:

  • DataLoader

    Train dataloader.

Source code in quadra/datamodules/classification.py
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
def train_dataloader(self) -> DataLoader:
    """Returns the train dataloader.

    Raises:
        ValueError: If train dataset is not initialized.

    Returns:
        Train dataloader.
    """
    if not self.train_dataset_available:
        raise ValueError("Train dataset is not initialized")
    return DataLoader(
        self.train_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
        persistent_workers=self.num_workers > 0,
    )

val_dataloader()

Returns the validation dataloader.

Raises:

  • ValueError

    If validation dataset is not initialized.

Returns:

  • DataLoader

    val dataloader.

Source code in quadra/datamodules/classification.py
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
def val_dataloader(self) -> DataLoader:
    """Returns the validation dataloader.

    Raises:
        ValueError: If validation dataset is not initialized.

    Returns:
        val dataloader.
    """
    if not self.val_dataset_available:
        raise ValueError("Validation dataset is not initialized")
    return DataLoader(
        self.val_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
        persistent_workers=self.num_workers > 0,
    )

PatchSklearnClassificationDataModule(data_path, class_to_idx, name='patch_classification_datamodule', train_filename='dataset.txt', exclude_filter=None, include_filter=None, seed=42, batch_size=32, num_workers=6, train_transform=None, val_transform=None, test_transform=None, balance_classes=False, class_to_skip_training=None, **kwargs)

Bases: BaseDataModule

DataModule for patch classification.

Parameters:

  • data_path (str) –

    Location of the dataset

  • name (str, default: 'patch_classification_datamodule' ) –

    Name of the datamodule

  • train_filename (str, default: 'dataset.txt' ) –

    Name of the file containing the list of training samples

  • exclude_filter (list[str] | None, default: None ) –

    Filter to exclude samples from the dataset

  • include_filter (list[str] | None, default: None ) –

    Filter to include samples from the dataset

  • class_to_idx (dict) –

    Dictionary mapping class names to indices

  • seed (int, default: 42 ) –

    Random seed

  • batch_size (int, default: 32 ) –

    Batch size

  • num_workers (int, default: 6 ) –

    Number of workers

  • train_transform (Compose | None, default: None ) –

    Transform to apply to the training samples

  • val_transform (Compose | None, default: None ) –

    Transform to apply to the validation samples

  • test_transform (Compose | None, default: None ) –

    Transform to apply to the test samples

  • balance_classes (bool, default: False ) –

    If True repeat low represented classes

  • class_to_skip_training (list | None, default: None ) –

    List of classes skipped during training.

Source code in quadra/datamodules/patch.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def __init__(
    self,
    data_path: str,
    class_to_idx: dict,
    name: str = "patch_classification_datamodule",
    train_filename: str = "dataset.txt",
    exclude_filter: list[str] | None = None,
    include_filter: list[str] | None = None,
    seed: int = 42,
    batch_size: int = 32,
    num_workers: int = 6,
    train_transform: albumentations.Compose | None = None,
    val_transform: albumentations.Compose | None = None,
    test_transform: albumentations.Compose | None = None,
    balance_classes: bool = False,
    class_to_skip_training: list | None = None,
    **kwargs,
):
    super().__init__(
        data_path=data_path,
        name=name,
        seed=seed,
        num_workers=num_workers,
        batch_size=batch_size,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        **kwargs,
    )
    self.class_to_idx = class_to_idx
    self.balance_classes = balance_classes
    self.train_filename = train_filename
    self.include_filter = include_filter
    self.exclude_filter = exclude_filter
    self.class_to_skip_training = class_to_skip_training

    self.train_folder = os.path.join(self.data_path, "train")
    self.val_folder = os.path.join(self.data_path, "val")
    self.test_folder = os.path.join(self.data_path, "test")
    self.info: PatchDatasetInfo
    self.train_dataset: PatchSklearnClassificationTrainDataset
    self.val_dataset: ImageClassificationListDataset
    self.test_dataset: ImageClassificationListDataset

setup(stage=None)

Setup function.

Source code in quadra/datamodules/patch.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def setup(self, stage: str | None = None) -> None:
    """Setup function."""
    if stage == "fit":
        self.train_dataset = PatchSklearnClassificationTrainDataset(
            data_path=self.data_path,
            class_to_idx=self.class_to_idx,
            samples=self.data[self.data["split"] == "train"]["samples"].tolist(),
            targets=self.data[self.data["split"] == "train"]["targets"].tolist(),
            transform=self.train_transform,
            balance_classes=self.balance_classes,
        )

        self.val_dataset = ImageClassificationListDataset(
            class_to_idx=self.class_to_idx,
            samples=self.data[self.data["split"] == "val"]["samples"].tolist(),
            targets=self.data[self.data["split"] == "val"]["targets"].tolist(),
            transform=self.val_transform,
            allow_missing_label=False,
        )

    elif stage in ["test", "predict"]:
        self.test_dataset = ImageClassificationListDataset(
            class_to_idx=self.class_to_idx,
            samples=self.data[self.data["split"] == "test"]["samples"].tolist(),
            targets=self.data[self.data["split"] == "test"]["targets"].tolist(),
            transform=self.test_transform,
            allow_missing_label=True,
        )

test_dataloader()

Return the test dataloader.

Source code in quadra/datamodules/patch.py
178
179
180
181
182
183
184
185
186
187
188
189
190
def test_dataloader(self) -> DataLoader:
    """Return the test dataloader."""
    if not self.test_dataset_available:
        raise ValueError("No test dataset is available")

    return DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
    )

train_dataloader()

Return the train dataloader.

Source code in quadra/datamodules/patch.py
152
153
154
155
156
157
158
159
160
161
162
163
def train_dataloader(self) -> DataLoader:
    """Return the train dataloader."""
    if not self.train_dataset_available:
        raise ValueError("No training sample is available")
    return DataLoader(
        self.train_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
    )

val_dataloader()

Return the validation dataloader.

Source code in quadra/datamodules/patch.py
165
166
167
168
169
170
171
172
173
174
175
176
def val_dataloader(self) -> DataLoader:
    """Return the validation dataloader."""
    if not self.val_dataset_available:
        raise ValueError("No validation dataset is available")
    return DataLoader(
        self.val_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
    )

SSLDataModule(data_path, augmentation_dataset, name='ssl_datamodule', split_validation=True, **kwargs)

Bases: ClassificationDataModule

Base class for all data modules for self supervised learning data modules.

Parameters:

  • data_path (str) –

    Path to the data main folder.

  • augmentation_dataset (TwoAugmentationDataset | TwoSetAugmentationDataset) –

    Augmentation dataset for training dataset.

  • name (str, default: 'ssl_datamodule' ) –

    The name for the data module. Defaults to "ssl_datamodule".

  • split_validation (bool, default: True ) –

    Whether to split the validation set if . Defaults to True.

  • **kwargs (Any, default: {} ) –

    The keyword arguments for the classification data module. Defaults to None.

Source code in quadra/datamodules/ssl.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
    self,
    data_path: str,
    augmentation_dataset: TwoAugmentationDataset | TwoSetAugmentationDataset,
    name: str = "ssl_datamodule",
    split_validation: bool = True,
    **kwargs: Any,
):
    super().__init__(
        data_path=data_path,
        name=name,
        **kwargs,
    )
    self.augmentation_dataset = augmentation_dataset
    self.classifier_train_dataset: torch.utils.data.Dataset | None = None
    self.split_validation = split_validation

classifier_train_dataloader()

Returns classifier train dataloader.

Source code in quadra/datamodules/ssl.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def classifier_train_dataloader(self) -> DataLoader:
    """Returns classifier train dataloader."""
    if self.classifier_train_dataset is None:
        raise ValueError("Classifier train dataset is not initialized")

    loader = DataLoader(
        self.classifier_train_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
        persistent_workers=self.num_workers > 0,
    )
    return loader

setup(stage=None)

Setup data module based on stages of training.

Source code in quadra/datamodules/ssl.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def setup(self, stage: str | None = None) -> None:
    """Setup data module based on stages of training."""
    if stage == "fit":
        self.train_dataset = self.dataset(
            samples=self.train_data["samples"].tolist(),
            targets=self.train_data["targets"].tolist(),
            transform=self.train_transform,
        )

        if np.unique(self.train_data["targets"]).shape[0] > 1 and not self.split_validation:
            self.classifier_train_dataset = self.dataset(
                samples=self.train_data["samples"].tolist(),
                targets=self.train_data["targets"].tolist(),
                transform=self.val_transform,
            )
            self.val_dataset = self.dataset(
                samples=self.val_data["samples"].tolist(),
                targets=self.val_data["targets"].tolist(),
                transform=self.val_transform,
            )
        else:
            train_classifier_samples, val_samples, train_classifier_targets, val_targets = train_test_split(
                self.val_data["samples"],
                self.val_data["targets"],
                test_size=0.3,
                random_state=self.seed,
                stratify=self.val_data["targets"],
            )

            self.classifier_train_dataset = self.dataset(
                samples=train_classifier_samples,
                targets=train_classifier_targets,
                transform=self.test_transform,
            )

            self.val_dataset = self.dataset(
                samples=val_samples,
                targets=val_targets,
                transform=self.val_transform,
            )

            log.warning(
                "The training set contains only one class and cannot be used to train a classifier. To overcome "
                "this issue 70% of the validation set is used to train the classifier. The remaining will be used "
                "as standard validation. To disable this behaviour set the `split_validation` parameter to False."
            )
            self._check_train_dataset_config()
    if stage == "test":
        self.test_dataset = self.dataset(
            samples=self.test_data["samples"].tolist(),
            targets=self.test_data["targets"].tolist(),
            transform=self.test_transform,
        )

train_dataloader()

Returns train dataloader.

Source code in quadra/datamodules/ssl.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def train_dataloader(self) -> DataLoader:
    """Returns train dataloader."""
    if not isinstance(self.train_dataset, torch.utils.data.Dataset):
        raise ValueError("Train dataset is not a subclass of `torch.utils.data.Dataset`")
    self.augmentation_dataset.dataset = self.train_dataset
    loader = DataLoader(
        self.augmentation_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
        persistent_workers=self.num_workers > 0,
    )
    return loader

SegmentationDataModule(data_path, name='segmentation_datamodule', test_size=0.3, val_size=0.3, seed=42, dataset=SegmentationDataset, batch_size=32, num_workers=6, train_transform=None, test_transform=None, val_transform=None, train_split_file=None, test_split_file=None, val_split_file=None, num_data_class=None, exclude_good=False, **kwargs)

Bases: BaseDataModule

Base class for segmentation datasets.

Parameters:

  • data_path (str) –

    Path to the data main folder.

  • name (str, default: 'segmentation_datamodule' ) –

    The name for the data module. Defaults to "segmentation_datamodule".

  • val_size (float, default: 0.3 ) –

    The validation split. Defaults to 0.2.

  • test_size (float, default: 0.3 ) –

    The test split. Defaults to 0.2.

  • seed (int, default: 42 ) –

    Random generator seed. Defaults to 42.

  • dataset (type[SegmentationDataset], default: SegmentationDataset ) –

    Dataset class.

  • batch_size (int, default: 32 ) –

    Batch size. Defaults to 32.

  • num_workers (int, default: 6 ) –

    Number of workers for dataloaders. Defaults to 16.

  • train_transform (Compose | None, default: None ) –

    Transformations for train dataset. Defaults to None.

  • val_transform (Compose | None, default: None ) –

    Transformations for validation dataset. Defaults to None.

  • test_transform (Compose | None, default: None ) –

    Transformations for test dataset. Defaults to None.

  • num_data_class (int | None, default: None ) –

    The number of samples per class. Defaults to None.

  • exclude_good (bool, default: False ) –

    If True, exclude good samples from the dataset. Defaults to False.

Source code in quadra/datamodules/segmentation.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __init__(
    self,
    data_path: str,
    name: str = "segmentation_datamodule",
    test_size: float = 0.3,
    val_size: float = 0.3,
    seed: int = 42,
    dataset: type[SegmentationDataset] = SegmentationDataset,
    batch_size: int = 32,
    num_workers: int = 6,
    train_transform: albumentations.Compose | None = None,
    test_transform: albumentations.Compose | None = None,
    val_transform: albumentations.Compose | None = None,
    train_split_file: str | None = None,
    test_split_file: str | None = None,
    val_split_file: str | None = None,
    num_data_class: int | None = None,
    exclude_good: bool = False,
    **kwargs: Any,
):
    super().__init__(
        data_path=data_path,
        name=name,
        seed=seed,
        batch_size=batch_size,
        num_workers=num_workers,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        **kwargs,
    )
    self.test_size = test_size
    self.val_size = val_size
    self.num_data_class = num_data_class
    self.exclude_good = exclude_good
    self.train_split_file = train_split_file
    self.test_split_file = test_split_file
    self.val_split_file = val_split_file
    self.dataset = dataset
    self.train_dataset: SegmentationDataset
    self.val_dataset: SegmentationDataset
    self.test_dataset: SegmentationDataset

predict_dataloader()

Returns a dataloader used for predictions.

Source code in quadra/datamodules/segmentation.py
365
366
367
def predict_dataloader(self) -> DataLoader:
    """Returns a dataloader used for predictions."""
    return self.test_dataloader()

setup(stage=None)

Setup data module based on stages of training.

Source code in quadra/datamodules/segmentation.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
def setup(self, stage=None):
    """Setup data module based on stages of training."""
    if stage in ["fit", "train"]:
        self.train_dataset = self.dataset(
            image_paths=self.data[self.data["split"] == "train"]["samples"].tolist(),
            mask_paths=self.data[self.data["split"] == "train"]["masks"].tolist(),
            mask_preprocess=self._preprocess_mask,
            labels=self.data[self.data["split"] == "train"]["targets"].tolist(),
            object_masks=None,
            transform=self.train_transform,
            batch_size=None,
            defect_transform=None,
            resize=None,
        )
        self.val_dataset = self.dataset(
            image_paths=self.data[self.data["split"] == "val"]["samples"].tolist(),
            mask_paths=self.data[self.data["split"] == "val"]["masks"].tolist(),
            defect_transform=None,
            labels=self.data[self.data["split"] == "val"]["targets"].tolist(),
            object_masks=None,
            batch_size=None,
            mask_preprocess=self._preprocess_mask,
            transform=self.test_transform,
            resize=None,
        )
    elif stage == "test":
        self.test_dataset = self.dataset(
            image_paths=self.data[self.data["split"] == "test"]["samples"].tolist(),
            mask_paths=self.data[self.data["split"] == "test"]["masks"].tolist(),
            labels=self.data[self.data["split"] == "test"]["targets"].tolist(),
            object_masks=None,
            batch_size=None,
            mask_preprocess=self._preprocess_mask,
            transform=self.test_transform,
            resize=None,
        )
    elif stage == "predict":
        pass
    else:
        raise ValueError(f"Unknown stage {stage}")

test_dataloader()

Returns the test dataloader.

Raises:

  • ValueError

    If test dataset is not initialized.

Returns:

  • DataLoader

    test dataloader.

Source code in quadra/datamodules/segmentation.py
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
def test_dataloader(self) -> DataLoader:
    """Returns the test dataloader.

    Raises:
        ValueError: If test dataset is not initialized.


    Returns:
        test dataloader.
    """
    if not self.test_dataset_available:
        raise ValueError("Test dataset is not initialized")

    loader = DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
        persistent_workers=self.num_workers > 0,
    )
    return loader

train_dataloader()

Returns the train dataloader.

Raises:

  • ValueError

    If train dataset is not initialized.

Returns:

  • DataLoader

    Train dataloader.

Source code in quadra/datamodules/segmentation.py
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
def train_dataloader(self) -> DataLoader:
    """Returns the train dataloader.

    Raises:
        ValueError: If train dataset is not initialized.

    Returns:
        Train dataloader.
    """
    if not self.train_dataset_available:
        raise ValueError("Train dataset is not initialized")

    return DataLoader(
        self.train_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
        persistent_workers=self.num_workers > 0,
    )

val_dataloader()

Returns the validation dataloader.

Raises:

  • ValueError

    If validation dataset is not initialized.

Returns:

  • DataLoader

    val dataloader.

Source code in quadra/datamodules/segmentation.py
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def val_dataloader(self) -> DataLoader:
    """Returns the validation dataloader.

    Raises:
        ValueError: If validation dataset is not initialized.

    Returns:
        val dataloader.
    """
    if not self.val_dataset_available:
        raise ValueError("Validation dataset is not initialized")

    return DataLoader(
        self.val_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
        persistent_workers=self.num_workers > 0,
    )

SegmentationMulticlassDataModule(data_path, idx_to_class, name='multiclass_segmentation_datamodule', dataset=SegmentationDatasetMulticlass, batch_size=32, test_size=0.3, val_size=0.3, seed=42, num_workers=6, train_transform=None, test_transform=None, val_transform=None, train_split_file=None, test_split_file=None, val_split_file=None, exclude_good=False, num_data_train=None, one_hot_encoding=False, **kwargs)

Bases: BaseDataModule

Base class for segmentation datasets with multiple classes.

Parameters:

  • data_path

    Path to the data main folder.

  • idx_to_class (dict) –

    dict with corrispondence btw mask index and classes: {1: class_1, 2: class_2, ..., N: class_N} except background class which is 0.

  • name

    The name for the data module. Defaults to "multiclass_segmentation_datamodule".

  • dataset (type[SegmentationDatasetMulticlass], default: SegmentationDatasetMulticlass ) –

    Dataset class.

  • batch_size

    Batch size. Defaults to 32.

  • val_size

    The validation split. Defaults to 0.3.

  • test_size

    The test split. Defaults to 0.3.

  • seed

    Random generator seed. Defaults to 42.

  • num_workers (int, default: 6 ) –

    Number of workers for dataloaders. Defaults to 6.

  • train_transform (Compose | None, default: None ) –

    Transformations for train dataset. Defaults to None.

  • val_transform

    Transformations for validation dataset. Defaults to None.

  • test_transform

    Transformations for test dataset. Defaults to None.

  • train_split_file (str | None, default: None ) –

    path to txt file with training samples list

  • val_split_file (str | None, default: None ) –

    path to txt file with validation samples list

  • test_split_file (str | None, default: None ) –

    path to txt file with test samples list

  • exclude_good

    If True, exclude good samples from the dataset. Defaults to False.

  • num_data_train (int | None, default: None ) –

    number of samples to use in the train split (shuffle the samples and pick the first num_data_train)

  • one_hot_encoding (bool, default: False ) –

    if True, the labels are one-hot encoded to N channels, where N is the number of classes. If False, masks are single channel that contains values as class indexes. Defaults to True.

Source code in quadra/datamodules/segmentation.py
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
def __init__(
    self,
    data_path: str,
    idx_to_class: dict,
    name: str = "multiclass_segmentation_datamodule",
    dataset: type[SegmentationDatasetMulticlass] = SegmentationDatasetMulticlass,
    batch_size: int = 32,
    test_size: float = 0.3,
    val_size: float = 0.3,
    seed: int = 42,
    num_workers: int = 6,
    train_transform: albumentations.Compose | None = None,
    test_transform: albumentations.Compose | None = None,
    val_transform: albumentations.Compose | None = None,
    train_split_file: str | None = None,
    test_split_file: str | None = None,
    val_split_file: str | None = None,
    exclude_good: bool = False,
    num_data_train: int | None = None,
    one_hot_encoding: bool = False,
    **kwargs: Any,
):
    super().__init__(
        data_path=data_path,
        name=name,
        seed=seed,
        batch_size=batch_size,
        num_workers=num_workers,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        **kwargs,
    )
    self.test_size = test_size
    self.val_size = val_size
    self.exclude_good = exclude_good
    self.train_split_file = train_split_file
    self.test_split_file = test_split_file
    self.val_split_file = val_split_file
    self.dataset = dataset
    self.idx_to_class = idx_to_class
    self.num_data_train = num_data_train
    self.one_hot_encoding = one_hot_encoding
    self.train_dataset: SegmentationDataset
    self.val_dataset: SegmentationDataset
    self.test_dataset: SegmentationDataset

predict_dataloader()

Returns a dataloader used for predictions.

Source code in quadra/datamodules/segmentation.py
740
741
742
def predict_dataloader(self) -> DataLoader:
    """Returns a dataloader used for predictions."""
    return self.test_dataloader()

setup(stage=None)

Setup data module based on stages of training.

Source code in quadra/datamodules/segmentation.py
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
def setup(self, stage=None):
    """Setup data module based on stages of training."""
    if stage in ["fit", "train"]:
        train_data = self.data[self.data["split"] == "train"]
        val_data = self.data[self.data["split"] == "val"]

        self.train_dataset = self.dataset(
            image_paths=train_data["samples"].tolist(),
            mask_paths=train_data["masks"].tolist(),
            idx_to_class=self.idx_to_class,
            transform=self.train_transform,
            one_hot=self.one_hot_encoding,
        )
        self.val_dataset = self.dataset(
            image_paths=val_data["samples"].tolist(),
            mask_paths=val_data["masks"].tolist(),
            transform=self.val_transform,
            idx_to_class=self.idx_to_class,
            one_hot=self.one_hot_encoding,
        )
    elif stage == "test":
        self.test_dataset = self.dataset(
            image_paths=self.data[self.data["split"] == "test"]["samples"].tolist(),
            mask_paths=self.data[self.data["split"] == "test"]["masks"].tolist(),
            transform=self.test_transform,
            idx_to_class=self.idx_to_class,
            one_hot=self.one_hot_encoding,
        )
    elif stage == "predict":
        pass
    else:
        raise ValueError(f"Unknown stage {stage}")

test_dataloader()

Returns the test dataloader.

Raises:

  • ValueError

    If test dataset is not initialized.

Returns:

  • DataLoader

    test dataloader.

Source code in quadra/datamodules/segmentation.py
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
def test_dataloader(self) -> DataLoader:
    """Returns the test dataloader.

    Raises:
        ValueError: If test dataset is not initialized.


    Returns:
        test dataloader.
    """
    if not self.test_dataset_available:
        raise ValueError("Test dataset is not initialized")

    loader = DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
        persistent_workers=self.num_workers > 0,
    )
    return loader

train_dataloader()

Returns the train dataloader.

Raises:

  • ValueError

    If train dataset is not initialized.

Returns:

  • DataLoader

    Train dataloader.

Source code in quadra/datamodules/segmentation.py
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
def train_dataloader(self) -> DataLoader:
    """Returns the train dataloader.

    Raises:
        ValueError: If train dataset is not initialized.

    Returns:
        Train dataloader.
    """
    if not self.train_dataset_available:
        raise ValueError("Train dataset is not initialized")

    return DataLoader(
        self.train_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
        persistent_workers=self.num_workers > 0,
    )

val_dataloader()

Returns the validation dataloader.

Raises:

  • ValueError

    If validation dataset is not initialized.

Returns:

  • DataLoader

    val dataloader.

Source code in quadra/datamodules/segmentation.py
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
def val_dataloader(self) -> DataLoader:
    """Returns the validation dataloader.

    Raises:
        ValueError: If validation dataset is not initialized.

    Returns:
        val dataloader.
    """
    if not self.val_dataset_available:
        raise ValueError("Validation dataset is not initialized")

    return DataLoader(
        self.val_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
        persistent_workers=self.num_workers > 0,
    )

SklearnClassificationDataModule(data_path, exclude_filter=None, include_filter=None, val_size=0.2, class_to_idx=None, label_map=None, seed=42, batch_size=32, num_workers=6, train_transform=None, val_transform=None, test_transform=None, roi=None, n_splits=1, phase='train', cache=False, limit_training_data=None, train_split_file=None, test_split_file=None, name='sklearn_classification_datamodule', dataset=ImageClassificationListDataset, **kwargs)

Bases: BaseDataModule

A generic Data Module for classification with frozen torch backbone and sklearn classifier.

It can also handle k-fold cross validation.

Parameters:

  • name (str, default: 'sklearn_classification_datamodule' ) –

    The name for the data module. Defaults to "sklearn_classification_datamodule".

  • data_path (str) –

    Path to images main folder

  • exclude_filter (list[str] | None, default: None ) –

    List of string filter to be used to exclude images. If None no filter will be applied.

  • include_filter (list[str] | None, default: None ) –

    List of string filter to be used to include images. Only images that satisfied at list one of the filter will be included.

  • val_size (float, default: 0.2 ) –

    The validation split. Defaults to 0.2.

  • class_to_idx (dict[str, int] | None, default: None ) –

    Dictionary of conversion btw folder name and index. Only file whose label is in dictionary key list will be considered. If None all files will be considered and a custom conversion is created.

  • seed (int, default: 42 ) –

    Fixed seed for random operations

  • batch_size (int, default: 32 ) –

    Dimension of batches for dataloader

  • num_workers (int, default: 6 ) –

    Number of workers for dataloader

  • train_transform (Compose | None, default: None ) –

    Albumentation transformations for training set

  • val_transform (Compose | None, default: None ) –

    Albumentation transformations for validation set

  • test_transform (Compose | None, default: None ) –

    Albumentation transformations for test set

  • roi (tuple[int, int, int, int] | None, default: None ) –

    Optional cropping region

  • n_splits (int, default: 1 ) –

    Number of dataset subdivision (default 1 -> train/test). Use a value >= 2 for cross validation.

  • phase (str, default: 'train' ) –

    Either train or test

  • cache (bool, default: False ) –

    If true disable shuffling in all dataloader to enable feature caching

  • limit_training_data (int | None, default: None ) –

    if defined, each class will be donwsampled to this number. It must be >= 2 to allow splitting

  • label_map (dict[str, Any] | None, default: None ) –

    Dictionary of conversion btw folder name and label.

  • train_split_file (str | None, default: None ) –

    Optional path to a csv file containing the train split samples.

  • test_split_file (str | None, default: None ) –

    Optional path to a csv file containing the test split samples.

  • **kwargs (Any, default: {} ) –

    Additional arguments for BaseDataModule

Source code in quadra/datamodules/classification.py
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
def __init__(
    self,
    data_path: str,
    exclude_filter: list[str] | None = None,
    include_filter: list[str] | None = None,
    val_size: float = 0.2,
    class_to_idx: dict[str, int] | None = None,
    label_map: dict[str, Any] | None = None,
    seed: int = 42,
    batch_size: int = 32,
    num_workers: int = 6,
    train_transform: albumentations.Compose | None = None,
    val_transform: albumentations.Compose | None = None,
    test_transform: albumentations.Compose | None = None,
    roi: tuple[int, int, int, int] | None = None,
    n_splits: int = 1,
    phase: str = "train",
    cache: bool = False,
    limit_training_data: int | None = None,
    train_split_file: str | None = None,
    test_split_file: str | None = None,
    name: str = "sklearn_classification_datamodule",
    dataset: type[ImageClassificationListDataset] = ImageClassificationListDataset,
    **kwargs: Any,
):
    super().__init__(
        data_path=data_path,
        name=name,
        seed=seed,
        batch_size=batch_size,
        num_workers=num_workers,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        **kwargs,
    )

    self.class_to_idx = class_to_idx
    self.roi = roi
    self.cache = cache
    self.limit_training_data = limit_training_data

    self.dataset = dataset
    self.phase = phase
    self.n_splits = n_splits
    self.train_split_file = train_split_file
    self.test_split_file = test_split_file
    self.exclude_filter = exclude_filter
    self.include_filter = include_filter
    self.val_size = val_size
    self.label_map = label_map
    self.full_dataset: ImageClassificationListDataset
    self.train_dataset: list[ImageClassificationListDataset]
    self.val_dataset: list[ImageClassificationListDataset]

full_dataloader()

Return a dataloader to perform training on the entire dataset.

Returns:

  • DataLoader

    dataloader to perform training on the entire dataset after evaluation. This is useful

  • DataLoader

    to perform a final training on the entire dataset after the evaluation phase.

Source code in quadra/datamodules/classification.py
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
def full_dataloader(self) -> DataLoader:
    """Return a dataloader to perform training on the entire dataset.

    Returns:
        dataloader to perform training on the entire dataset after evaluation. This is useful
        to perform a final training on the entire dataset after the evaluation phase.

    """
    if self.full_dataset is None:
        raise ValueError("Full dataset is not initialized")

    return DataLoader(
        self.full_dataset,
        batch_size=self.batch_size,
        shuffle=not self.cache,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
    )

predict_dataloader()

Returns a dataloader used for predictions.

Source code in quadra/datamodules/classification.py
605
606
607
def predict_dataloader(self) -> DataLoader:
    """Returns a dataloader used for predictions."""
    return self.test_dataloader()

setup(stage)

Setup data module based on stages of training.

Source code in quadra/datamodules/classification.py
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
def setup(self, stage: str) -> None:
    """Setup data module based on stages of training."""
    if stage == "fit":
        self.train_dataset = []
        self.val_dataset = []

        for cv_idx in range(self.n_splits):
            cv_df = self.data[self.data["cv"] == cv_idx]
            train_samples = cv_df[cv_df["split"] == "train"]["samples"].tolist()
            train_targets = cv_df[cv_df["split"] == "train"]["targets"].tolist()
            val_samples = cv_df[cv_df["split"] == "val"]["samples"].tolist()
            val_targets = cv_df[cv_df["split"] == "val"]["targets"].tolist()
            self.train_dataset.append(
                self.dataset(
                    class_to_idx=self.class_to_idx,
                    samples=train_samples,
                    targets=train_targets,
                    transform=self.train_transform,
                    roi=self.roi,
                )
            )
            self.val_dataset.append(
                self.dataset(
                    class_to_idx=self.class_to_idx,
                    samples=val_samples,
                    targets=val_targets,
                    transform=self.val_transform,
                    roi=self.roi,
                )
            )
        all_samples = self.data[self.data["cv"] == 0]["samples"].tolist()
        all_targets = self.data[self.data["cv"] == 0]["targets"].tolist()
        self.full_dataset = self.dataset(
            class_to_idx=self.class_to_idx,
            samples=all_samples,
            targets=all_targets,
            transform=self.train_transform,
            roi=self.roi,
        )
    if stage == "test":
        test_samples = self.data[self.data["split"] == "test"]["samples"].tolist()
        test_targets = self.data[self.data["split"] == "test"]["targets"]
        self.test_dataset = self.dataset(
            class_to_idx=self.class_to_idx,
            samples=test_samples,
            targets=test_targets.tolist(),
            transform=self.test_transform,
            roi=self.roi,
            allow_missing_label=True,
        )

test_dataloader()

Returns the test dataloader.

Raises:

  • ValueError

    If test dataset is not initialized.

Returns:

  • DataLoader

    test dataloader.

Source code in quadra/datamodules/classification.py
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
def test_dataloader(self) -> DataLoader:
    """Returns the test dataloader.

    Raises:
        ValueError: If test dataset is not initialized.


    Returns:
        test dataloader.
    """
    if not self.test_dataset_available:
        raise ValueError("Test dataset is not initialized")

    loader = DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
        persistent_workers=self.num_workers > 0,
    )
    return loader

train_dataloader()

Returns a list of train dataloader.

Raises:

  • ValueError

    If train dataset is not initialized.

Returns:

  • list[DataLoader]

    list of train dataloader.

Source code in quadra/datamodules/classification.py
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
def train_dataloader(self) -> list[DataLoader]:
    """Returns a list of train dataloader.

    Raises:
        ValueError: If train dataset is not initialized.

    Returns:
        list of train dataloader.
    """
    if not self.train_dataset_available:
        raise ValueError("Train dataset is not initialized")

    loader = []
    for dataset in self.train_dataset:
        loader.append(
            DataLoader(
                dataset,
                batch_size=self.batch_size,
                shuffle=not self.cache,
                num_workers=self.num_workers,
                drop_last=False,
                pin_memory=True,
            )
        )
    return loader

val_dataloader()

Returns a list of validation dataloader.

Raises:

  • ValueError

    If validation dataset is not initialized.

Returns:

  • list[DataLoader]

    List of validation dataloader.

Source code in quadra/datamodules/classification.py
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
def val_dataloader(self) -> list[DataLoader]:
    """Returns a list of validation dataloader.

    Raises:
        ValueError: If validation dataset is not initialized.

    Returns:
        List of validation dataloader.
    """
    if not self.val_dataset_available:
        raise ValueError("Validation dataset is not initialized")

    loader = []
    for dataset in self.val_dataset:
        loader.append(
            DataLoader(
                dataset,
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=self.num_workers,
                drop_last=False,
                pin_memory=True,
            )
        )

    return loader