Skip to content

backbones

TimmNetworkBuilder(model_name, pretrained=True, pre_classifier=nn.Identity(), classifier=nn.Identity(), freeze=True, hyperspherical=False, flatten_features=True, **timm_kwargs)

Bases: BaseNetworkBuilder

Torchvision feature extractor, with the possibility to map features to an hypersphere.

Parameters:

  • model_name (str) –

    Timm model name

  • pretrained (bool) –

    Whether to load the pretrained weights for the model.

  • pre_classifier (nn.Module) –

    Pre classifier as a torch.nn.Module. Defaults to nn.Identity().

  • classifier (nn.Module) –

    Classifier as a torch.nn.Module. Defaults to nn.Identity().

  • freeze (bool) –

    Whether to freeze the feature extractor. Defaults to True.

  • hyperspherical (bool) –

    Whether to map features to an hypersphere. Defaults to False.

  • flatten_features (bool) –

    Whether to flatten the features before the pre_classifier. Defaults to True.

  • **timm_kwargs (Any) –

    Additional arguments to pass to timm.create_model

Source code in quadra/models/classification/backbones.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def __init__(
    self,
    model_name: str,
    pretrained: bool = True,
    pre_classifier: nn.Module = nn.Identity(),
    classifier: nn.Module = nn.Identity(),
    freeze: bool = True,
    hyperspherical: bool = False,
    flatten_features: bool = True,
    **timm_kwargs: Any,
):
    self.pretrained = pretrained
    features_extractor = timm.create_model(model_name, pretrained=self.pretrained, num_classes=0, **timm_kwargs)

    super().__init__(
        features_extractor=features_extractor,
        pre_classifier=pre_classifier,
        classifier=classifier,
        freeze=freeze,
        hyperspherical=hyperspherical,
        flatten_features=flatten_features,
    )

TorchHubNetworkBuilder(repo_or_dir, model_name, pretrained=True, pre_classifier=nn.Identity(), classifier=nn.Identity(), freeze=True, hyperspherical=False, flatten_features=True, **torch_hub_kwargs)

Bases: BaseNetworkBuilder

TorchHub feature extractor, with the possibility to map features to an hypersphere.

Parameters:

  • repo_or_dir (str) –

    The name of the repository or the path to the directory containing the model.

  • model_name (str) –

    The name of the model within the repository.

  • pretrained (bool) –

    Whether to load the pretrained weights for the model.

  • pre_classifier (nn.Module) –

    Pre classifier as a torch.nn.Module. Defaults to nn.Identity().

  • classifier (nn.Module) –

    Classifier as a torch.nn.Module. Defaults to nn.Identity().

  • freeze (bool) –

    Whether to freeze the feature extractor. Defaults to True.

  • hyperspherical (bool) –

    Whether to map features to an hypersphere. Defaults to False.

  • flatten_features (bool) –

    Whether to flatten the features before the pre_classifier. Defaults to True.

  • **torch_hub_kwargs (Any) –

    Additional arguments to pass to torch.hub.load

Source code in quadra/models/classification/backbones.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def __init__(
    self,
    repo_or_dir: str,
    model_name: str,
    pretrained: bool = True,
    pre_classifier: nn.Module = nn.Identity(),
    classifier: nn.Module = nn.Identity(),
    freeze: bool = True,
    hyperspherical: bool = False,
    flatten_features: bool = True,
    **torch_hub_kwargs: Any,
):
    self.pretrained = pretrained
    features_extractor = torch.hub.load(
        repo_or_dir=repo_or_dir, model=model_name, pretrained=self.pretrained, **torch_hub_kwargs
    )
    super().__init__(
        features_extractor=features_extractor,
        pre_classifier=pre_classifier,
        classifier=classifier,
        freeze=freeze,
        hyperspherical=hyperspherical,
        flatten_features=flatten_features,
    )

TorchVisionNetworkBuilder(model_name, pretrained=True, pre_classifier=nn.Identity(), classifier=nn.Identity(), freeze=True, hyperspherical=False, flatten_features=True, **torchvision_kwargs)

Bases: BaseNetworkBuilder

Torchvision feature extractor, with the possibility to map features to an hypersphere.

Parameters:

  • model_name (str) –

    Torchvision model function that will be evaluated, for example: torchvision.models.resnet18.

  • pretrained (bool) –

    Whether to load the pretrained weights for the model.

  • pre_classifier (nn.Module) –

    Pre classifier as a torch.nn.Module. Defaults to nn.Identity().

  • classifier (nn.Module) –

    Classifier as a torch.nn.Module. Defaults to nn.Identity().

  • freeze (bool) –

    Whether to freeze the feature extractor. Defaults to True.

  • hyperspherical (bool) –

    Whether to map features to an hypersphere. Defaults to False.

  • flatten_features (bool) –

    Whether to flatten the features before the pre_classifier. Defaults to True.

  • **torchvision_kwargs (Any) –

    Additional arguments to pass to the model function.

Source code in quadra/models/classification/backbones.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __init__(
    self,
    model_name: str,
    pretrained: bool = True,
    pre_classifier: nn.Module = nn.Identity(),
    classifier: nn.Module = nn.Identity(),
    freeze: bool = True,
    hyperspherical: bool = False,
    flatten_features: bool = True,
    **torchvision_kwargs: Any,
):
    self.pretrained = pretrained
    model_function = models.__dict__[model_name]
    features_extractor = model_function(pretrained=self.pretrained, progress=True, **torchvision_kwargs)
    # Remove classifier
    features_extractor.classifier = nn.Identity()
    super().__init__(
        features_extractor=features_extractor,
        pre_classifier=pre_classifier,
        classifier=classifier,
        freeze=freeze,
        hyperspherical=hyperspherical,
        flatten_features=flatten_features,
    )