backbones
TimmNetworkBuilder(model_name, pretrained=True, pre_classifier=None, classifier=None, freeze=True, hyperspherical=False, flatten_features=True, checkpoint_path=None, **timm_kwargs)
¶
Bases: BaseNetworkBuilder
Torchvision feature extractor, with the possibility to map features to an hypersphere.
Parameters:
-
model_name
(
str
) –Timm model name
-
pretrained
(
bool
, default:True
) –Whether to load the pretrained weights for the model.
-
pre_classifier
(
Module | None
, default:None
) –Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
-
classifier
(
Module | None
, default:None
) –Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
-
freeze
(
bool
, default:True
) –Whether to freeze the feature extractor. Defaults to True.
-
hyperspherical
(
bool
, default:False
) –Whether to map features to an hypersphere. Defaults to False.
-
flatten_features
(
bool
, default:True
) –Whether to flatten the features before the pre_classifier. Defaults to True.
-
checkpoint_path
(
str | None
, default:None
) –Path to a checkpoint to load after the model is initialized. Defaults to None.
-
**timm_kwargs
(
Any
, default:{}
) –Additional arguments to pass to timm.create_model
Source code in quadra/models/classification/backbones.py
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
|
TorchHubNetworkBuilder(repo_or_dir, model_name, pretrained=True, pre_classifier=None, classifier=None, freeze=True, hyperspherical=False, flatten_features=True, checkpoint_path=None, **torch_hub_kwargs)
¶
Bases: BaseNetworkBuilder
TorchHub feature extractor, with the possibility to map features to an hypersphere.
Parameters:
-
repo_or_dir
(
str
) –The name of the repository or the path to the directory containing the model.
-
model_name
(
str
) –The name of the model within the repository.
-
pretrained
(
bool
, default:True
) –Whether to load the pretrained weights for the model.
-
pre_classifier
(
Module | None
, default:None
) –Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
-
classifier
(
Module | None
, default:None
) –Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
-
freeze
(
bool
, default:True
) –Whether to freeze the feature extractor. Defaults to True.
-
hyperspherical
(
bool
, default:False
) –Whether to map features to an hypersphere. Defaults to False.
-
flatten_features
(
bool
, default:True
) –Whether to flatten the features before the pre_classifier. Defaults to True.
-
checkpoint_path
(
str | None
, default:None
) –Path to a checkpoint to load after the model is initialized. Defaults to None.
-
**torch_hub_kwargs
(
Any
, default:{}
) –Additional arguments to pass to torch.hub.load
Source code in quadra/models/classification/backbones.py
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
|
TorchVisionNetworkBuilder(model_name, pretrained=True, pre_classifier=None, classifier=None, freeze=True, hyperspherical=False, flatten_features=True, checkpoint_path=None, **torchvision_kwargs)
¶
Bases: BaseNetworkBuilder
Torchvision feature extractor, with the possibility to map features to an hypersphere.
Parameters:
-
model_name
(
str
) –Torchvision model function that will be evaluated, for example: torchvision.models.resnet18.
-
pretrained
(
bool
, default:True
) –Whether to load the pretrained weights for the model.
-
pre_classifier
(
Module | None
, default:None
) –Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
-
classifier
(
Module | None
, default:None
) –Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
-
freeze
(
bool
, default:True
) –Whether to freeze the feature extractor. Defaults to True.
-
hyperspherical
(
bool
, default:False
) –Whether to map features to an hypersphere. Defaults to False.
-
flatten_features
(
bool
, default:True
) –Whether to flatten the features before the pre_classifier. Defaults to True.
-
checkpoint_path
(
str | None
, default:None
) –Path to a checkpoint to load after the model is initialized. Defaults to None.
-
**torchvision_kwargs
(
Any
, default:{}
) –Additional arguments to pass to the model function.
Source code in quadra/models/classification/backbones.py
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
|