diff --git a/client/verta/tests/finetune/__init__.py b/client/verta/tests/finetune/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/client/verta/tests/finetune/test_finetune.py b/client/verta/tests/finetune/test_finetune.py new file mode 100644 index 0000000000..ca26a2268f --- /dev/null +++ b/client/verta/tests/finetune/test_finetune.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- + +"""Test RegisteredModelVersion.finetune() and its downstream effects.""" + +from verta.dataset import Path +from verta import finetune + + +def test_finetune(client, registered_model, dataset): + """Verify that happy-path ``finetune()`` works.""" + base_model_ver = registered_model.create_version() # mocked base LLM RMV + name = "v1" + + reg_model = client.create_registered_model() + train_dataset_version = dataset.create_version( + Path(__file__, enable_mdb_versioning=True), + ) + model_ver = base_model_ver.finetune( + destination_registered_model=reg_model, + train_dataset=train_dataset_version, + name=name, + ) + run = client.get_experiment_run(id=model_ver.experiment_run_id) + + # check entity names + assert client.proj.name == reg_model.name + finetune._PROJECT_NAME_SUFFIX + assert client.expt.name == finetune._EXPERIMENT_NAME_PREFIX + dataset.name + # TODO: wait for fine-tuning to launch, then check ER name + assert model_ver.name == name + + # check dataset association + for entity in [model_ver, run]: + for key, value in [ + (finetune._TRAIN_DATASET_NAME, train_dataset_version), + # TODO: eval and test, too + ]: + assert entity.get_dataset_version(key).id == value.id + + # check attributes + for entity in [model_ver, run]: + assert ( + entity.get_attributes().items() + >= { + finetune._FINETUNE_ATTR_KEY: True, + }.items() + ) diff --git a/client/verta/verta/client.py b/client/verta/verta/client.py index 60ee4202cc..6e9e097ca6 100644 --- a/client/verta/verta/client.py +++ b/client/verta/verta/client.py @@ -154,7 +154,9 @@ def __init__( _connect=True, ): if organization_id is not None and organization_name is not None: - raise ValueError("cannot provide both `organization_id` and `organization_name`") + raise ValueError( + "cannot provide both `organization_id` and `organization_name`" + ) self._load_config() host = self._get_with_fallback(host, env_var="VERTA_HOST", config_var="host") @@ -225,10 +227,10 @@ def __init__( # ) # response = conn.must_proto_response(request, OrganizationV2_pb2.GetOrganizationByNameV2.Response) # organization_id = response.organization.id - request = conn.make_proto_request( - "GET", "/api/v2/uac-proxy/organization" + request = conn.make_proto_request("GET", "/api/v2/uac-proxy/organization") + response = conn.must_proto_response( + request, OrganizationV2_pb2.ListOrganizationsV2.Response ) - response = conn.must_proto_response(request, OrganizationV2_pb2.ListOrganizationsV2.Response) for org in response.organizations: if org.name == organization_name: organization_id = org.id @@ -509,26 +511,55 @@ def set_project( self._ctx = _Context(self._conn, self._conf) self._ctx.workspace_name = workspace + return self._get_or_create_project( + self._conn, + self._conf, + self._ctx, + name=name, + desc=desc, + tags=tags, + attrs=attrs, + public_within_org=public_within_org, + visibility=visibility, + id=id, + ) + + @staticmethod + def _get_or_create_project( + conn, + conf, + ctx, + name=None, + desc=None, + tags=None, + attrs=None, + public_within_org=None, + visibility=None, + id=None, + ): + """:meth:`set_project`, but static. + + Requires either `ctx.workspace_name` or `id`. + + """ resource_name = "Project" param_names = "`desc`, `tags`, `attrs`, `public_within_org`, or `visibility`" params = (desc, tags, attrs, public_within_org, visibility) if id is not None: - self._ctx.proj = Project._get_by_id(self._conn, self._conf, id) + ctx.proj = Project._get_by_id(conn, conf, id) check_unnecessary_params_warning( resource_name, "id {}".format(id), param_names, params ) - self._ctx.populate() + ctx.populate() else: - self._ctx.proj = Project._get_or_create_by_name( - self._conn, + ctx.proj = Project._get_or_create_by_name( + conn, name, - lambda name: Project._get_by_name( - self._conn, self._conf, name, self._ctx.workspace_name - ), + lambda name: Project._get_by_name(conn, conf, name, ctx.workspace_name), lambda name: Project._create( - self._conn, - self._conf, - self._ctx, + conn, + conf, + ctx, name=name, desc=desc, tags=tags, @@ -541,7 +572,7 @@ def set_project( ), ) - return self._ctx.proj + return ctx.proj def get_experiment(self, name=None, id=None): """ @@ -621,29 +652,54 @@ def set_experiment(self, name=None, desc=None, tags=None, attrs=None, id=None): name = self._set_from_config_if_none(name, "experiment") + if id is None and self._ctx.proj is None: + self.set_project() + + return self._get_or_create_experiment( + self._conn, + self._conf, + self._ctx, + name=name, + desc=desc, + tags=tags, + attrs=attrs, + id=id, + ) + + @staticmethod + def _get_or_create_experiment( + conn, + conf, + ctx, + name=None, + desc=None, + tags=None, + attrs=None, + id=None, + ): + """:meth:`set_experiment`, but static. + + Requires either `ctx.proj` or `id` + + """ resource_name = "Experiment" param_names = "`desc`, `tags`, or `attrs`" params = (desc, tags, attrs) if id is not None: - self._ctx.expt = Experiment._get_by_id(self._conn, self._conf, id) + ctx.expt = Experiment._get_by_id(conn, conf, id) check_unnecessary_params_warning( resource_name, "id {}".format(id), param_names, params ) - self._ctx.populate() + ctx.populate() else: - if self._ctx.proj is None: - self.set_project() - - self._ctx.expt = Experiment._get_or_create_by_name( - self._conn, + ctx.expt = Experiment._get_or_create_by_name( + conn, name, - lambda name: Experiment._get_by_name( - self._conn, self._conf, name, self._ctx.proj.id - ), + lambda name: Experiment._get_by_name(conn, conf, name, ctx.proj.id), lambda name: Experiment._create( - self._conn, - self._conf, - self._ctx, + conn, + conf, + ctx, name=name, desc=desc, tags=tags, @@ -654,7 +710,7 @@ def set_experiment(self, name=None, desc=None, tags=None, attrs=None, id=None): ), ) - return self._ctx.expt + return ctx.expt def get_experiment_run(self, name=None, id=None): """ @@ -872,24 +928,59 @@ def get_or_create_registered_model( ctx = _Context(self._conn, self._conf) ctx.workspace_name = workspace + return self._get_or_create_registered_model( + self._conn, + self._conf, + ctx, + name=name, + desc=desc, + labels=labels, + public_within_org=public_within_org, + visibility=visibility, + id=id, + task_type=task_type, + data_type=data_type, + pii=pii, + ) + + @staticmethod + def _get_or_create_registered_model( + conn, + conf, + ctx, + name=None, + desc=None, + labels=None, + public_within_org=None, + visibility=None, + id=None, + task_type=None, + data_type=None, + pii=False, + ): + """:meth:`get_or_create_registered_model`, but static. + + Requires either `ctx.workspace_name` or `id`. + + """ resource_name = "Registered Model" param_names = "`desc`, `labels`, `public_within_org`, or `visibility`" params = (desc, labels, public_within_org, visibility) if id is not None: - registered_model = RegisteredModel._get_by_id(self._conn, self._conf, id) + registered_model = RegisteredModel._get_by_id(conn, conf, id) check_unnecessary_params_warning( resource_name, "id {}".format(id), param_names, params ) else: registered_model = RegisteredModel._get_or_create_by_name( - self._conn, + conn, name, lambda name: RegisteredModel._get_by_name( - self._conn, self._conf, name, ctx.workspace_name + conn, conf, name, ctx.workspace_name ), lambda name: RegisteredModel._create( - self._conn, - self._conf, + conn, + conf, ctx, name=name, desc=desc, diff --git a/client/verta/verta/finetune/__init__.py b/client/verta/verta/finetune/__init__.py index 4535bb3e20..45a87c6a56 100644 --- a/client/verta/verta/finetune/__init__.py +++ b/client/verta/verta/finetune/__init__.py @@ -2,11 +2,21 @@ """Utilities for model fine-tuning.""" from verta._internal_utils import documentation +from verta.tracking.entities._deployable_entity import _RESERVED_ATTR_PREFIX from ._finetuning_config import _FinetuningConfig from ._lora_config import LoraConfig +_PROJECT_NAME_SUFFIX = " Fine-Tuning" # append to RM name for projname +_EXPERIMENT_NAME_PREFIX = "On Dataset: " # prepend to train dataset name for expt name +_TRAIN_DATASET_NAME = "train" +_EVAL_DATASET_NAME = "eval" +_TEST_DATASET_NAME = "test" +_FINETUNE_BASE_RMV_ATTR_KEY = f"{_RESERVED_ATTR_PREFIX}FINETUNE_BASE" +_FINETUNE_ATTR_KEY = f"{_RESERVED_ATTR_PREFIX}FINETUNE" + + documentation.reassign_module( [ _FinetuningConfig, diff --git a/client/verta/verta/registry/_constants.py b/client/verta/verta/registry/_constants.py index b6f9b03cc6..3a128af85b 100644 --- a/client/verta/verta/registry/_constants.py +++ b/client/verta/verta/registry/_constants.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- from verta._protos.public.registry import ModelMetadata_pb2 +from verta.tracking.entities._deployable_entity import _RESERVED_ATTR_PREFIX -MODEL_LANGUAGE_ATTR_KEY = "__verta_reserved__model_language" -MODEL_TYPE_ATTR_KEY = "__verta_reserved__model_type" +MODEL_LANGUAGE_ATTR_KEY = f"{_RESERVED_ATTR_PREFIX}model_language" +MODEL_TYPE_ATTR_KEY = f"{_RESERVED_ATTR_PREFIX}model_type" class ModelLanguage(object): diff --git a/client/verta/verta/registry/entities/_modelversion.py b/client/verta/verta/registry/entities/_modelversion.py index 0a6a0ebc5c..59bb07882b 100644 --- a/client/verta/verta/registry/entities/_modelversion.py +++ b/client/verta/verta/registry/entities/_modelversion.py @@ -8,7 +8,7 @@ import pathlib import pickle import tempfile -from typing import List, Optional +from typing import List, Optional, Union import warnings from google.protobuf.struct_pb2 import Value @@ -38,6 +38,8 @@ from verta import _blob, code, data_types, environment from verta.endpoint.build import Build +import verta.finetune +from verta.tracking import _Context from verta.tracking.entities._entity import _MODEL_ARTIFACTS_ATTR_KEY from verta.tracking.entities import _deployable_entity from .. import lock, DockerImage @@ -1762,7 +1764,13 @@ def list_builds(self) -> List[Build]: builds = Build._list_model_version_builds(self._conn, self.workspace, self.id) return sorted(builds, key=lambda build: build.date_created, reverse=True) - def create_external_build(self, location: str, requires_root: Optional[bool] = None, scan_external: Optional[bool] = None, self_contained: Optional[bool] = None) -> Build: + def create_external_build( + self, + location: str, + requires_root: Optional[bool] = None, + scan_external: Optional[bool] = None, + self_contained: Optional[bool] = None, + ) -> Build: """ (alpha) Creates a new external build for this model version. @@ -1784,4 +1792,138 @@ def create_external_build(self, location: str, requires_root: Optional[bool] = N :class:`~verta.endpoint.build.Build` """ - return Build._create_external(self._conn, self.workspace, self.id, location, requires_root, scan_external, self_contained) + return Build._create_external( + self._conn, + self.workspace, + self.id, + location, + requires_root, + scan_external, + self_contained, + ) + + def finetune( + self, + destination_registered_model: Union[ + str, + "verta.registry.entities.RegisteredModel", + ], + train_dataset: _dataset_version.DatasetVersion, + eval_dataset: Optional[_dataset_version.DatasetVersion] = None, + test_dataset: Optional[_dataset_version.DatasetVersion] = None, + name: Optional[str] = None, + finetuning_config: Optional[verta.finetune._FinetuningConfig] = None, + ) -> "RegisteredModelVersion": + """Fine-tune this model version using the provided dataset(s). + + Parameters + ---------- + destination_registered_model : str or :class:`~verta.registry.entities.RegisteredModel` + Registered model (or simply its name) in which to create the new fine-tuned + model version. + train_dataset : :class:`~verta.dataset.entities.DatasetVersion` + Dataset version to use for training. The `content` passed to + :meth:`Dataset.create_version() ` + must have ``enable_mdb_versioning=True``. + eval_dataset : :class:`~verta.dataset.entities.DatasetVersion`, optional + Dataset version to use for evaluation. The `content` passed to + :meth:`Dataset.create_version() ` + must have ``enable_mdb_versioning=True``. + test_dataset : :class:`~verta.dataset.entities.DatasetVersion`, optional + Dataset version to use for final testing at the end of fine-tuning. The + `content` passed to :meth:`Dataset.create_version() ` + must have ``enable_mdb_versioning=True``. + name : str, optional + Name for the new fine-tuned model version. If no name is provided, one will + be generated. + finetuning_config : :mod:`fine-tuning configuration `, default :class:`~verta.finetune.LoraConfig` + Fine-tuning algorithm and configuration. + + Returns + ------- + :class:`~verta.registry.entities.RegisteredModelVersion` + New fine-tuned model version. + + """ + # import here to circumvent circular imports + from verta import Client + from verta.dataset.entities import Dataset + from verta.tracking.entities import ExperimentRun + + # TODO: [VRD-1131] check `enable_mdb_versioning` + + ctx = _Context(self._conn, self._conf) + ctx.workspace_name = self.workspace + + # TODO: check base RMV for fine-tunability + # if not self.get_attributes().get(verta.finetune._FINETUNE_BASE_RMV_ATTR_KEY): + # raise ValueError("this model version is not eligible for fine-tuning") + if isinstance(destination_registered_model, str): + destination_registered_model = Client._get_or_create_registered_model( + self._conn, + self._conf, + ctx, + name=destination_registered_model, + ) + if finetuning_config is None: + finetuning_config = verta.finetune.LoraConfig() + + # create experiment run + ctx.proj = Client._get_or_create_project( + self._conn, + self._conf, + ctx, + name=destination_registered_model.name + + verta.finetune._PROJECT_NAME_SUFFIX, + ) + ctx.expt = Client._get_or_create_experiment( + self._conn, + self._conf, + ctx, + name=verta.finetune._EXPERIMENT_NAME_PREFIX + + Dataset._get_by_id(self._conn, self._conf, train_dataset.dataset_id).name, + ) + run = ExperimentRun._create( + self._conn, + self._conf, + ctx, + name=None, # autogenerate unique name, to be set later by fine-tuning job + attrs={verta.finetune._FINETUNE_ATTR_KEY: True}, + ) + + try: + # log dataset versions + run.log_dataset_version(verta.finetune._TRAIN_DATASET_NAME, train_dataset) + if eval_dataset is not None: + run.log_dataset_version(verta.finetune._EVAL_DATASET_NAME, eval_dataset) + if test_dataset is not None: + run.log_dataset_version(verta.finetune._TEST_DATASET_NAME, test_dataset) + + # create model version + model_ver = destination_registered_model.create_version_from_run( + run, + name=name, + ) + + try: + # launch fine-tuning + data = { + "base_model_version_id": self.id, + "run_id": run.id, + finetuning_config._JOB_DICT_KEY: finetuning_config._as_dict(), + } + url = "{}://{}/api/v1/deployment/workspace/{}/finetuning-job".format( + self._conn.scheme, + self._conn.socket, + self.workspace, + ) + response = _utils.make_request("POST", url, self._conn, json=data) + _utils.raise_for_http_error(response) + except: + model_ver.delete() + raise + except: + run.delete() + raise + + return model_ver diff --git a/client/verta/verta/tracking/entities/_deployable_entity.py b/client/verta/verta/tracking/entities/_deployable_entity.py index 18447594cf..5084007cf5 100644 --- a/client/verta/verta/tracking/entities/_deployable_entity.py +++ b/client/verta/verta/tracking/entities/_deployable_entity.py @@ -42,6 +42,7 @@ ) _INTERNAL_ATTR_PREFIX = "__verta_" +_RESERVED_ATTR_PREFIX = f"{_INTERNAL_ATTR_PREFIX}reserved__" @six.add_metaclass(abc.ABCMeta)