Skip to content

base

BaseDataModule(data_path, name='base_datamodule', num_workers=16, batch_size=32, seed=42, load_aug_images=False, aug_name=None, n_aug_to_take=None, replace_str_from=None, replace_str_to=None, train_transform=None, val_transform=None, test_transform=None, enable_hashing=True, hash_size=64, hash_type='content')

Bases: LightningDataModule

Base class for all data modules.

Parameters:

  • data_path (str) –

    Path to the data main folder.

  • name (str, default: 'base_datamodule' ) –

    The name for the data module. Defaults to "base_datamodule".

  • num_workers (int, default: 16 ) –

    Number of workers for dataloaders. Defaults to 16.

  • batch_size (int, default: 32 ) –

    Batch size. Defaults to 32.

  • seed (int, default: 42 ) –

    Random generator seed. Defaults to 42.

  • train_transform (Optional[Compose], default: None ) –

    Transformations for train dataset. Defaults to None.

  • val_transform (Optional[Compose], default: None ) –

    Transformations for validation dataset. Defaults to None.

  • test_transform (Optional[Compose], default: None ) –

    Transformations for test dataset. Defaults to None.

  • enable_hashing (bool, default: True ) –

    Whether to enable hashing of images. Defaults to True.

  • hash_size (Literal[32, 64, 128], default: 64 ) –

    Size of the hash. Must be one of [32, 64, 128]. Defaults to 64.

  • hash_type (Literal['content', 'size'], default: 'content' ) –

    Type of hash to use, if content hash is used, the hash is computed on the file content, otherwise the hash is computed on the file size which is faster but less safe. Defaults to "content".

Source code in quadra/datamodules/base.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def __init__(
    self,
    data_path: str,
    name: str = "base_datamodule",
    num_workers: int = 16,
    batch_size: int = 32,
    seed: int = 42,
    load_aug_images: bool = False,
    aug_name: Optional[str] = None,
    n_aug_to_take: Optional[int] = None,
    replace_str_from: Optional[str] = None,
    replace_str_to: Optional[str] = None,
    train_transform: Optional[albumentations.Compose] = None,
    val_transform: Optional[albumentations.Compose] = None,
    test_transform: Optional[albumentations.Compose] = None,
    enable_hashing: bool = True,
    hash_size: Literal[32, 64, 128] = 64,
    hash_type: Literal["content", "size"] = "content",
):
    super().__init__()
    self.num_workers = num_workers
    self.batch_size = batch_size
    self.seed = seed
    self.data_path = data_path
    self.name = name
    self.train_transform = train_transform
    self.val_transform = val_transform
    self.test_transform = test_transform
    self.enable_hashing = enable_hashing
    self.hash_size = hash_size
    self.hash_type = hash_type

    if self.hash_size not in [32, 64, 128]:
        raise ValueError(f"Invalid hash size {self.hash_size}. Must be one of [32, 64, 128].")

    self.load_aug_images = load_aug_images
    self.aug_name = aug_name
    self.n_aug_to_take = n_aug_to_take
    self.replace_str_from = replace_str_from
    self.replace_str_to = replace_str_to
    self.extra_args: Dict[str, Any] = {}
    self.train_dataset: TrainDataset
    self.val_dataset: ValDataset
    self.test_dataset: TestDataset
    self.data: pd.DataFrame
    self.data_folder = "data"
    os.makedirs(self.data_folder, exist_ok=True)
    self.datamodule_checkpoint_file = os.path.join(self.data_folder, "datamodule.pkl")
    self.dataset_file = os.path.join(self.data_folder, "dataset.csv")

test_data: pd.DataFrame property

Get test data.

test_dataset_available: bool property

Checks if the test dataset is available.

train_data: pd.DataFrame property

Get train data.

train_dataset_available: bool property

Checks if the train dataset is available.

val_data: pd.DataFrame property

Get validation data.

val_dataset_available: bool property

Checks if the validation dataset is available.

__getstate__()

This method is called when pickling the object. It's useful to remove attributes that shouldn't be pickled.

Source code in quadra/datamodules/base.py
285
286
287
288
289
290
291
292
293
294
def __getstate__(self) -> Dict[str, Any]:
    """This method is called when pickling the object.
    It's useful to remove attributes that shouldn't be pickled.
    """
    state = self.__dict__.copy()
    if "trainer" in state:
        # Lightning injects the trainer in the datamodule, we don't want to pickle it.
        del state["trainer"]

    return state

