patch
PatchSklearnClassificationDataModule(data_path, class_to_idx, name='patch_classification_datamodule', train_filename='dataset.txt', exclude_filter=None, include_filter=None, seed=42, batch_size=32, num_workers=6, train_transform=None, val_transform=None, test_transform=None, balance_classes=False, class_to_skip_training=None, **kwargs)
¶
Bases: BaseDataModule
DataModule for patch classification.
Parameters:
-
data_path
(
str
) –Location of the dataset
-
name
(
str
, default:'patch_classification_datamodule'
) –Name of the datamodule
-
train_filename
(
str
, default:'dataset.txt'
) –Name of the file containing the list of training samples
-
exclude_filter
(
Optional[List[str]]
, default:None
) –Filter to exclude samples from the dataset
-
include_filter
(
Optional[List[str]]
, default:None
) –Filter to include samples from the dataset
-
class_to_idx
(
Dict
) –Dictionary mapping class names to indices
-
seed
(
int
, default:42
) –Random seed
-
batch_size
(
int
, default:32
) –Batch size
-
num_workers
(
int
, default:6
) –Number of workers
-
train_transform
(
Optional[Compose]
, default:None
) –Transform to apply to the training samples
-
val_transform
(
Optional[Compose]
, default:None
) –Transform to apply to the validation samples
-
test_transform
(
Optional[Compose]
, default:None
) –Transform to apply to the test samples
-
balance_classes
(
bool
, default:False
) –If True repeat low represented classes
-
class_to_skip_training
(
Optional[list]
, default:None
) –List of classes skipped during training.
Source code in quadra/datamodules/patch.py
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
|
setup(stage=None)
¶
Setup function.
Source code in quadra/datamodules/patch.py
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
|
test_dataloader()
¶
Return the test dataloader.
Source code in quadra/datamodules/patch.py
177 178 179 180 181 182 183 184 185 186 187 188 189 |
|
train_dataloader()
¶
Return the train dataloader.
Source code in quadra/datamodules/patch.py
151 152 153 154 155 156 157 158 159 160 161 162 |
|
val_dataloader()
¶
Return the validation dataloader.
Source code in quadra/datamodules/patch.py
164 165 166 167 168 169 170 171 172 173 174 175 |
|