Skip to content

model_manager

AbstractModelManager

Bases: ABC

Abstract class for model managers.

delete_model(model_name, version, description=None) abstractmethod

Delete a model with the given version.

Source code in quadra/utils/model_manager.py
46
47
48
@abstractmethod
def delete_model(self, model_name: str, version: int, description: str | None = None) -> None:
    """Delete a model with the given version."""

download_model(model_name, version, output_path) abstractmethod

Download the model with the given version to the given output path.

Source code in quadra/utils/model_manager.py
63
64
65
@abstractmethod
def download_model(self, model_name: str, version: int, output_path: str) -> None:
    """Download the model with the given version to the given output path."""

get_latest_version(model_name) abstractmethod

Get the latest version of a model for all the possible stages or filtered by stage.

Source code in quadra/utils/model_manager.py
38
39
40
@abstractmethod
def get_latest_version(self, model_name: str) -> Any:
    """Get the latest version of a model for all the possible stages or filtered by stage."""

register_best_model(experiment_name, metric, model_name, description, tags=None, mode='max', model_path='deployment_model') abstractmethod

Register the best model from an experiment.

Source code in quadra/utils/model_manager.py
50
51
52
53
54
55
56
57
58
59
60
61
@abstractmethod
def register_best_model(
    self,
    experiment_name: str,
    metric: str,
    model_name: str,
    description: str,
    tags: dict[str, Any] | None = None,
    mode: Literal["max", "min"] = "max",
    model_path: str = "deployment_model",
) -> Any:
    """Register the best model from an experiment."""

register_model(model_location, model_name, description, tags=None) abstractmethod

Register a model in the model registry.

Source code in quadra/utils/model_manager.py
32
33
34
35
36
@abstractmethod
def register_model(
    self, model_location: str, model_name: str, description: str, tags: dict[str, Any] | None = None
) -> Any:
    """Register a model in the model registry."""

transition_model(model_name, version, stage, description=None) abstractmethod

Transition the model with the given version to a new stage.

Source code in quadra/utils/model_manager.py
42
43
44
@abstractmethod
def transition_model(self, model_name: str, version: int, stage: str, description: str | None = None) -> Any:
    """Transition the model with the given version to a new stage."""

MlflowModelManager()

Bases: AbstractModelManager

Model manager for Mlflow.

Source code in quadra/utils/model_manager.py
71
72
73
74
75
76
77
78
79
def __init__(self):
    if not MLFLOW_AVAILABLE:
        raise ImportError("Mlflow is not available, please install it with pip install mlflow")

    if os.getenv("MLFLOW_TRACKING_URI") is None:
        raise ValueError("MLFLOW_TRACKING_URI environment variable is not set")

    mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI"))
    self.client = MlflowClient()

delete_model(model_name, version, description=None)

Delete a model.

Parameters:

  • model_name (str) –

    The name of the model

  • version (int) –

    The version of the model

  • description (str | None, default: None ) –

    Why the model was deleted, this will be added to the model changelog

Source code in quadra/utils/model_manager.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def delete_model(self, model_name: str, version: int, description: str | None = None) -> None:
    """Delete a model.

    Args:
        model_name: The name of the model
        version: The version of the model
        description: Why the model was deleted, this will be added to the model changelog
    """
    model_stage = self._safe_get_stage(model_name, version)

    if model_stage is None:
        return

    if (
        input(
            f"Model named `{model_name}`, version {version} is in stage {model_stage}, "
            "type the model name to continue deletion:"
        )
        != model_name
    ):
        log.warning("Model name did not match, aborting deletion")
        return

    log.info("Deleting model %s version %s", model_name, version)
    self.client.delete_model_version(model_name, version)

    registered_model_description = self.client.get_registered_model(model_name).description

    new_model_description = "## **Deletion:**\n"
    new_model_description += f"### Version {version} from stage: {model_stage}\n"
    new_model_description += self._get_author_and_date()
    new_model_description += self._generate_description(description)

    self.client.update_registered_model(model_name, registered_model_description + new_model_description)

download_model(model_name, version, output_path)

Download the model with the given version to the given output path.

Parameters:

  • model_name (str) –

    The name of the model

  • version (int) –

    The version of the model

  • output_path (str) –

    The path to save the model to

Source code in quadra/utils/model_manager.py
274
275
276
277
278
279
280
281
282
283
284
285
286
287
def download_model(self, model_name: str, version: int, output_path: str) -> None:
    """Download the model with the given version to the given output path.

    Args:
        model_name: The name of the model
        version: The version of the model
        output_path: The path to save the model to
    """
    artifact_uri = self.client.get_model_version_download_uri(model_name, version)
    log.info("Downloading model %s version %s from %s to %s", model_name, version, artifact_uri, output_path)
    if not os.path.exists(output_path):
        log.info("Creating output path %s", output_path)
        os.makedirs(output_path)
    mlflow.artifacts.download_artifacts(artifact_uri=artifact_uri, dst_path=output_path)

get_latest_version(model_name)

Get the latest version of a model.

Parameters:

  • model_name (str) –

    The name of the model

Returns:

  • ModelVersion

    The model version

Source code in quadra/utils/model_manager.py
116
117
118
119
120
121
122
123
124
125
126
127
128
def get_latest_version(self, model_name: str) -> ModelVersion:
    """Get the latest version of a model.

    Args:
        model_name: The name of the model

    Returns:
        The model version
    """
    latest_version = max(int(x.version) for x in self.client.get_latest_versions(model_name))
    model_version = self.client.get_model_version(model_name, latest_version)

    return model_version

register_best_model(experiment_name, metric, model_name, description=None, tags=None, mode='max', model_path='deployment_model')

Register the best model from an experiment.

Parameters:

  • experiment_name (str) –

    The name of the experiment

  • metric (str) –

    The metric to use to determine the best model

  • model_name (str) –

    The name of the model after it is registered

  • description (str | None, default: None ) –

    A description of the model, this will be added to the model changelog

  • tags (dict[str, Any] | None, default: None ) –

    A dictionary of tags to add to the model

  • mode (Literal['max', 'min'], default: 'max' ) –

    The mode to use to determine the best model, either "max" or "min"

  • model_path (str, default: 'deployment_model' ) –

    The path to the model within the experiment run

Returns:

  • ModelVersion | None

    The registered model version if successful, otherwise None

Source code in quadra/utils/model_manager.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
def register_best_model(
    self,
    experiment_name: str,
    metric: str,
    model_name: str,
    description: str | None = None,
    tags: dict[str, Any] | None = None,
    mode: Literal["max", "min"] = "max",
    model_path: str = "deployment_model",
) -> ModelVersion | None:
    """Register the best model from an experiment.

    Args:
        experiment_name: The name of the experiment
        metric: The metric to use to determine the best model
        model_name: The name of the model after it is registered
        description: A description of the model, this will be added to the model changelog
        tags: A dictionary of tags to add to the model
        mode: The mode to use to determine the best model, either "max" or "min"
        model_path: The path to the model within the experiment run

    Returns:
        The registered model version if successful, otherwise None
    """
    if mode not in ["max", "min"]:
        raise ValueError(f"Mode must be either 'max' or 'min', got {mode}")

    experiment_id = self.client.get_experiment_by_name(experiment_name).experiment_id
    runs = self.client.search_runs(experiment_ids=[experiment_id])

    if len(runs) == 0:
        log.error("No runs found for experiment %s", experiment_name)
        return None

    best_run: Run | None = None

    # We can only make comparisons if the model is on the top folder, otherwise just check if the folder exists
    # TODO: Is there a better way to do this?
    base_model_path = model_path.split("/")[0]

    for run in runs:
        run_artifacts = [x.path for x in self.client.list_artifacts(run.info.run_id) if x.path == base_model_path]

        if len(run_artifacts) == 0:
            # If we don't find the given model path, skip this run
            continue

        if best_run is None:
            # If we find a run with the model it must also have the metric
            if run.data.metrics.get(metric) is not None:
                best_run = run
            continue

        if mode == "max":
            if run.data.metrics[metric] > best_run.data.metrics[metric]:
                best_run = run
        elif run.data.metrics[metric] < best_run.data.metrics[metric]:
            best_run = run

    if best_run is None:
        log.error("No runs found for experiment %s with the given metric", experiment_name)
        return None

    best_model_uri = f"runs:/{best_run.info.run_id}/{model_path}"

    model_version = self.register_model(
        model_location=best_model_uri, model_name=model_name, tags=tags, description=description
    )

    return model_version

