Skip to content

Commit

Permalink
Launch fine-tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
liuverta committed Sep 19, 2023
1 parent 7e3ec37 commit ee0d426
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
3 changes: 2 additions & 1 deletion client/verta/verta/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,8 @@ def get_or_create_registered_model(
return self._get_or_create_registered_model(
self._conn,
self._conf,
self._ctx,
ctx,
name=name,
desc=desc,
labels=labels,
public_within_org=public_within_org,
Expand Down
44 changes: 36 additions & 8 deletions client/verta/verta/registry/entities/_modelversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,15 +1814,18 @@ def finetune(
test_dataset: Optional[_dataset_version.DatasetVersion] = None,
name: Optional[str] = None,
finetuning_config: Optional[verta.finetune._FinetuningConfig] = None,
):
) -> "verta.registry.entities.RegisteredModelVersion":
""""""
from verta import Client
from verta.tracking.entities import ExperimentRun

# TODO: check `enable_mdb_versioning`

ctx = _Context(self._conn, self._conf)
ctx.workspace_name = self.workspace

if not self.get_attributes().get(verta.finetune._FINETUNE_BASE_RMV_ATTR_KEY):
raise ValueError("this model version is not eligible for fine-tuning")
# if not self.get_attributes().get(verta.finetune._FINETUNE_BASE_RMV_ATTR_KEY):
# raise ValueError("this model version is not eligible for fine-tuning")
if isinstance(destination_registered_model, str):
destination_registered_model = Client._get_or_create_registered_model(
self._conn,
Expand All @@ -1831,6 +1834,7 @@ def finetune(
name=destination_registered_model,
)

# create experiment run
ctx.proj = Client._get_or_create_project(
self._conn,
self._conf,
Expand All @@ -1844,10 +1848,34 @@ def finetune(
ctx,
name=verta.finetune._EXPERIMENT_NAME,
)
# TODO: create ER
# TODO: log attributes
# TODO: log dataset versions
run = ExperimentRun._create(
self._conn,
self._conf,
ctx,
attrs={verta.finetune._FINETUNE_ATTR_KEY: True},
)

# log dataset versions
run.log_dataset_version(verta.finetune._TRAIN_DATASET_NAME, train_dataset)
if eval_dataset is not None:
run.log_dataset_version(verta.finetune._EVAL_DATASET_NAME, eval_dataset)
if test_dataset is not None:
run.log_dataset_version(verta.finetune._TEST_DATASET_NAME, test_dataset)

# TODO: create RMV from ER
model_ver = destination_registered_model.create_version_from_run(run, name=name)

# launch fine-tuning
data = {
"base_model_version_id": self.id,
"run_id": run.id,
finetuning_config._JOB_DICT_KEY: finetuning_config._as_dict(),
}
url = "{}://{}/api/v1/deployment/workspace/{}/finetuning-job".format(
self._conn.scheme,
self._conn.socket,
self.workspace,
)
response = _utils.make_request("POST", url, self._conn, json=data)
_utils.raise_for_http_error(response)

# TODO: POST fine-tune job
return model_ver

2 comments on commit ee0d426

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docker Tag: liu_finetune-2023-09-19T00-07-06--ee0d426

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Total coverage (common): 15.09
Total coverage (server): 61.32

Changed Files coverage (common): coverage 100
Changed Files coverage (server): 100

Please sign in to comment.