Skip to content

Commit

Permalink
Define Secret type (#1546)
Browse files Browse the repository at this point in the history
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
  • Loading branch information
mattt committed May 31, 2024
1 parent e07dc07 commit d4749f0
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 5 deletions.
10 changes: 9 additions & 1 deletion python/cog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@

from .predictor import BasePredictor
from .server.worker import emit_metric
from .types import AsyncConcatenateIterator, ConcatenateIterator, File, Input, Path
from .types import (
AsyncConcatenateIterator,
ConcatenateIterator,
File,
Input,
Path,
Secret,
)

try:
from ._version import __version__
Expand All @@ -20,4 +27,5 @@
"Input",
"Path",
"emit_metric",
"Secret",
]
13 changes: 11 additions & 2 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,17 @@
from .types import (
Path as CogPath,
)

ALLOWED_INPUT_TYPES: List[Type[Any]] = [str, int, float, bool, CogFile, CogPath]
from .types import Secret as CogSecret

ALLOWED_INPUT_TYPES: List[Type[Any]] = [
str,
int,
float,
bool,
CogFile,
CogPath,
CogSecret,
]


class BasePredictor(ABC):
Expand Down
15 changes: 14 additions & 1 deletion python/cog/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import httpx
import requests
from pydantic import Field
from pydantic import Field, SecretStr

FILENAME_ILLEGAL_CHARS = set("\u0000/")

Expand Down Expand Up @@ -44,6 +44,19 @@ def Input(
)


class Secret(SecretStr):
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
"""Defines what this type should be in openapi.json"""
field_schema.update(
{
"type": "string",
"format": "password",
"x-cog-secret": True,
}
)


class File(io.IOBase):
validate_always = True

Expand Down
6 changes: 6 additions & 0 deletions python/tests/server/fixtures/input_secret.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from cog import BasePredictor, Secret


class Predictor(BasePredictor):
def predict(self, secret: Secret) -> str:
return secret.get_secret_value()
10 changes: 10 additions & 0 deletions python/tests/server/test_http_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,16 @@ def test_union_integers(client):
assert resp.status_code == 422


@uses_predictor("input_secret")
def test_secret_str(client, match):
resp = client.post("/predictions", json={"input": {"secret": "foo"}})
assert resp.status_code == 200
assert resp.json() == match({"output": "foo", "status": "succeeded"})

resp = client.post("/predictions", json={"input": {"secret": {}}})
assert resp.status_code == 422


def test_untyped_inputs():
config = {"predict": _fixture_path("input_untyped")}
app = create_app(
Expand Down
10 changes: 9 additions & 1 deletion python/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest
import responses
from cog.types import URLFile, get_filename_from_url, get_filename_from_urlopen
from cog.types import Secret, URLFile, get_filename_from_url, get_filename_from_urlopen


@responses.activate
Expand Down Expand Up @@ -123,3 +123,11 @@ def test_get_filename(url, filename):
def test_get_filename_from_urlopen(url, filename):
resp = urllib.request.urlopen(url) # noqa: S310
assert get_filename_from_urlopen(resp) == filename


def test_secret_type():
secret_value = "sw0rdf1$h" # noqa: S105
secret = Secret(secret_value)

assert secret.get_secret_value() == secret_value
assert str(secret) == "**********"

0 comments on commit d4749f0

Please sign in to comment.