From ee0d42646f7d37b420100b4b053d0ed47efd715f Mon Sep 17 00:00:00 2001 From: Michael Liu Date: Mon, 18 Sep 2023 17:07:06 -0700 Subject: [PATCH] Launch fine-tuning --- client/verta/verta/client.py | 3 +- .../verta/registry/entities/_modelversion.py | 44 +++++++++++++++---- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/client/verta/verta/client.py b/client/verta/verta/client.py index d6eb8da50f..6e9e097ca6 100644 --- a/client/verta/verta/client.py +++ b/client/verta/verta/client.py @@ -931,7 +931,8 @@ def get_or_create_registered_model( return self._get_or_create_registered_model( self._conn, self._conf, - self._ctx, + ctx, + name=name, desc=desc, labels=labels, public_within_org=public_within_org, diff --git a/client/verta/verta/registry/entities/_modelversion.py b/client/verta/verta/registry/entities/_modelversion.py index e32c5e1c7b..4894846b96 100644 --- a/client/verta/verta/registry/entities/_modelversion.py +++ b/client/verta/verta/registry/entities/_modelversion.py @@ -1814,15 +1814,18 @@ def finetune( test_dataset: Optional[_dataset_version.DatasetVersion] = None, name: Optional[str] = None, finetuning_config: Optional[verta.finetune._FinetuningConfig] = None, - ): + ) -> "verta.registry.entities.RegisteredModelVersion": """""" from verta import Client + from verta.tracking.entities import ExperimentRun + + # TODO: check `enable_mdb_versioning` ctx = _Context(self._conn, self._conf) ctx.workspace_name = self.workspace - if not self.get_attributes().get(verta.finetune._FINETUNE_BASE_RMV_ATTR_KEY): - raise ValueError("this model version is not eligible for fine-tuning") + # if not self.get_attributes().get(verta.finetune._FINETUNE_BASE_RMV_ATTR_KEY): + # raise ValueError("this model version is not eligible for fine-tuning") if isinstance(destination_registered_model, str): destination_registered_model = Client._get_or_create_registered_model( self._conn, @@ -1831,6 +1834,7 @@ def finetune( name=destination_registered_model, ) + # create experiment run ctx.proj = Client._get_or_create_project( self._conn, self._conf, @@ -1844,10 +1848,34 @@ def finetune( ctx, name=verta.finetune._EXPERIMENT_NAME, ) - # TODO: create ER - # TODO: log attributes - # TODO: log dataset versions + run = ExperimentRun._create( + self._conn, + self._conf, + ctx, + attrs={verta.finetune._FINETUNE_ATTR_KEY: True}, + ) + + # log dataset versions + run.log_dataset_version(verta.finetune._TRAIN_DATASET_NAME, train_dataset) + if eval_dataset is not None: + run.log_dataset_version(verta.finetune._EVAL_DATASET_NAME, eval_dataset) + if test_dataset is not None: + run.log_dataset_version(verta.finetune._TEST_DATASET_NAME, test_dataset) - # TODO: create RMV from ER + model_ver = destination_registered_model.create_version_from_run(run, name=name) + + # launch fine-tuning + data = { + "base_model_version_id": self.id, + "run_id": run.id, + finetuning_config._JOB_DICT_KEY: finetuning_config._as_dict(), + } + url = "{}://{}/api/v1/deployment/workspace/{}/finetuning-job".format( + self._conn.scheme, + self._conn.socket, + self.workspace, + ) + response = _utils.make_request("POST", url, self._conn, json=data) + _utils.raise_for_http_error(response) - # TODO: POST fine-tune job + return model_ver