Skip to content

ssl

AugmentationStrategy

Bases: Enum

Augmentation Strategy for TwoAugmentationDataset.

TwoAugmentationDataset(dataset, transform, strategy=AugmentationStrategy.SAME_IMAGE)

Bases: Dataset

Two Image Augmentation Dataset for using in self-supervised learning.

Parameters:

  • dataset (Dataset) –

    A torch Dataset object

  • transform (Union[Compose, Tuple[Compose, Compose]]) –

    albumentation transformations for each image. If you use single transformation, it will be applied to both images. If you use tuple, it will be applied to first image and second image separately.

  • strategy (AugmentationStrategy, default: SAME_IMAGE ) –

    Defaults to AugmentationStrategy.SAME_IMAGE.

Source code in quadra/datasets/ssl.py
29
30
31
32
33
34
35
36
37
38
39
def __init__(
    self,
    dataset: Dataset,
    transform: Union[A.Compose, Tuple[A.Compose, A.Compose]],
    strategy: AugmentationStrategy = AugmentationStrategy.SAME_IMAGE,
):
    self.dataset = dataset
    self.transform = transform
    self.stategy = strategy
    if isinstance(transform, Iterable) and not isinstance(transform, str) and len(set(transform)) != 2:
        raise ValueError("transform must be an Iterable of length 2")

TwoSetAugmentationDataset(dataset, global_transforms, local_transform, num_local_transforms)

Bases: Dataset

Two Set Augmentation Dataset for using in self-supervised learning (DINO).

Parameters:

  • dataset (Dataset) –

    Base dataset

  • global_transforms (Tuple[Compose, Compose]) –

    Global transformations for each image.

  • local_transform (Compose) –

    Local transformations for each image.

  • num_local_transforms (int) –

    Number of local transformations to apply. In total you will have two + num_local_transforms transformations for each image. First element of the array will always return the original image.

Example

images[0] = global_transform0 images[1] = global_transform1 images[2:] = local_transform(s)(original_image)

Source code in quadra/datasets/ssl.py
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def __init__(
    self,
    dataset: Dataset,
    global_transforms: Tuple[A.Compose, A.Compose],
    local_transform: A.Compose,
    num_local_transforms: int,
):
    self.dataset = dataset
    self.global_transforms = global_transforms
    self.local_transform = local_transform
    self.num_local_transforms = num_local_transforms

    if num_local_transforms < 1:
        raise ValueError("num_local_transforms must be greater than 0")