base
BaseNetworkBuilder(features_extractor, pre_classifier=nn.Identity(), classifier=nn.Identity(), freeze=True, hyperspherical=False, flatten_features=True)
¶
Bases: Module
Baseline Feature Extractor, with the possibility to map features to an hypersphere. If hypershperical is True the classifier is ignored.
Parameters:
-
features_extractor
(
Module
) –Feature extractor as a toch.nn.Module.
-
pre_classifier
(
Module
, default:Identity()
) –Pre classifier as a torch.nn.Module. Defaults to nn.Identity().
-
classifier
(
Module
, default:Identity()
) –Classifier as a torch.nn.Module. Defaults to nn.Identity().
-
freeze
(
bool
, default:True
) –Whether to freeze the feature extractor. Defaults to True.
-
hyperspherical
(
bool
, default:False
) –Whether to map features to an hypersphere. Defaults to False.
-
flatten_features
(
bool
, default:True
) –Whether to flatten the features before the pre_classifier. May be required if your model is outputting a feature map rather than a vector. Defaults to True.
Source code in quadra/models/classification/base.py
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
|