From 269244795e85d2ac18b7c5f3c2eed48d05eec945 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Mon, 14 Oct 2024 13:19:07 -0700 Subject: [PATCH] Choices for Str not just list In the case of the Str (which translates to a StrEnum for choices) the explicit list isinstance check is incorrect. There are many cases where it would be valid to send an `Iterable` such as with `dict.keys()` returning a `dict_keys`. As we do not want to explictly cast to `list` type this commit results in a check of iterable type instead of list. --- python/cog/predictor.py | 4 ++-- python/tests/server/fixtures/input_choices_iterable.py | 7 +++++++ python/tests/server/test_http_input.py | 8 ++++++++ python/tests/server/test_predictor.py | 1 + 4 files changed, 18 insertions(+), 2 deletions(-) create mode 100644 python/tests/server/fixtures/input_choices_iterable.py diff --git a/python/cog/predictor.py b/python/cog/predictor.py index dc0ddc3ac..72d997cdf 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -8,7 +8,7 @@ import types import uuid from abc import ABC, abstractmethod -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from pathlib import Path from typing import ( Any, @@ -354,7 +354,7 @@ def get_input_create_model_kwargs(signature: inspect.Signature) -> Dict[str, Any # In either case, remove it as an extra field because it will be # passed automatically as 'enum' in the schema if choices: - if InputType == str and isinstance(choices, list): # noqa: E721 + if InputType == str and isinstance(choices, Iterable): # noqa: E721 class StringEnum(str, enum.Enum): pass diff --git a/python/tests/server/fixtures/input_choices_iterable.py b/python/tests/server/fixtures/input_choices_iterable.py new file mode 100644 index 000000000..a9a5f498d --- /dev/null +++ b/python/tests/server/fixtures/input_choices_iterable.py @@ -0,0 +1,7 @@ +from cog import BasePredictor, Input + + +class Predictor(BasePredictor): + def predict(self, text: str = Input(choices={"foo": "x", "bar": "y"}.keys())) -> str: + assert type(text) == str + return text diff --git a/python/tests/server/test_http_input.py b/python/tests/server/test_http_input.py index 0fab05012..9aa9203b8 100644 --- a/python/tests/server/test_http_input.py +++ b/python/tests/server/test_http_input.py @@ -214,6 +214,14 @@ def test_choices_str(client): assert resp.status_code == 422 +@uses_predictor("input_choices_iterable") +def test_choices_str(client): + resp = client.post("/predictions", json={"input": {"text": "foo"}}) + assert resp.status_code == 200 + resp = client.post("/predictions", json={"input": {"text": "baz"}}) + assert resp.status_code == 422 + + @uses_predictor("input_choices_integer") def test_choices_int(client): resp = client.post("/predictions", json={"input": {"x": 1}}) diff --git a/python/tests/server/test_predictor.py b/python/tests/server/test_predictor.py index f51900928..1b86b7da5 100644 --- a/python/tests/server/test_predictor.py +++ b/python/tests/server/test_predictor.py @@ -14,6 +14,7 @@ PREDICTOR_FIXTURES = [ ("input_choices", "Predictor", "predict"), ("input_choices_integer", "Predictor", "predict"), + ("input_choices_iterable", "Predictor", "predict"), ("input_file", "Predictor", "predict"), ("function", "predict", "predict"), ("input_ge_le", "Predictor", "predict"),