Skip to content

patch

PatchSklearnClassificationDataModule(data_path, class_to_idx, name='patch_classification_datamodule', train_filename='dataset.txt', exclude_filter=None, include_filter=None, seed=42, batch_size=32, num_workers=6, train_transform=None, val_transform=None, test_transform=None, balance_classes=False, class_to_skip_training=None, **kwargs)

Bases: BaseDataModule

DataModule for patch classification.

Parameters:

  • data_path (str) –

    Location of the dataset

  • name (str, default: 'patch_classification_datamodule' ) –

    Name of the datamodule

  • train_filename (str, default: 'dataset.txt' ) –

    Name of the file containing the list of training samples

  • exclude_filter (list[str] | None, default: None ) –

    Filter to exclude samples from the dataset

  • include_filter (list[str] | None, default: None ) –

    Filter to include samples from the dataset

  • class_to_idx (dict) –

    Dictionary mapping class names to indices

  • seed (int, default: 42 ) –

    Random seed

  • batch_size (int, default: 32 ) –

    Batch size

  • num_workers (int, default: 6 ) –

    Number of workers

  • train_transform (Compose | None, default: None ) –

    Transform to apply to the training samples

  • val_transform (Compose | None, default: None ) –

    Transform to apply to the validation samples

  • test_transform (Compose | None, default: None ) –

    Transform to apply to the test samples

  • balance_classes (bool, default: False ) –

    If True repeat low represented classes

  • class_to_skip_training (list | None, default: None ) –

    List of classes skipped during training.

Source code in quadra/datamodules/patch.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def __init__(
    self,
    data_path: str,
    class_to_idx: dict,
    name: str = "patch_classification_datamodule",
    train_filename: str = "dataset.txt",
    exclude_filter: list[str] | None = None,
    include_filter: list[str] | None = None,
    seed: int = 42,
    batch_size: int = 32,
    num_workers: int = 6,
    train_transform: albumentations.Compose | None = None,
    val_transform: albumentations.Compose | None = None,
    test_transform: albumentations.Compose | None = None,
    balance_classes: bool = False,
    class_to_skip_training: list | None = None,
    **kwargs,
):
    super().__init__(
        data_path=data_path,
        name=name,
        seed=seed,
        num_workers=num_workers,
        batch_size=batch_size,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        **kwargs,
    )
    self.class_to_idx = class_to_idx
    self.balance_classes = balance_classes
    self.train_filename = train_filename
    self.include_filter = include_filter
    self.exclude_filter = exclude_filter
    self.class_to_skip_training = class_to_skip_training

    self.train_folder = os.path.join(self.data_path, "train")
    self.val_folder = os.path.join(self.data_path, "val")
    self.test_folder = os.path.join(self.data_path, "test")
    self.info: PatchDatasetInfo
    self.train_dataset: PatchSklearnClassificationTrainDataset
    self.val_dataset: ImageClassificationListDataset
    self.test_dataset: ImageClassificationListDataset

setup(stage=None)

Setup function.

Source code in quadra/datamodules/patch.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def setup(self, stage: str | None = None) -> None:
    """Setup function."""
    if stage == "fit":
        self.train_dataset = PatchSklearnClassificationTrainDataset(
            data_path=self.data_path,
            class_to_idx=self.class_to_idx,
            samples=self.data[self.data["split"] == "train"]["samples"].tolist(),
            targets=self.data[self.data["split"] == "train"]["targets"].tolist(),
            transform=self.train_transform,
            balance_classes=self.balance_classes,
        )

        self.val_dataset = ImageClassificationListDataset(
            class_to_idx=self.class_to_idx,
            samples=self.data[self.data["split"] == "val"]["samples"].tolist(),
            targets=self.data[self.data["split"] == "val"]["targets"].tolist(),
            transform=self.val_transform,
            allow_missing_label=False,
        )

    elif stage in ["test", "predict"]:
        self.test_dataset = ImageClassificationListDataset(
            class_to_idx=self.class_to_idx,
            samples=self.data[self.data["split"] == "test"]["samples"].tolist(),
            targets=self.data[self.data["split"] == "test"]["targets"].tolist(),
            transform=self.test_transform,
            allow_missing_label=True,
        )

test_dataloader()

Return the test dataloader.

Source code in quadra/datamodules/patch.py
178
179
180
181
182
183
184
185
186
187
188
189
190
def test_dataloader(self) -> DataLoader:
    """Return the test dataloader."""
    if not self.test_dataset_available:
        raise ValueError("No test dataset is available")

    return DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
    )

train_dataloader()

Return the train dataloader.

Source code in quadra/datamodules/patch.py
152
153
154
155
156
157
158
159
160
161
162
163
def train_dataloader(self) -> DataLoader:
    """Return the train dataloader."""
    if not self.train_dataset_available:
        raise ValueError("No training sample is available")
    return DataLoader(
        self.train_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
    )

val_dataloader()

Return the validation dataloader.

Source code in quadra/datamodules/patch.py
165
166
167
168
169
170
171
172
173
174
175
176
def val_dataloader(self) -> DataLoader:
    """Return the validation dataloader."""
    if not self.val_dataset_available:
        raise ValueError("No validation dataset is available")
    return DataLoader(
        self.val_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
    )