Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define Secret type #1546

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/cog/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pydantic import BaseModel

from .predictor import BasePredictor
from .types import ConcatenateIterator, File, Input, Path
from .types import ConcatenateIterator, File, Input, Path, Secret

try:
from ._version import __version__
Expand All @@ -17,4 +17,5 @@
"File",
"Input",
"Path",
"Secret",
]
15 changes: 11 additions & 4 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,20 @@
Input,
URLPath,
)
from .types import (
Path as CogPath,
)
from .types import Path as CogPath
from .types import Secret as CogSecret

log = structlog.get_logger("cog.server.predictor")

ALLOWED_INPUT_TYPES: List[Type[Any]] = [str, int, float, bool, CogFile, CogPath]
ALLOWED_INPUT_TYPES: List[Type[Any]] = [
str,
int,
float,
bool,
CogFile,
CogPath,
CogSecret,
]


class BasePredictor(ABC):
Expand Down
15 changes: 14 additions & 1 deletion python/cog/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any, Dict, Iterator, List, Optional, TypeVar, Union

import requests
from pydantic import Field
from pydantic import Field, SecretStr

FILENAME_ILLEGAL_CHARS = set("\u0000/")

Expand Down Expand Up @@ -42,6 +42,19 @@ def Input(
)


class Secret(SecretStr):
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
"""Defines what this type should be in openapi.json"""
field_schema.update(
{
"type": "string",
"format": "password",
"x-cog-secret": True,
}
)


class File(io.IOBase):
validate_always = True

Expand Down
6 changes: 6 additions & 0 deletions python/tests/server/fixtures/input_secret.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from cog import BasePredictor, Secret


class Predictor(BasePredictor):
def predict(self, secret: Secret) -> str:
return secret.get_secret_value()
10 changes: 10 additions & 0 deletions python/tests/server/test_http_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,16 @@ def test_union_integers(client):
assert resp.status_code == 422


@uses_predictor("input_secret")
def test_secret_str(client, match):
resp = client.post("/predictions", json={"input": {"secret": "foo"}})
assert resp.status_code == 200
assert resp.json() == match({"output": "foo", "status": "succeeded"})

resp = client.post("/predictions", json={"input": {"secret": {}}})
assert resp.status_code == 422


def test_untyped_inputs():
config = {"predict": _fixture_path("input_untyped")}
app = create_app(
Expand Down
10 changes: 9 additions & 1 deletion python/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
import responses
from cog.types import URLFile, get_filename
from cog.types import Secret, URLFile, get_filename


@responses.activate
Expand Down Expand Up @@ -119,3 +119,11 @@ def test_urlfile_can_be_pickled_even_once_loaded():
)
def test_get_filename(url, filename):
assert get_filename(url) == filename


def test_secret_type():
secret_value = "sw0rdf1$h" # noqa: S105
secret = Secret(secret_value)

assert secret.get_secret_value() == secret_value
assert str(secret) == "**********"
4 changes: 3 additions & 1 deletion test-integration/test_integration/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ def random_string(length):
def remove_docker_image(image_name, max_attempts=5, wait_seconds=1):
for attempt in range(max_attempts):
try:
subprocess.run(["docker", "rmi", "-f", image_name], check=True, capture_output=True)
subprocess.run(
["docker", "rmi", "-f", image_name], check=True, capture_output=True
)
print(f"Image {image_name} successfully removed.")
break
except subprocess.CalledProcessError as e:
Expand Down