Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [VRD-1120, VRD-1130, VRD-1132] Add finetune() #4085

Merged
merged 18 commits into from
Sep 21, 2023
Empty file.
46 changes: 46 additions & 0 deletions client/verta/tests/finetune/test_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-

"""Test RegisteredModelVersion.finetune() and its downstream effects."""

from verta.dataset import Path
from verta import finetune


def test_finetune(client, registered_model, dataset):
"""Verify that happy-path ``finetune()`` works."""
base_model_ver = registered_model.create_version() # mocked base LLM RMV
name = "v1"

reg_model = client.create_registered_model()
train_dataset_version = dataset.create_version(
Path(__file__, enable_mdb_versioning=True),
)
model_ver = base_model_ver.finetune(
destination_registered_model=reg_model,
train_dataset=train_dataset_version,
name=name,
)
run = client.get_experiment_run(id=model_ver.experiment_run_id)

# check entity names
assert client.proj.name == reg_model.name + finetune._PROJECT_NAME_SUFFIX
assert client.expt.name == finetune._EXPERIMENT_NAME_PREFIX + dataset.name
# TODO: wait for fine-tuning to launch, then check ER name
assert model_ver.name == name

# check dataset association
for entity in [model_ver, run]:
for key, value in [
(finetune._TRAIN_DATASET_NAME, train_dataset_version),
# TODO: eval and test, too
]:
assert entity.get_dataset_version(key).id == value.id

# check attributes
for entity in [model_ver, run]:
assert (
entity.get_attributes().items()
>= {
finetune._FINETUNE_ATTR_KEY: True,
}.items()
)
159 changes: 125 additions & 34 deletions client/verta/verta/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ def __init__(
_connect=True,
):
if organization_id is not None and organization_name is not None:
raise ValueError("cannot provide both `organization_id` and `organization_name`")
raise ValueError(
"cannot provide both `organization_id` and `organization_name`"
)
Comment on lines -157 to +159
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore; black did this.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rude

self._load_config()