hash_data()

Computes the hash of the files inside the datasets.

Source code in quadra/datamodules/base.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def hash_data(self) -> None:
    """Computes the hash of the files inside the datasets."""
    if not self.enable_hashing:
        return

    # TODO: We need to find a way to annotate the columns of data.
    paths_and_hash_length = zip(self.data["samples"], [self.hash_size] * len(self.data))

    with mp.Pool(min(8, mp.cpu_count() - 1)) as pool:
        self.data["hash"] = list(
            tqdm(
                pool.istarmap(  # type: ignore[attr-defined]
                    compute_file_content_hash if self.hash_type == "content" else compute_file_size_hash,
                    paths_and_hash_length,
                ),
                total=len(self.data),
                desc="Computing hashes",
            )
        )

    self.data["hash_type"] = self.hash_type

load_augmented_samples(samples, targets, replace_str_from=None, replace_str_to=None, shuffle=False)

Loads augmented samples.

Source code in quadra/datamodules/base.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def load_augmented_samples(
    self,
    samples: List[str],
    targets: List[Any],
    replace_str_from: Optional[str] = None,
    replace_str_to: Optional[str] = None,
    shuffle: bool = False,
) -> Tuple[List[str], List[str]]:
    """Loads augmented samples."""
    if self.n_aug_to_take is None:
        raise ValueError("`n_aug_to_take` is not set. Cannot load augmented samples.")
    aug_samples = []
    aug_labels = []
    for sample, label in zip(samples, targets):
        aug_samples.append(sample)
        aug_labels.append(label)
        if replace_str_from is not None and replace_str_to is not None:
            sample = sample.replace(replace_str_from, replace_str_to)
        base, ext = os.path.splitext(sample)
        for k in range(self.n_aug_to_take):
            aug_samples.append(base + "_" + str(k + 1) + ext)
            aug_labels.append(label)
    samples = aug_samples
    targets = aug_labels
    if shuffle:
        idexs = np.arange(len(aug_samples))
        np.random.shuffle(idexs)
        samples = np.array(samples)[idexs].tolist()
        targets = np.array(targets)[idexs].tolist()
    return samples, targets

prepare_data()

Prepares the data, should be overridden by subclasses.

Source code in quadra/datamodules/base.py
276
277
278
279
280
281
282
283
def prepare_data(self) -> None:
    """Prepares the data, should be overridden by subclasses."""
    if hasattr(self, "data"):
        return

    self._prepare_data()
    self.hash_data()
    self.save_checkpoint()

restore_checkpoint()

Loads the data from disk, utility function that should be called from setup.

Source code in quadra/datamodules/base.py
323
324
325
326
327
328
329
330
331
332
333
334
def restore_checkpoint(self) -> None:
    """Loads the data from disk, utility function that should be called from setup."""
    if hasattr(self, "data"):
        return

    if not os.path.isfile(self.datamodule_checkpoint_file):
        raise ValueError(f"Dataset file {self.datamodule_checkpoint_file} does not exist.")

    with open(self.datamodule_checkpoint_file, "rb") as f:
        checkpoint_datamodule = pkl.load(f)
        for key, value in checkpoint_datamodule.__dict__.items():
            setattr(self, key, value)

save_checkpoint()

Saves the datamodule to disk, utility function that is called from prepare_data. We are required to save datamodule to disk because we can't assign attributes to the datamodule in prepare_data when working with multiple gpus.

Source code in quadra/datamodules/base.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
def save_checkpoint(self) -> None:
    """Saves the datamodule to disk, utility function that is called from prepare_data. We are required to save
    datamodule to disk because we can't assign attributes to the datamodule in prepare_data when working with
    multiple gpus.
    """
    if not os.path.exists(self.datamodule_checkpoint_file) and not os.path.exists(self.dataset_file):
        with open(self.datamodule_checkpoint_file, "wb") as f:
            pkl.dump(self, f)

        self.data.to_csv(self.dataset_file, index=False)
        log.info("Datamodule checkpoint saved to disk.")

    if "targets" in self.data and not isinstance(self.data["targets"].iloc[0], np.ndarray):
        # If we find a numpy array target it's very likely one hot encoded, in that case we don't want to print
        log.info("Dataset Info:")
        split_order = {"train": 0, "val": 1, "test": 2}
        log.info(
            "\n%s",
            self.data.groupby(["split", "targets"])
            .size()
            .to_frame()
            .reset_index()
            .sort_values(by=["split"], key=lambda x: x.map(split_order))
            .rename(columns={0: "count"})
            .to_string(index=False),
        )

