Skip to content

anomaly

AnomalyDataModule(data_path, category=None, image_size=None, train_batch_size=32, test_batch_size=32, num_workers=8, train_transform=None, val_transform=None, test_transform=None, seed=0, task='segmentation', mask_suffix=None, create_test_set_if_empty=True, phase='train', name='anomaly_datamodule', valid_area_mask=None, crop_area=None, **kwargs)

Bases: BaseDataModule

Anomalib-like Lightning Data Module.

Parameters:

  • data_path (str) –

    Path to the dataset

  • category (str | None, default: None ) –

    Name of the sub category to use.

  • image_size (int | tuple[int, int] | None, default: None ) –

    Variable to which image is resized.

  • train_batch_size (int, default: 32 ) –

    Training batch size.

  • test_batch_size (int, default: 32 ) –

    Testing batch size.

  • train_transform (Compose | None, default: None ) –

    transformations for training. Defaults to None.

  • val_transform (Compose | None, default: None ) –

    transformations for validation. Defaults to None.

  • test_transform (Compose | None, default: None ) –

    transformations for testing. Defaults to None.

  • num_workers (int, default: 8 ) –

    Number of workers.

  • seed (int, default: 0 ) –

    seed used for the random subset splitting

  • task (str, default: 'segmentation' ) –

    Whether we are interested in segmenting the anomalies (segmentation) or not (classification)

  • mask_suffix (str | None, default: None ) –

    String to append to the base filename to get the mask name, by default for MVTec dataset masks are saved as imagename_mask.png in this case the parameter should be filled with "_mask"

  • create_test_set_if_empty (bool, default: True ) –

    If True, the test set is created from good images if it is empty.

  • phase (str, default: 'train' ) –

    Either train or test.

  • name (str, default: 'anomaly_datamodule' ) –

    Name of the data module.

  • valid_area_mask (str | None, default: None ) –

    Optional path to the mask to use to filter out the valid area of the image. If None, the whole image is considered valid. The mask should match the image size even if the image is cropped.

  • crop_area (tuple[int, int, int, int] | None, default: None ) –

    Optional tuple of 4 integers (x1, y1, x2, y2) to crop the image to the specified area. If None, the whole image is considered valid.

Source code in quadra/datamodules/anomaly.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def __init__(
    self,
    data_path: str,
    category: str | None = None,
    image_size: int | tuple[int, int] | None = None,
    train_batch_size: int = 32,
    test_batch_size: int = 32,
    num_workers: int = 8,
    train_transform: albumentations.Compose | None = None,
    val_transform: albumentations.Compose | None = None,
    test_transform: albumentations.Compose | None = None,
    seed: int = 0,
    task: str = "segmentation",
    mask_suffix: str | None = None,
    create_test_set_if_empty: bool = True,
    phase: str = "train",
    name: str = "anomaly_datamodule",
    valid_area_mask: str | None = None,
    crop_area: tuple[int, int, int, int] | None = None,
    **kwargs,
) -> None:
    super().__init__(
        data_path=data_path,
        name=name,
        seed=seed,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        num_workers=num_workers,
        **kwargs,
    )

    self.root = data_path
    self.category = category
    self.data_path = os.path.join(self.root, self.category) if self.category is not None else self.root
    self.image_size = image_size

    self.train_batch_size = train_batch_size
    self.test_batch_size = test_batch_size
    self.task = task

    self.train_dataset: AnomalyDataset
    self.test_dataset: AnomalyDataset
    self.val_dataset: AnomalyDataset
    self.mask_suffix = mask_suffix
    self.create_test_set_if_empty = create_test_set_if_empty
    self.phase = phase
    self.valid_area_mask = valid_area_mask
    self.crop_area = crop_area

val_data: pd.DataFrame property

Get validation data.

predict_dataloader()

Returns a dataloader used for predictions.

Source code in quadra/datamodules/anomaly.py
172
173
174
175
176
177
178
179
180
def predict_dataloader(self) -> DataLoader:
    """Returns a dataloader used for predictions."""
    return DataLoader(
        self.test_dataset,
        shuffle=False,
        batch_size=self.test_batch_size,
        num_workers=self.num_workers,
        pin_memory=True,
    )

setup(stage=None)

Setup data module based on stages of training.

Source code in quadra/datamodules/anomaly.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def setup(self, stage: str | None = None) -> None:
    """Setup data module based on stages of training."""
    if stage == "fit" and self.phase == "train":
        self.train_dataset = AnomalyDataset(
            transform=self.train_transform,
            task=self.task,
            samples=self.train_data,
            valid_area_mask=self.valid_area_mask,
            crop_area=self.crop_area,
        )

        if len(self.val_data) == 0:
            log.info("Validation dataset is empty, using test set instead")

        self.val_dataset = AnomalyDataset(
            transform=self.test_transform,
            task=self.task,
            samples=self.val_data if len(self.val_data) > 0 else self.data,
            valid_area_mask=self.valid_area_mask,
            crop_area=self.crop_area,
        )
    if stage == "test" or self.phase == "test":
        self.test_dataset = AnomalyDataset(
            transform=self.test_transform,
            task=self.task,
            samples=self.test_data,
            valid_area_mask=self.valid_area_mask,
            crop_area=self.crop_area,
        )

test_dataloader()

Get test dataloader.

Source code in quadra/datamodules/anomaly.py
162
163
164
165
166
167
168
169
170
def test_dataloader(self) -> DataLoader:
    """Get test dataloader."""
    return DataLoader(
        self.test_dataset,
        shuffle=False,
        batch_size=self.test_batch_size,
        num_workers=self.num_workers,
        pin_memory=True,
    )

train_dataloader()

Get train dataloader.

Source code in quadra/datamodules/anomaly.py
142
143
144
145
146
147
148
149
150
def train_dataloader(self) -> DataLoader:
    """Get train dataloader."""
    return DataLoader(
        self.train_dataset,
        shuffle=True,
        batch_size=self.train_batch_size,
        num_workers=self.num_workers,
        pin_memory=True,
    )

val_dataloader()

Get validation dataloader.

Source code in quadra/datamodules/anomaly.py
152
153
154
155
156
157
158
159
160
def val_dataloader(self) -> DataLoader:
    """Get validation dataloader."""
    return DataLoader(
        dataset=self.val_dataset,
        shuffle=False,
        batch_size=self.test_batch_size,
        num_workers=self.num_workers,
        pin_memory=True,
    )