Skip to content

Commit

Permalink
Choices for Str not just list
Browse files Browse the repository at this point in the history
In the case of the Str (which translates to a StrEnum for choices) the
explicit list isinstance check is incorrect. There are many cases where
it would be valid to send an `Iterable` such as with `dict.keys()`
returning a `dict_keys`. As we do not want to explictly cast to `list`
type this commit results in a check of iterable type instead of list.
  • Loading branch information
tempusfrangit committed Oct 14, 2024
1 parent ed7cc89 commit 2692447
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import types
import uuid
from abc import ABC, abstractmethod
from collections.abc import Iterator
from collections.abc import Iterable, Iterator
from pathlib import Path
from typing import (
Any,
Expand Down Expand Up @@ -354,7 +354,7 @@ def get_input_create_model_kwargs(signature: inspect.Signature) -> Dict[str, Any
# In either case, remove it as an extra field because it will be
# passed automatically as 'enum' in the schema
if choices:
if InputType == str and isinstance(choices, list): # noqa: E721
if InputType == str and isinstance(choices, Iterable): # noqa: E721

class StringEnum(str, enum.Enum):
pass
Expand Down
7 changes: 7 additions & 0 deletions python/tests/server/fixtures/input_choices_iterable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from cog import BasePredictor, Input


class Predictor(BasePredictor):
def predict(self, text: str = Input(choices={"foo": "x", "bar": "y"}.keys())) -> str:
assert type(text) == str
return text
8 changes: 8 additions & 0 deletions python/tests/server/test_http_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,14 @@ def test_choices_str(client):
assert resp.status_code == 422


@uses_predictor("input_choices_iterable")
def test_choices_str(client):
resp = client.post("/predictions", json={"input": {"text": "foo"}})
assert resp.status_code == 200
resp = client.post("/predictions", json={"input": {"text": "baz"}})
assert resp.status_code == 422


@uses_predictor("input_choices_integer")
def test_choices_int(client):
resp = client.post("/predictions", json={"input": {"x": 1}})
Expand Down
1 change: 1 addition & 0 deletions python/tests/server/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
PREDICTOR_FIXTURES = [
("input_choices", "Predictor", "predict"),
("input_choices_integer", "Predictor", "predict"),
("input_choices_iterable", "Predictor", "predict"),
("input_file", "Predictor", "predict"),
("function", "predict", "predict"),
("input_ge_le", "Predictor", "predict"),
Expand Down

0 comments on commit 2692447

Please sign in to comment.