Skip to content

mnist

MNISTAnomalyDataModule(data_path, good_number, limit_data=100, category=None, **kwargs)

Bases: AnomalyDataModule

Standard anomaly datamodule with automatic download of the MNIST dataset.

Parameters:

  • data_path (str) –

    Path to the dataset

  • good_number (int) –

    Which number to use as a good class, all other numbers are considered anomalies.

  • category (Optional[str], default: None ) –

    The category of the dataset. For mnist this is always None.

  • limit_data (int, default: 100 ) –

    Limit the number of images to use for training and testing. Defaults to 100.

  • **kwargs (Any, default: {} ) –

    Additional arguments to pass to the AnomalyDataModule.

Source code in quadra/datamodules/generic/mnist.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def __init__(
    self, data_path: str, good_number: int, limit_data: int = 100, category: Optional[str] = None, **kwargs: Any
):
    """Initialize the MNIST anomaly datamodule.

    Args:
        data_path: Path to the dataset
        good_number: Which number to use as a good class, all other numbers are considered anomalies.
        category: The category of the dataset. For mnist this is always None.
        limit_data: Limit the number of images to use for training and testing. Defaults to 100.
        **kwargs: Additional arguments to pass to the AnomalyDataModule.
    """
    super().__init__(data_path=data_path, category=None, **kwargs)
    self.good_number = good_number
    self.limit_data = limit_data

download_data()

Download the MNIST dataset and move images in the right folders.

Source code in quadra/datamodules/generic/mnist.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def download_data(self) -> None:
    """Download the MNIST dataset and move images in the right folders."""
    log.info("Generating MNIST anomaly dataset for good number %s", self.good_number)

    mnist_train_dataset = MNIST(root=self.data_path, train=True, download=True)
    mnist_test_dataset = MNIST(root=self.data_path, train=False, download=True)

    self.data_path = os.path.join(self.data_path, "quadra_mnist_anomaly")

    if os.path.exists(self.data_path):
        shutil.rmtree(self.data_path)

    # Create the folder structure
    train_good_folder = os.path.join(self.data_path, "train", "good")
    test_good_folder = os.path.join(self.data_path, "test", "good")

    os.makedirs(train_good_folder, exist_ok=True)
    os.makedirs(test_good_folder, exist_ok=True)

    # Copy the good train images to the correct folder
    good_train_samples = mnist_train_dataset.data[mnist_train_dataset.targets == self.good_number]
    for i, image in enumerate(good_train_samples.numpy()):
        if i == self.limit_data:
            break
        cv2.imwrite(os.path.join(train_good_folder, f"{i}.png"), image)

    for number in range(10):
        if number == self.good_number:
            good_train_samples = mnist_test_dataset.data[mnist_test_dataset.targets == number]
            for i, image in enumerate(good_train_samples.numpy()):
                if i == self.limit_data:
                    break
                cv2.imwrite(os.path.join(test_good_folder, f"{number}_{i}.png"), image)
        else:
            test_bad_folder = os.path.join(self.data_path, "test", str(number))
            os.makedirs(test_bad_folder, exist_ok=True)
            bad_train_samples = mnist_train_dataset.data[mnist_train_dataset.targets == number]
            for i, image in enumerate(bad_train_samples.numpy()):
                if i == self.limit_data:
                    break

                cv2.imwrite(os.path.join(test_bad_folder, f"{number}_{i}.png"), image)