Skip to content

backbones

TimmNetworkBuilder(model_name, pretrained=True, pre_classifier=None, classifier=None, freeze=True, hyperspherical=False, flatten_features=True, checkpoint_path=None, **timm_kwargs)

Bases: BaseNetworkBuilder

Torchvision feature extractor, with the possibility to map features to an hypersphere.

Parameters:

  • model_name (str) –

    Timm model name

  • pretrained (bool, default: True ) –

    Whether to load the pretrained weights for the model.

  • pre_classifier (Module | None, default: None ) –

    Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.

  • classifier (Module | None, default: None ) –

    Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.

  • freeze (bool, default: True ) –

    Whether to freeze the feature extractor. Defaults to True.

  • hyperspherical (bool, default: False ) –

    Whether to map features to an hypersphere. Defaults to False.

  • flatten_features (bool, default: True ) –

    Whether to flatten the features before the pre_classifier. Defaults to True.

  • checkpoint_path (str | None, default: None ) –

    Path to a checkpoint to load after the model is initialized. Defaults to None.

  • **timm_kwargs (Any, default: {} ) –

    Additional arguments to pass to timm.create_model

Source code in quadra/models/classification/backbones.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def __init__(
    self,
    model_name: str,
    pretrained: bool = True,
    pre_classifier: nn.Module | None = None,
    classifier: nn.Module | None = None,
    freeze: bool = True,
    hyperspherical: bool = False,
    flatten_features: bool = True,
    checkpoint_path: str | None = None,
    **timm_kwargs: Any,
):
    self.pretrained = pretrained
    features_extractor = timm.create_model(
        model_name, pretrained=self.pretrained, num_classes=0, checkpoint_path=checkpoint_path, **timm_kwargs
    )

    super().__init__(
        features_extractor=features_extractor,
        pre_classifier=pre_classifier,
        classifier=classifier,
        freeze=freeze,
        hyperspherical=hyperspherical,
        flatten_features=flatten_features,
    )

TorchHubNetworkBuilder(repo_or_dir, model_name, pretrained=True, pre_classifier=None, classifier=None, freeze=True, hyperspherical=False, flatten_features=True, checkpoint_path=None, **torch_hub_kwargs)

Bases: BaseNetworkBuilder

TorchHub feature extractor, with the possibility to map features to an hypersphere.

Parameters:

  • repo_or_dir (str) –

    The name of the repository or the path to the directory containing the model.

  • model_name (str) –

    The name of the model within the repository.

  • pretrained (bool, default: True ) –

    Whether to load the pretrained weights for the model.

  • pre_classifier (Module | None, default: None ) –

    Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.

  • classifier (Module | None, default: None ) –

    Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.

  • freeze (bool, default: True ) –

    Whether to freeze the feature extractor. Defaults to True.

  • hyperspherical (bool, default: False ) –

    Whether to map features to an hypersphere. Defaults to False.

  • flatten_features (bool, default: True ) –

    Whether to flatten the features before the pre_classifier. Defaults to True.

  • checkpoint_path (str | None, default: None ) –

    Path to a checkpoint to load after the model is initialized. Defaults to None.

  • **torch_hub_kwargs (Any, default: {} ) –

    Additional arguments to pass to torch.hub.load

Source code in quadra/models/classification/backbones.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def __init__(
    self,
    repo_or_dir: str,
    model_name: str,
    pretrained: bool = True,
    pre_classifier: nn.Module | None = None,
    classifier: nn.Module | None = None,
    freeze: bool = True,
    hyperspherical: bool = False,
    flatten_features: bool = True,
    checkpoint_path: str | None = None,
    **torch_hub_kwargs: Any,
):
    self.pretrained = pretrained
    features_extractor = torch.hub.load(
        repo_or_dir=repo_or_dir, model=model_name, pretrained=self.pretrained, **torch_hub_kwargs
    )
    if checkpoint_path:
        log.info("Loading checkpoint from %s", checkpoint_path)
        load_checkpoint(model=features_extractor, checkpoint_path=checkpoint_path)

    super().__init__(
        features_extractor=features_extractor,
        pre_classifier=pre_classifier,
        classifier=classifier,
        freeze=freeze,
        hyperspherical=hyperspherical,
        flatten_features=flatten_features,
    )

TorchVisionNetworkBuilder(model_name, pretrained=True, pre_classifier=None, classifier=None, freeze=True, hyperspherical=False, flatten_features=True, checkpoint_path=None, **torchvision_kwargs)

Bases: BaseNetworkBuilder

Torchvision feature extractor, with the possibility to map features to an hypersphere.

Parameters:

  • model_name (str) –

    Torchvision model function that will be evaluated, for example: torchvision.models.resnet18.

  • pretrained (bool, default: True ) –

    Whether to load the pretrained weights for the model.

  • pre_classifier (Module | None, default: None ) –

    Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.

  • classifier (Module | None, default: None ) –

    Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.

  • freeze (bool, default: True ) –

    Whether to freeze the feature extractor. Defaults to True.

  • hyperspherical (bool, default: False ) –

    Whether to map features to an hypersphere. Defaults to False.

  • flatten_features (bool, default: True ) –

    Whether to flatten the features before the pre_classifier. Defaults to True.

  • checkpoint_path (str | None, default: None ) –

    Path to a checkpoint to load after the model is initialized. Defaults to None.

  • **torchvision_kwargs (Any, default: {} ) –

    Additional arguments to pass to the model function.

Source code in quadra/models/classification/backbones.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def __init__(
    self,
    model_name: str,
    pretrained: bool = True,
    pre_classifier: nn.Module | None = None,
    classifier: nn.Module | None = None,
    freeze: bool = True,
    hyperspherical: bool = False,
    flatten_features: bool = True,
    checkpoint_path: str | None = None,
    **torchvision_kwargs: Any,
):
    self.pretrained = pretrained
    model_function = models.__dict__[model_name]
    features_extractor = model_function(pretrained=self.pretrained, progress=True, **torchvision_kwargs)
    if checkpoint_path:
        log.info("Loading checkpoint from %s", checkpoint_path)
        load_checkpoint(model=features_extractor, checkpoint_path=checkpoint_path)

    # Remove classifier
    features_extractor.classifier = nn.Identity()
    super().__init__(
        features_extractor=features_extractor,
        pre_classifier=pre_classifier,
        classifier=classifier,
        freeze=freeze,
        hyperspherical=hyperspherical,
        flatten_features=flatten_features,
    )