Skip to content

base

BaseNetworkBuilder(features_extractor, pre_classifier=None, classifier=None, freeze=True, hyperspherical=False, flatten_features=True)

Bases: Module

Baseline Feature Extractor, with the possibility to map features to an hypersphere. If hypershperical is True the classifier is ignored.

Parameters:

  • features_extractor (Module) –

    Feature extractor as a toch.nn.Module.

  • pre_classifier (Module | None, default: None ) –

    Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.

  • classifier (Module | None, default: None ) –

    Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.

  • freeze (bool, default: True ) –

    Whether to freeze the feature extractor. Defaults to True.

  • hyperspherical (bool, default: False ) –

    Whether to map features to an hypersphere. Defaults to False.

  • flatten_features (bool, default: True ) –

    Whether to flatten the features before the pre_classifier. May be required if your model is outputting a feature map rather than a vector. Defaults to True.

Source code in quadra/models/classification/base.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(
    self,
    features_extractor: nn.Module,
    pre_classifier: nn.Module | None = None,
    classifier: nn.Module | None = None,
    freeze: bool = True,
    hyperspherical: bool = False,
    flatten_features: bool = True,
):
    super().__init__()
    if pre_classifier is None:
        pre_classifier = nn.Identity()

    if classifier is None:
        classifier = nn.Identity()

    self.features_extractor = features_extractor
    self.freeze = freeze
    self.hyperspherical = hyperspherical
    self.pre_classifier = pre_classifier
    self.classifier = classifier
    self.flatten: bool = False
    self._hyperspherical: bool = False
    self.l2: L2Norm | None = None
    self.flatten_features = flatten_features

    self.freeze = freeze
    self.hyperspherical = hyperspherical

    if self.freeze:
        for p in self.features_extractor.parameters():
            p.requires_grad = False

freeze: bool property writable

Whether to freeze the feature extractor.

hyperspherical: bool property writable

Whether to map the extracted features into an hypersphere.