DecorateParentMethod

Bases: type

Metaclass to decorate methods of subclasses.

__new__(name, bases, dct)

Create new decorator for parent class methods.

Source code in quadra/datamodules/base.py
41
42
43
44
45
46
47
48
49
50
def __new__(cls, name, bases, dct):
    """Create new  decorator for parent class methods."""
    method_decorator_mapper = {
        "setup": load_data_from_disk_dec,
    }
    for method_name, decorator in method_decorator_mapper.items():
        if method_name in dct:
            dct[method_name] = decorator(dct[method_name])

    return super().__new__(cls, name, bases, dct)

compute_file_content_hash(path, hash_size=64)

Get hash of a file based on its content.

Parameters:

  • path (str) –

    Path to the file.

  • hash_size (Literal[32, 64, 128], default: 64 ) –

    Size of the hash. Must be one of [32, 64, 128].

Returns:

  • str

    The hash of the file.

Source code in quadra/datamodules/base.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def compute_file_content_hash(path: str, hash_size: Literal[32, 64, 128] = 64) -> str:
    """Get hash of a file based on its content.

    Args:
        path: Path to the file.
        hash_size: Size of the hash. Must be one of [32, 64, 128].

    Returns:
        The hash of the file.
    """
    with open(path, "rb") as f:
        data = f.read()

        if hash_size == 32:
            file_hash = xxhash.xxh32(data, seed=42).hexdigest()
        elif hash_size == 64:
            file_hash = xxhash.xxh64(data, seed=42).hexdigest()
        elif hash_size == 128:
            file_hash = xxhash.xxh128(data, seed=42).hexdigest()
        else:
            raise ValueError(f"Invalid hash size {hash_size}. Must be one of [32, 64, 128].")

    return file_hash

compute_file_size_hash(path, hash_size=64)

Get hash of a file based on its size.

Parameters:

  • path (str) –

    Path to the file.

  • hash_size (Literal[32, 64, 128], default: 64 ) –

    Size of the hash. Must be one of [32, 64, 128].

Returns:

  • str

    The hash of the file.

Source code in quadra/datamodules/base.py
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def compute_file_size_hash(path: str, hash_size: Literal[32, 64, 128] = 64) -> str:
    """Get hash of a file based on its size.

    Args:
        path: Path to the file.
        hash_size: Size of the hash. Must be one of [32, 64, 128].

    Returns:
        The hash of the file.
    """
    data = str(os.path.getsize(path))

    if hash_size == 32:
        file_hash = xxhash.xxh32(data, seed=42).hexdigest()
    elif hash_size == 64:
        file_hash = xxhash.xxh64(data, seed=42).hexdigest()
    elif hash_size == 128:
        file_hash = xxhash.xxh128(data, seed=42).hexdigest()
    else:
        raise ValueError(f"Invalid hash size {hash_size}. Must be one of [32, 64, 128].")

    return file_hash

istarmap(self, func, iterable, chunksize=1)

Starmap-version of imap.

Source code in quadra/datamodules/base.py
102
103
104
105
106
107
108
109
110
111
112
113
@typing.no_type_check
def istarmap(self, func: Callable, iterable: Iterable, chunksize: int = 1):
    # pylint: disable=all
    """Starmap-version of imap."""
    self._check_running()
    if chunksize < 1:
        raise ValueError("Chunksize must be 1+, not {0:n}".format(chunksize))

    task_batches = mpp.Pool._get_tasks(func, iterable, chunksize)
    result = mpp.IMapIterator(self)
    self._taskqueue.put((self._guarded_task_generation(result._job, mpp.starmapstar, task_batches), result._set_length))
    return (item for chunk in result for item in chunk)

load_data_from_disk_dec(func)

Load data from disk if it exists.

Source code in quadra/datamodules/base.py
25
26
27
28
29
30
31
32
33
34
35
def load_data_from_disk_dec(func):
    """Load data from disk if it exists."""

    @wraps(func)
    def wrapper(*args, **kwargs):
        """Wrapper function to load data from disk if it exists."""
        self = cast(BaseDataModule, args[0])
        self.restore_checkpoint()
        return func(*args, **kwargs)

    return wrapper