Skip to content

Commit

Permalink
PROD-435 Add model_name to search
Browse files Browse the repository at this point in the history
* PROD-435 Add model_name to CLI and SDK, as well
as allowing the query param at the endpoint

GitOrigin-RevId: 1a1ea68d8a06b6f19ff8553363e9d618a082a038
  • Loading branch information
mckornfield authored and pimlock committed May 26, 2023
1 parent d8f0f38 commit 0a002eb
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 12 deletions.
10 changes: 7 additions & 3 deletions src/gretel_client/cli/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def create(
dry_run: bool,
name: str,
):

if wait >= 0 and output:
raise click.BadOptionUsage(
"--output",
Expand Down Expand Up @@ -304,9 +303,14 @@ def get(sc: SessionContext, project: str, model_id: str, output: str):
@models.command(help="Search for models of the project.")
@project_option
@click.option("--limit", help="Limit the number of projects.", default=100)
@click.option("--model-name", help="Model name to match on", default="")
@pass_session
def search(sc: SessionContext, project: str, limit: int):
sc.print(data=list(sc.project.search_models(factory=dict, limit=limit)))
def search(sc: SessionContext, project: str, limit: int, model_name: str):
sc.print(
data=list(
sc.project.search_models(factory=dict, limit=limit, model_name=model_name)
)
)


@models.command(help="Delete model.")
Expand Down
25 changes: 16 additions & 9 deletions src/gretel_client/projects/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,25 +153,32 @@ def info(self) -> dict:

@check_not_deleted
def search_models(
self, factory: Type[MT] = Model, limit: int = 100
self,
factory: Type[MT] = Model,
limit: int = 100,
model_name: str = "",
) -> Iterator[MT]:
"""Search for project models.
Args:
limit: Limits the number of project models to return
factory: Determines what type of Model representation is returned.
If ``Model`` is passed, a ``Model`` will be returned. If ``dict``
is passed, a dictionary representation of the search results
will be returned.
limit: Limits the number of project models to return
model_name: Name of the model to try and match on (partial match)
"""
if factory not in (dict, Model):
raise ValueError("factory must be one of ``str`` or ``Model``.")
models = (
self.projects_api.get_models(project_id=self.name, limit=limit)
.get(DATA)
.get(MODELS)
)
for model in models:
raise ValueError("factory must be one of ``dict`` or ``Model``.")

api_args = {"project_id": self.name, "limit": limit}
if model_name:
api_args["model_name"] = model_name

result = self.projects_api.get_models(**api_args)
searched_models = result.get(DATA).get(MODELS)

for model in searched_models:
if factory == Model:
model = self.get_model(model_id=model[f.UID])
yield model
Expand Down
5 changes: 5 additions & 0 deletions src/gretel_client/rest/api/projects_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1598,6 +1598,7 @@ def __get_models(self, project_id, **kwargs):
Keyword Args:
limit (int): Limit number of models to return. [optional]
model_name (str): Model name to match on. [optional]
_return_http_data_only (bool): response data without head status
code and headers. Default is True.
_preload_content (bool): if False, the urllib3.HTTPResponse object
Expand Down Expand Up @@ -1662,6 +1663,7 @@ def __get_models(self, project_id, **kwargs):
"all": [
"project_id",
"limit",
"model_name",
],
"required": [
"project_id",
Expand All @@ -1676,14 +1678,17 @@ def __get_models(self, project_id, **kwargs):
"openapi_types": {
"project_id": (str,),
"limit": (int,),
"model_name": (str,),
},
"attribute_map": {
"project_id": "project_id",
"limit": "limit",
"model_name": "model_name",
},
"location_map": {
"project_id": "path",
"limit": "query",
"model_name": "query",
},
"collection_format_map": {},
},
Expand Down
11 changes: 11 additions & 0 deletions tests/gretel_client/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,14 @@ def test_does_run_manual_artifacts(
status_strings=ANY,
model_path=None,
)


def test_search_models_with_model_name(
get_project: MagicMock,
runner: CliRunner,
):
cmd = runner.invoke(
cli,
["models", "search", "--model-name", "model-boi"],
)
assert cmd.exit_code == 0

0 comments on commit 0a002eb

Please sign in to comment.