-
Notifications
You must be signed in to change notification settings - Fork 286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: [VRD-1120, VRD-1130, VRD-1132] Add finetune()
#4085
Changes from all commits
1c14266
c431d3a
329d80b
97dd8c8
cce2cff
25b2a78
1640b23
c48f9ef
817907c
3e5ed53
fbb161e
0a1267d
ebabcd0
340a77d
c752616
6698949
62e229e
8d726bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
"""Test RegisteredModelVersion.finetune() and its downstream effects.""" | ||
|
||
from verta.dataset import Path | ||
from verta import finetune | ||
|
||
|
||
def test_finetune(client, registered_model, dataset): | ||
"""Verify that happy-path ``finetune()`` works.""" | ||
base_model_ver = registered_model.create_version() # mocked base LLM RMV | ||
name = "v1" | ||
|
||
reg_model = client.create_registered_model() | ||
train_dataset_version = dataset.create_version( | ||
Path(__file__, enable_mdb_versioning=True), | ||
) | ||
model_ver = base_model_ver.finetune( | ||
destination_registered_model=reg_model, | ||
train_dataset=train_dataset_version, | ||
name=name, | ||
) | ||
run = client.get_experiment_run(id=model_ver.experiment_run_id) | ||
|
||
# check entity names | ||
assert client.proj.name == reg_model.name + finetune._PROJECT_NAME_SUFFIX | ||
assert client.expt.name == finetune._EXPERIMENT_NAME_PREFIX + dataset.name | ||
# TODO: wait for fine-tuning to launch, then check ER name | ||
assert model_ver.name == name | ||
|
||
# check dataset association | ||
for entity in [model_ver, run]: | ||
for key, value in [ | ||
(finetune._TRAIN_DATASET_NAME, train_dataset_version), | ||
# TODO: eval and test, too | ||
]: | ||
assert entity.get_dataset_version(key).id == value.id | ||
|
||
# check attributes | ||
for entity in [model_ver, run]: | ||
assert ( | ||
entity.get_attributes().items() | ||
>= { | ||
finetune._FINETUNE_ATTR_KEY: True, | ||
}.items() | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -154,7 +154,9 @@ def __init__( | |
_connect=True, | ||
): | ||
if organization_id is not None and organization_name is not None: | ||
raise ValueError("cannot provide both `organization_id` and `organization_name`") | ||
raise ValueError( | ||
"cannot provide both `organization_id` and `organization_name`" | ||
) | ||
self._load_config() | ||
|
||
host = self._get_with_fallback(host, env_var="VERTA_HOST", config_var="host") | ||
|
@@ -225,10 +227,10 @@ def __init__( | |
# ) | ||
# response = conn.must_proto_response(request, OrganizationV2_pb2.GetOrganizationByNameV2.Response) | ||
# organization_id = response.organization.id | ||
request = conn.make_proto_request( | ||
"GET", "/api/v2/uac-proxy/organization" | ||
request = conn.make_proto_request("GET", "/api/v2/uac-proxy/organization") | ||
response = conn.must_proto_response( | ||
request, OrganizationV2_pb2.ListOrganizationsV2.Response | ||
) | ||
response = conn.must_proto_response(request, OrganizationV2_pb2.ListOrganizationsV2.Response) | ||
Comment on lines
-228
to
-231
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ignore; There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean this was bad |
||
for org in response.organizations: | ||
if org.name == organization_name: | ||
organization_id = org.id | ||
|
@@ -509,26 +511,55 @@ def set_project( | |
self._ctx = _Context(self._conn, self._conf) | ||
self._ctx.workspace_name = workspace | ||
|
||
return self._get_or_create_project( | ||
self._conn, | ||
self._conf, | ||
self._ctx, | ||
name=name, | ||
desc=desc, | ||
tags=tags, | ||
attrs=attrs, | ||
public_within_org=public_within_org, | ||
visibility=visibility, | ||
id=id, | ||
) | ||
|
||
@staticmethod | ||
def _get_or_create_project( | ||
Comment on lines
+527
to
+528
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This whole file's changes are basically factoring out static versions of these Ideally these would actually be moved to new classes—which I have vague, vague plans for—but I wanted to keep this diff as readable as possible for now. |
||
conn, | ||
conf, | ||
ctx, | ||
name=None, | ||
desc=None, | ||
tags=None, | ||
attrs=None, | ||
public_within_org=None, | ||
visibility=None, | ||
id=None, | ||
): | ||
""":meth:`set_project`, but static. | ||
|
||
Requires either `ctx.workspace_name` or `id`. | ||
|
||
""" | ||
resource_name = "Project" | ||
param_names = "`desc`, `tags`, `attrs`, `public_within_org`, or `visibility`" | ||
params = (desc, tags, attrs, public_within_org, visibility) | ||
if id is not None: | ||
self._ctx.proj = Project._get_by_id(self._conn, self._conf, id) | ||
ctx.proj = Project._get_by_id(conn, conf, id) | ||
check_unnecessary_params_warning( | ||
resource_name, "id {}".format(id), param_names, params | ||
) | ||
self._ctx.populate() | ||
ctx.populate() | ||
else: | ||
self._ctx.proj = Project._get_or_create_by_name( | ||
self._conn, | ||
ctx.proj = Project._get_or_create_by_name( | ||
conn, | ||
name, | ||
lambda name: Project._get_by_name( | ||
self._conn, self._conf, name, self._ctx.workspace_name | ||
), | ||
lambda name: Project._get_by_name(conn, conf, name, ctx.workspace_name), | ||
lambda name: Project._create( | ||
self._conn, | ||
self._conf, | ||
self._ctx, | ||
conn, | ||
conf, | ||
ctx, | ||
name=name, | ||
desc=desc, | ||
tags=tags, | ||
|
@@ -541,7 +572,7 @@ def set_project( | |
), | ||
) | ||
|
||
return self._ctx.proj | ||
return ctx.proj | ||
|
||
def get_experiment(self, name=None, id=None): | ||
""" | ||
|
@@ -621,29 +652,54 @@ def set_experiment(self, name=None, desc=None, tags=None, attrs=None, id=None): | |
|
||
name = self._set_from_config_if_none(name, "experiment") | ||
|
||
if id is None and self._ctx.proj is None: | ||
self.set_project() | ||
Comment on lines
+655
to
+656
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is moved here from the old lines 634–635 below, because the static method can't access |
||
|
||
return self._get_or_create_experiment( | ||
self._conn, | ||
self._conf, | ||
self._ctx, | ||
name=name, | ||
desc=desc, | ||
tags=tags, | ||
attrs=attrs, | ||
id=id, | ||
) | ||
|
||
@staticmethod | ||
def _get_or_create_experiment( | ||
conn, | ||
conf, | ||
ctx, | ||
name=None, | ||
desc=None, | ||
tags=None, | ||
attrs=None, | ||
id=None, | ||
): | ||
""":meth:`set_experiment`, but static. | ||
|
||
Requires either `ctx.proj` or `id` | ||
|
||
""" | ||
resource_name = "Experiment" | ||
param_names = "`desc`, `tags`, or `attrs`" | ||
params = (desc, tags, attrs) | ||
if id is not None: | ||
self._ctx.expt = Experiment._get_by_id(self._conn, self._conf, id) | ||
ctx.expt = Experiment._get_by_id(conn, conf, id) | ||
check_unnecessary_params_warning( | ||
resource_name, "id {}".format(id), param_names, params | ||
) | ||
self._ctx.populate() | ||
ctx.populate() | ||
else: | ||
if self._ctx.proj is None: | ||
self.set_project() | ||
|
||
self._ctx.expt = Experiment._get_or_create_by_name( | ||
self._conn, | ||
ctx.expt = Experiment._get_or_create_by_name( | ||
conn, | ||
name, | ||
lambda name: Experiment._get_by_name( | ||
self._conn, self._conf, name, self._ctx.proj.id | ||
), | ||
lambda name: Experiment._get_by_name(conn, conf, name, ctx.proj.id), | ||
lambda name: Experiment._create( | ||
self._conn, | ||
self._conf, | ||
self._ctx, | ||
conn, | ||
conf, | ||
ctx, | ||
name=name, | ||
desc=desc, | ||
tags=tags, | ||
|
@@ -654,7 +710,7 @@ def set_experiment(self, name=None, desc=None, tags=None, attrs=None, id=None): | |
), | ||
) | ||
|
||
return self._ctx.expt | ||
return ctx.expt | ||
|
||
def get_experiment_run(self, name=None, id=None): | ||
""" | ||
|
@@ -872,24 +928,59 @@ def get_or_create_registered_model( | |
ctx = _Context(self._conn, self._conf) | ||
ctx.workspace_name = workspace | ||
|
||
return self._get_or_create_registered_model( | ||
self._conn, | ||
self._conf, | ||
ctx, | ||
name=name, | ||
desc=desc, | ||
labels=labels, | ||
public_within_org=public_within_org, | ||
visibility=visibility, | ||
id=id, | ||
task_type=task_type, | ||
data_type=data_type, | ||
pii=pii, | ||
) | ||
|
||
@staticmethod | ||
def _get_or_create_registered_model( | ||
conn, | ||
conf, | ||
ctx, | ||
name=None, | ||
desc=None, | ||
labels=None, | ||
public_within_org=None, | ||
visibility=None, | ||
id=None, | ||
task_type=None, | ||
data_type=None, | ||
pii=False, | ||
): | ||
""":meth:`get_or_create_registered_model`, but static. | ||
|
||
Requires either `ctx.workspace_name` or `id`. | ||
|
||
""" | ||
resource_name = "Registered Model" | ||
param_names = "`desc`, `labels`, `public_within_org`, or `visibility`" | ||
params = (desc, labels, public_within_org, visibility) | ||
if id is not None: | ||
registered_model = RegisteredModel._get_by_id(self._conn, self._conf, id) | ||
registered_model = RegisteredModel._get_by_id(conn, conf, id) | ||
check_unnecessary_params_warning( | ||
resource_name, "id {}".format(id), param_names, params | ||
) | ||
else: | ||
registered_model = RegisteredModel._get_or_create_by_name( | ||
self._conn, | ||
conn, | ||
name, | ||
lambda name: RegisteredModel._get_by_name( | ||
self._conn, self._conf, name, ctx.workspace_name | ||
conn, conf, name, ctx.workspace_name | ||
), | ||
lambda name: RegisteredModel._create( | ||
self._conn, | ||
self._conf, | ||
conn, | ||
conf, | ||
ctx, | ||
name=name, | ||
desc=desc, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ignore;
black
did this.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rude