register_model(model_location, model_name, description=None, tags=None)

Register a model in the model registry.

Parameters:

  • model_location (str) –

    The model uri

  • model_name (str) –

    The name of the model after it is registered

  • description (str | None, default: None ) –

    A description of the model, this will be added to the model changelog

  • tags (dict[str, Any] | None, default: None ) –

    A dictionary of tags to add to the model

Returns:

  • ModelVersion

    The model version

Source code in quadra/utils/model_manager.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def register_model(
    self, model_location: str, model_name: str, description: str | None = None, tags: dict[str, Any] | None = None
) -> ModelVersion:
    """Register a model in the model registry.

    Args:
        model_location: The model uri
        model_name: The name of the model after it is registered
        description: A description of the model, this will be added to the model changelog
        tags: A dictionary of tags to add to the model

    Returns:
        The model version
    """
    model_version = mlflow.register_model(model_uri=model_location, name=model_name, tags=tags)
    log.info("Registered model %s with version %s", model_name, model_version.version)
    registered_model_description = self.client.get_registered_model(model_name).description

    if model_version.version == "1":
        header = "# MODEL CHANGELOG\n"
    else:
        header = ""

    new_model_description = VERSION_MD_TEMPLATE.format(model_version.version)
    new_model_description += self._get_author_and_date()
    new_model_description += self._generate_description(description)

    self.client.update_registered_model(model_name, header + registered_model_description + new_model_description)

    self.client.update_model_version(
        model_name, model_version.version, "# MODEL CHANGELOG\n" + new_model_description
    )

    return model_version

transition_model(model_name, version, stage, description=None)

Transition a model to a new stage.

Parameters:

  • model_name (str) –

    The name of the model

  • version (int) –

    The version of the model

  • stage (str) –

    The stage of the model

  • description (str | None, default: None ) –

    A description of the transition, this will be added to the model changelog

Source code in quadra/utils/model_manager.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def transition_model(
    self, model_name: str, version: int, stage: str, description: str | None = None
) -> ModelVersion | None:
    """Transition a model to a new stage.

    Args:
        model_name: The name of the model
        version: The version of the model
        stage: The stage of the model
        description: A description of the transition, this will be added to the model changelog
    """
    previous_stage = self._safe_get_stage(model_name, version)

    if previous_stage is None:
        return None

    if previous_stage.lower() == stage.lower():
        log.warning("Model %s version %s is already in stage %s", model_name, version, stage)
        return self.client.get_model_version(model_name, version)

    log.info("Transitioning model %s version %s from %s to %s", model_name, version, previous_stage, stage)
    model_version = self.client.transition_model_version_stage(name=model_name, version=version, stage=stage)
    new_stage = model_version.current_stage
    registered_model_description = self.client.get_registered_model(model_name).description
    single_model_description = self.client.get_model_version(model_name, version).description

    new_model_description = "## **Transition:**\n"
    new_model_description += f"### Version {model_version.version} from {previous_stage} to {new_stage}\n"
    new_model_description += self._get_author_and_date()
    new_model_description += self._generate_description(description)

    self.client.update_registered_model(model_name, registered_model_description + new_model_description)
    self.client.update_model_version(
        model_name, model_version.version, single_model_description + new_model_description
    )

    return model_version