Skip to content

ssl

SSLDataModule(data_path, augmentation_dataset, name='ssl_datamodule', split_validation=True, **kwargs)

Bases: ClassificationDataModule

Base class for all data modules for self supervised learning data modules.

Parameters:

  • data_path (str) –

    Path to the data main folder.

  • augmentation_dataset (TwoAugmentationDataset | TwoSetAugmentationDataset) –

    Augmentation dataset for training dataset.

  • name (str, default: 'ssl_datamodule' ) –

    The name for the data module. Defaults to "ssl_datamodule".

  • split_validation (bool, default: True ) –

    Whether to split the validation set if . Defaults to True.

  • **kwargs (Any, default: {} ) –

    The keyword arguments for the classification data module. Defaults to None.

Source code in quadra/datamodules/ssl.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
    self,
    data_path: str,
    augmentation_dataset: TwoAugmentationDataset | TwoSetAugmentationDataset,
    name: str = "ssl_datamodule",
    split_validation: bool = True,
    **kwargs: Any,
):
    super().__init__(
        data_path=data_path,
        name=name,
        **kwargs,
    )
    self.augmentation_dataset = augmentation_dataset
    self.classifier_train_dataset: torch.utils.data.Dataset | None = None
    self.split_validation = split_validation

classifier_train_dataloader()

Returns classifier train dataloader.

Source code in quadra/datamodules/ssl.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def classifier_train_dataloader(self) -> DataLoader:
    """Returns classifier train dataloader."""
    if self.classifier_train_dataset is None:
        raise ValueError("Classifier train dataset is not initialized")

    loader = DataLoader(
        self.classifier_train_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
        persistent_workers=self.num_workers > 0,
    )
    return loader

setup(stage=None)

Setup data module based on stages of training.

Source code in quadra/datamodules/ssl.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def setup(self, stage: str | None = None) -> None:
    """Setup data module based on stages of training."""
    if stage == "fit":
        self.train_dataset = self.dataset(
            samples=self.train_data["samples"].tolist(),
            targets=self.train_data["targets"].tolist(),
            transform=self.train_transform,
        )

        if np.unique(self.train_data["targets"]).shape[0] > 1 and not self.split_validation:
            self.classifier_train_dataset = self.dataset(
                samples=self.train_data["samples"].tolist(),
                targets=self.train_data["targets"].tolist(),
                transform=self.val_transform,
            )
            self.val_dataset = self.dataset(
                samples=self.val_data["samples"].tolist(),
                targets=self.val_data["targets"].tolist(),
                transform=self.val_transform,
            )
        else:
            train_classifier_samples, val_samples, train_classifier_targets, val_targets = train_test_split(
                self.val_data["samples"],
                self.val_data["targets"],
                test_size=0.3,
                random_state=self.seed,
                stratify=self.val_data["targets"],
            )

            self.classifier_train_dataset = self.dataset(
                samples=train_classifier_samples,
                targets=train_classifier_targets,
                transform=self.test_transform,
            )

            self.val_dataset = self.dataset(
                samples=val_samples,
                targets=val_targets,
                transform=self.val_transform,
            )

            log.warning(
                "The training set contains only one class and cannot be used to train a classifier. To overcome "
                "this issue 70% of the validation set is used to train the classifier. The remaining will be used "
                "as standard validation. To disable this behaviour set the `split_validation` parameter to False."
            )
            self._check_train_dataset_config()
    if stage == "test":
        self.test_dataset = self.dataset(
            samples=self.test_data["samples"].tolist(),
            targets=self.test_data["targets"].tolist(),
            transform=self.test_transform,
        )

train_dataloader()

Returns train dataloader.

Source code in quadra/datamodules/ssl.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def train_dataloader(self) -> DataLoader:
    """Returns train dataloader."""
    if not isinstance(self.train_dataset, torch.utils.data.Dataset):
        raise ValueError("Train dataset is not a subclass of `torch.utils.data.Dataset`")
    self.augmentation_dataset.dataset = self.train_dataset
    loader = DataLoader(
        self.augmentation_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.num_workers,
        drop_last=False,
        pin_memory=True,
        persistent_workers=self.num_workers > 0,
    )
    return loader