host = self._get_with_fallback(host, env_var="VERTA_HOST", config_var="host")
Expand Down Expand Up @@ -225,10 +227,10 @@ def __init__(
# )
# response = conn.must_proto_response(request, OrganizationV2_pb2.GetOrganizationByNameV2.Response)
# organization_id = response.organization.id
request = conn.make_proto_request(
"GET", "/api/v2/uac-proxy/organization"
request = conn.make_proto_request("GET", "/api/v2/uac-proxy/organization")
response = conn.must_proto_response(
request, OrganizationV2_pb2.ListOrganizationsV2.Response
)
response = conn.must_proto_response(request, OrganizationV2_pb2.ListOrganizationsV2.Response)
Comment on lines -228 to -231
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore; black did this.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean this was bad

for org in response.organizations:
if org.name == organization_name:
organization_id = org.id
Expand Down Expand Up @@ -509,26 +511,55 @@ def set_project(
self._ctx = _Context(self._conn, self._conf)
self._ctx.workspace_name = workspace

return self._get_or_create_project(
self._conn,
self._conf,
self._ctx,
name=name,
desc=desc,
tags=tags,
attrs=attrs,
public_within_org=public_within_org,
visibility=visibility,
id=id,
)

@staticmethod
def _get_or_create_project(
Comment on lines +527 to +528
Copy link
Contributor Author

@liuverta liuverta Sep 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole file's changes are basically factoring out static versions of these get_or_create methods so that they can be called without needing to instantiate a new Client. See the absolute nightmare that is Client.__init__()—this company's oldest extant Python function—for why we want to avoid that.

Ideally these would actually be moved to new classes—which I have vague, vague plans for—but I wanted to keep this diff as readable as possible for now.

conn,
conf,
ctx,
name=None,
desc=None,
tags=None,
attrs=None,
public_within_org=None,
visibility=None,
id=None,
):
""":meth:`set_project`, but static.

Requires either `ctx.workspace_name` or `id`.

"""
resource_name = "Project"
param_names = "`desc`, `tags`, `attrs`, `public_within_org`, or `visibility`"
params = (desc, tags, attrs, public_within_org, visibility)
if id is not None:
self._ctx.proj = Project._get_by_id(self._conn, self._conf, id)
ctx.proj = Project._get_by_id(conn, conf, id)
check_unnecessary_params_warning(
resource_name, "id {}".format(id), param_names, params
)
self._ctx.populate()
ctx.populate()
else:
self._ctx.proj = Project._get_or_create_by_name(
self._conn,
ctx.proj = Project._get_or_create_by_name(
conn,
name,
lambda name: Project._get_by_name(
self._conn, self._conf, name, self._ctx.workspace_name
),
lambda name: Project._get_by_name(conn, conf, name, ctx.workspace_name),
lambda name: Project._create(
self._conn,
self._conf,
self._ctx,
conn,
conf,
ctx,
name=name,
desc=desc,
tags=tags,
Expand All @@ -541,7 +572,7 @@ def set_project(
),
)

return self._ctx.proj
return ctx.proj

def get_experiment(self, name=None, id=None):
"""
Expand Down Expand Up @@ -621,29 +652,54 @@ def set_experiment(self, name=None, desc=None, tags=None, attrs=None, id=None):

name = self._set_from_config_if_none(name, "experiment")

if id is None and self._ctx.proj is None:
self.set_project()
Comment on lines +655 to +656
Copy link
Contributor Author

@liuverta liuverta Sep 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is moved here from the old lines 634–635 below, because the static method can't access self.set_project()


return self._get_or_create_experiment(
self._conn,
self._conf,
self._ctx,
name=name,
desc=desc,
tags=tags,
attrs=attrs,
id=id,
)

@staticmethod
def _get_or_create_experiment(
conn,
conf,
ctx,
name=None,
desc=None,
tags=None,
attrs=None,
id=None,
):
""":meth:`set_experiment`, but static.

Requires either `ctx.proj` or `id`

"""
resource_name = "Experiment"
param_names = "`desc`, `tags`, or `attrs`"
params = (desc, tags, attrs)
if id is not None:
self._ctx.expt = Experiment._get_by_id(self._conn, self._conf, id)
ctx.expt = Experiment._get_by_id(conn, conf, id)
check_unnecessary_params_warning(
resource_name, "id {}".format(id), param_names, params
)
self._ctx.populate()
ctx.populate()
else:
if self._ctx.proj is None:
self.set_project()

self._ctx.expt = Experiment._get_or_create_by_name(
self._conn,
ctx.expt = Experiment._get_or_create_by_name(
conn,
name,
lambda name: Experiment._get_by_name(
self._conn, self._conf, name, self._ctx.proj.id
),
lambda name: Experiment._get_by_name(conn, conf, name, ctx.proj.id),
lambda name: Experiment._create(
self._conn,
self._conf,
self._ctx,
conn,
conf,
ctx,
name=name,
desc=desc,
tags=tags,
Expand All @@ -654,7 +710,7 @@ def set_experiment(self, name=None, desc=None, tags=None, attrs=None, id=None):
),
)

return self._ctx.expt
return ctx.expt

def get_experiment_run(self, name=None, id=None):
"""
Expand Down Expand Up @@ -872,24 +928,59 @@ def get_or_create_registered_model(
ctx = _Context(self._conn, self._conf)
ctx.workspace_name = workspace

return self._get_or_create_registered_model(
self._conn,
self._conf,
ctx,
name=name,
desc=desc,
labels=labels,
public_within_org=public_within_org,
visibility=visibility,
id=id,
task_type=task_type,
data_type=data_type,
pii=pii,
)

@staticmethod
def _get_or_create_registered_model(
conn,
conf,
ctx,
name=None,
desc=None,
labels=None,
public_within_org=None,
visibility=None,
id=None,
task_type=None,
data_type=None,
pii=False,
):
""":meth:`get_or_create_registered_model`, but static.

Requires either `ctx.workspace_name` or `id`.

"""
resource_name = "Registered Model"
param_names = "`desc`, `labels`, `public_within_org`, or `visibility`"
params = (desc, labels, public_within_org, visibility)
if id is not None:
registered_model = RegisteredModel._get_by_id(self._conn, self._conf, id)
registered_model = RegisteredModel._get_by_id(conn, conf, id)
check_unnecessary_params_warning(
resource_name, "id {}".format(id), param_names, params
)
else:
registered_model = RegisteredModel._get_or_create_by_name(
self._conn,
conn,
name,
lambda name: RegisteredModel._get_by_name(
self._conn, self._conf, name, ctx.workspace_name
conn, conf, name, ctx.workspace_name
),
lambda name: RegisteredModel._create(
self._conn,
self._conf,
conn,
conf,
ctx,
name=name,
desc=desc,
Expand Down
10 changes: 10 additions & 0 deletions client/verta/verta/finetune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,21 @@
"""Utilities for model fine-tuning."""

from verta._internal_utils import documentation
from verta.tracking.entities._deployable_entity import _RESERVED_ATTR_PREFIX

from ._finetuning_config import _FinetuningConfig
from ._lora_config import LoraConfig


_PROJECT_NAME_SUFFIX = " Fine-Tuning" # append to RM name for projname
_EXPERIMENT_NAME_PREFIX = "On Dataset: " # prepend to train dataset name for expt name
_TRAIN_DATASET_NAME = "train"
_EVAL_DATASET_NAME = "eval"
_TEST_DATASET_NAME = "test"
_FINETUNE_BASE_RMV_ATTR_KEY = f"{_RESERVED_ATTR_PREFIX}FINETUNE_BASE"
_FINETUNE_ATTR_KEY = f"{_RESERVED_ATTR_PREFIX}FINETUNE"


documentation.reassign_module(
[
_FinetuningConfig,
Expand Down
5 changes: 3 additions & 2 deletions client/verta/verta/registry/_constants.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# -*- coding: utf-8 -*-

from verta._protos.public.registry import ModelMetadata_pb2
from verta.tracking.entities._deployable_entity import _RESERVED_ATTR_PREFIX


MODEL_LANGUAGE_ATTR_KEY = "__verta_reserved__model_language"
MODEL_TYPE_ATTR_KEY = "__verta_reserved__model_type"
MODEL_LANGUAGE_ATTR_KEY = f"{_RESERVED_ATTR_PREFIX}model_language"
MODEL_TYPE_ATTR_KEY = f"{_RESERVED_ATTR_PREFIX}model_type"


class ModelLanguage(object):
Expand Down
Loading
Loading