diff --git a/python/cog/__init__.py b/python/cog/__init__.py index 34b372955..b8371e0f0 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -1,7 +1,7 @@ from pydantic import BaseModel from .predictor import BasePredictor -from .types import ConcatenateIterator, File, Input, Path +from .types import ConcatenateIterator, File, Input, Path, Secret try: from ._version import __version__ @@ -17,4 +17,5 @@ "File", "Input", "Path", + "Secret", ] diff --git a/python/cog/predictor.py b/python/cog/predictor.py index 06708c951..46f6920d2 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -45,13 +45,20 @@ Input, URLPath, ) -from .types import ( - Path as CogPath, -) +from .types import Path as CogPath +from .types import Secret as CogSecret log = structlog.get_logger("cog.server.predictor") -ALLOWED_INPUT_TYPES: List[Type[Any]] = [str, int, float, bool, CogFile, CogPath] +ALLOWED_INPUT_TYPES: List[Type[Any]] = [ + str, + int, + float, + bool, + CogFile, + CogPath, + CogSecret, +] class BasePredictor(ABC): diff --git a/python/cog/types.py b/python/cog/types.py index 8b75fd704..a75b48d99 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -9,7 +9,7 @@ from typing import Any, Dict, Iterator, List, Optional, TypeVar, Union import requests -from pydantic import Field +from pydantic import Field, SecretStr FILENAME_ILLEGAL_CHARS = set("\u0000/") @@ -42,6 +42,19 @@ def Input( ) +class Secret(SecretStr): + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + """Defines what this type should be in openapi.json""" + field_schema.update( + { + "type": "string", + "format": "password", + "x-cog-secret": True, + } + ) + + class File(io.IOBase): validate_always = True diff --git a/python/tests/server/fixtures/input_secret.py b/python/tests/server/fixtures/input_secret.py new file mode 100644 index 000000000..fcc6f0a5a --- /dev/null +++ b/python/tests/server/fixtures/input_secret.py @@ -0,0 +1,6 @@ +from cog import BasePredictor, Secret + + +class Predictor(BasePredictor): + def predict(self, secret: Secret) -> str: + return secret.get_secret_value() diff --git a/python/tests/server/test_http_input.py b/python/tests/server/test_http_input.py index c9966e507..a64bb0104 100644 --- a/python/tests/server/test_http_input.py +++ b/python/tests/server/test_http_input.py @@ -242,6 +242,16 @@ def test_union_integers(client): assert resp.status_code == 422 +@uses_predictor("input_secret") +def test_secret_str(client, match): + resp = client.post("/predictions", json={"input": {"secret": "foo"}}) + assert resp.status_code == 200 + assert resp.json() == match({"output": "foo", "status": "succeeded"}) + + resp = client.post("/predictions", json={"input": {"secret": {}}}) + assert resp.status_code == 422 + + def test_untyped_inputs(): config = {"predict": _fixture_path("input_untyped")} app = create_app( diff --git a/python/tests/test_types.py b/python/tests/test_types.py index 72efdca09..554aed305 100644 --- a/python/tests/test_types.py +++ b/python/tests/test_types.py @@ -3,7 +3,7 @@ import pytest import responses -from cog.types import URLFile, get_filename +from cog.types import Secret, URLFile, get_filename @responses.activate @@ -119,3 +119,11 @@ def test_urlfile_can_be_pickled_even_once_loaded(): ) def test_get_filename(url, filename): assert get_filename(url) == filename + + +def test_secret_type(): + secret_value = "sw0rdf1$h" # noqa: S105 + secret = Secret(secret_value) + + assert secret.get_secret_value() == secret_value + assert str(secret) == "**********" diff --git a/test-integration/test_integration/util.py b/test-integration/test_integration/util.py index 5f7221f84..07faff25d 100644 --- a/test-integration/test_integration/util.py +++ b/test-integration/test_integration/util.py @@ -11,7 +11,9 @@ def random_string(length): def remove_docker_image(image_name, max_attempts=5, wait_seconds=1): for attempt in range(max_attempts): try: - subprocess.run(["docker", "rmi", "-f", image_name], check=True, capture_output=True) + subprocess.run( + ["docker", "rmi", "-f", image_name], check=True, capture_output=True + ) print(f"Image {image_name} successfully removed.") break except subprocess.CalledProcessError as e: