Self Supervised Learning example¶
In this tutorial we will explain how to train a self-supervised learning model using Quadra
. Particularly we will focus on the Bootstrap your own latent
(BYOL) algorithm.
Training¶
Dataset¶
For self-supervised learning tasks, we will use the same classification dataset structure defined for the ClassificationDataModule
. In fact the SSLDataModule
is a subclass of ClassificationDataModule
and it shares the same API and implementation so it's fairly easy to move from a classification task to a self-supervised learning task.
dataset/
├── class_0
│ ├── abc.xyz
│ └── ...
├── class_1
│ ├── abc.xyz
│ └── ...
├── class_N
│ ├── abc.xyz
│ └── ...
├── test.txt # optional
├── train.txt # optional
└── val.txt # optional
The train.txt
, val.txt
and test.txt
files are optional and can be used to specify the list of images to use for training, validation and test. If not specified, the datamodule will base the split using a different parameter. The files should contain the relative path to the image from the dataset root folder. For example, if the dataset is organized as above, the train.txt
file could be:
class_0/abc.xyz
...
class_1/abc.xyz
...
class_N/abc.xyz
...
Validation is not required but it may be useful to evaluate the embeddings learned by the model during training for example using a linear classifier. The test set will be used to evaluate the model performance at the end of the training.
The default datamodule configuration is found under configs/datamodule/base/ssl.yaml
and it's defined as follows:
_target_: quadra.datamodules.SSLDataModule
data_path: ???
exclude_filter:
include_filter:
seed: ${core.seed}
num_workers: 8
batch_size: 16
augmentation_dataset: null
train_transform: null
test_transform: ${transforms.test_transform}
val_transform: ${transforms.val_transform}
train_split_file:
val_split_file:
test_split_file:
val_size: 0.3
test_size: 0.1
split_validation: true
class_to_idx:
We will make some changes to the datamodule in the experiment configuration file.
Experiment¶
First, let's check how base experiment configuration file is defined for BYOL algorithm located in configs/experiment/base/ssl/byol.yaml
.
# @package _global_
defaults:
- override /datamodule: base/ssl
- override /backbone: resnet18
- override /model: byol
- override /optimizer: lars
- override /scheduler: warmup
- override /transforms: byol
- override /loss: byol
- override /task: ssl
- override /trainer: lightning_gpu_fp16
core:
tag: "run"
name: "byol_ssl"
task:
_target_: quadra.tasks.ssl.BYOL
callbacks:
model_checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: "val_acc"
mode: "max"
trainer:
devices: [0]
max_epochs: 500
num_sanity_val_steps: 0
check_val_every_n_epoch: 10
datamodule:
num_workers: 12
batch_size: 256
augmentation_dataset:
_target_: quadra.datasets.TwoAugmentationDataset
transform:
- ${transforms.augmentation1}
- ${transforms.augmentation2}
dataset: null
scheduler:
init_lr:
- 0.4
The default configuration can be used to train a resnet18
model using lars
as optimizer and a cosine annealing
scheduler with warmup.
Every 10 epoch we will perform a step of validation using a KNN
classifier with 20 neighbors (check the model definition for more details). The model will be saved every time the validation accuracy improves and at the end of training.
We will make use of automatic mixed precision to speed up the training process.
Since we are going to use custom dataset for this task we can add a custom experiment configurations under configs/experiment/custom_experiment/byol.yaml
file.
# @package _global_
defaults:
- base/ssl/byol
- override /backbone: vit16_tiny # let's use different backbone instead of the resnet
- _self_
trainer:
devices: [2] # we may need to use different gpu(s)
max_epochs: 1000 # let's assume we would like to train for 1000 epochs
datamodule:
data_path: /path/to/the/dataset
Note
If you possess a GPU with bf16 support you can use the lightning_gpu_bf16
trainer configuration instead of lightning_gpu_fp16
by overriding the trainer
section in the experiment configuration file.
# @package _global_
defaults:
- base/ssl/byol
- override /backbone: vit16_tiny # let's use different backbone instead of the resnet
- override /trainer: lightning_gpu_bf16
- _self_
... # rest of the configuration
Run¶
Now we are ready to run our experiment with following command:
quadra experiment=custom_experiment/byol
The output folder should contain the following entries:
checkpoints config_resolved.yaml config_tree.txt data deployment_model main.log
The checkpoints
folder contains the saved pytorch
lightning checkpoints. The data
folder contains a joblib version of the datamodule containing all parameters and dataset spits. The deployment_model
folder contains the model ready for production in the format specified in the export.types
parameter (default torchscript
).
Run (Advanced) - Changing transformations¶
In previous example, we have used default transformations defined in the original paper. However, these settings may not be suitable for our dataset. For example, Gaussian Blur
may destroy important details. In this case, we can extend the experiment configuration file and add our custom transformations.
# @package _global_
defaults:
- base/ssl/byol
- override /backbone: vit16_tiny # let's use different backbone instead of the resnet
- _self_
trainer:
devices: [2] # we may need to use different gpu(s)
max_epochs: 1000 # let's assume we would like to train for 1000 epochs
datamodule:
data_path: /path/to/the/dataset
# check configs/transforms/byol.yaml for more details
transforms:
augmentation1:
_target_: albumentations.Compose
transforms:
- _target_: albumentations.RandomResizedCrop
height: ${transforms.input_height}
width: ${transforms.input_width}
scale: [0.08, 1.0]
- ${transforms.flip_and_jitter}
# remove gaussian blur
# - _target_: albumentations.GaussianBlur
# blur_limit: 23
# sigma_limit: [0.1, 2]
# p: 1.0
- ${transforms.normalize}
augmentation2:
_target_: albumentations.Compose
transforms:
- _target_: albumentations.RandomResizedCrop
height: ${transforms.input_height}
width: ${transforms.input_width}
scale: [0.08, 1.0]
- ${transforms.flip_and_jitter}
# remove gaussian blur
# - _target_: albumentations.GaussianBlur
# blur_limit: 23
# sigma_limit: [0.1, 2]
# p: 0.1
- _target_: albumentations.Solarize
p: 0.2
- ${transforms.normalize}
During training two different augmentations of the same image will be sampled based on the given parameter and the algorithm will try to match the representations of the two augmentations so picking the right set of transformations is important.