Skip to content

Commit

Permalink
Backport Secret type to async branch (#1706)
Browse files Browse the repository at this point in the history
* Define Secret type (#1546)

Signed-off-by: Mattt Zmuda <mattt@replicate.com>

* Fix linter errors (#1691)

* ruff --fix

Signed-off-by: Mattt Zmuda <mattt@replicate.com>

* Update ruff pyproject settings

Signed-off-by: Mattt Zmuda <mattt@replicate.com>

* Update ruff lint command in Makefile

Signed-off-by: Mattt Zmuda <mattt@replicate.com>

---------

Signed-off-by: Mattt Zmuda <mattt@replicate.com>

* Run ruff --fix python

Signed-off-by: Mattt Zmuda <mattt@replicate.com>

---------

Signed-off-by: Mattt Zmuda <mattt@replicate.com>
  • Loading branch information
mattt authored and technillogue committed Jun 19, 2024
1 parent b3064a9 commit a45bff0
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 14 deletions.
1 change: 1 addition & 0 deletions python/cog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@
"Path",
"Secret",
"emit_metric",
"Secret",
]
5 changes: 1 addition & 4 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,8 @@
from typing_extensions import Annotated

from .errors import ConfigDoesNotExist, PredictorNotSet
from .types import (
File as CogFile,
)
from .types import File as CogFile, Path as CogPath, Secret as CogSecret
from .types import Input
from .types import Path as CogPath, Secret as CogSecret

log = structlog.get_logger("cog.server.predictor")

Expand Down
8 changes: 4 additions & 4 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class SetupResult:

class TimeShareTracker:
def __init__(self) -> None:
self._time_shares_per_prediction: "dict[str, float]" = {}
self._time_shares_per_prediction: dict[str, float] = {}
self._last_updated_time_shares = 0.0

def update_time_shares(self) -> None:
Expand Down Expand Up @@ -122,8 +122,8 @@ def __init__(
self._shutdown_event = shutdown_event # __main__ waits for this event

self._upload_url = upload_url
self._predictions: "dict[str, tuple[schema.PredictionResponse, PredictionTask]]" = {}
self._predictions_in_flight: "set[str]" = set()
self._predictions: dict[str, tuple[schema.PredictionResponse, PredictionTask]] = {}
self._predictions_in_flight: set[str] = set()
# it would be lovely to merge these but it's not fully clear how best to handle it
# since idempotent requests can kinda come whenever?
# p: dict[str, PredictionTask]
Expand All @@ -144,7 +144,7 @@ def __init__(
self._time_shares_per_prediction: "dict[str, float]" = {}
self._last_updated_time_shares = 0.0
self._child = _ChildWorker(predictor_ref, child_events, tee_output)
self._events: "AsyncConnection[tuple[str, PublicEventType]]" = AsyncConnection(
self._events: AsyncConnection[tuple[str, PublicEventType]] = AsyncConnection(
events
)
# shutdown requested
Expand Down
8 changes: 4 additions & 4 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ class WorkerState(Enum):

class Mux:
def __init__(self, terminating: asyncio.Event) -> None:
self.outs: "defaultdict[str, asyncio.Queue[PublicEventType]]" = defaultdict(
self.outs: defaultdict[str, asyncio.Queue[PublicEventType]] = defaultdict(
asyncio.Queue
)
self.terminating = terminating
self.fatal: "Optional[FatalWorkerException]" = None
self.fatal: Optional[FatalWorkerException] = None

async def write(self, id: str, item: PublicEventType) -> None:
await self.outs[id].put(item)
Expand Down Expand Up @@ -199,11 +199,11 @@ def _loop_sync(self) -> None:
print(f"Got unexpected event: {ev}", file=sys.stderr)

async def _loop_async(self) -> None:
events: "AsyncConnection[tuple[str, PublicEventType]]" = AsyncConnection(
events: AsyncConnection[tuple[str, PublicEventType]] = AsyncConnection(
self._events
)
with events:
tasks: "dict[str, asyncio.Task[None]]" = {}
tasks: dict[str, asyncio.Task[None]] = {}
while True:
try:
ev = await events.recv()
Expand Down
1 change: 0 additions & 1 deletion python/tests/server/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import pytest
import pytest_asyncio

from cog.schema import PredictionRequest, PredictionResponse, Status, WebhookEvent
from cog.server.clients import ClientManager
from cog.server.eventtypes import (
Expand Down
10 changes: 9 additions & 1 deletion python/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest
import responses
from cog.types import Secret, URLFile, get_filename_from_url, get_filename_from_urlopen,
from cog.types import Secret, URLFile, get_filename_from_url, get_filename_from_urlopen


@responses.activate
Expand Down Expand Up @@ -131,3 +131,11 @@ def test_secret_type():
def test_get_filename_from_urlopen(url, filename):
resp = urllib.request.urlopen(url) # noqa: S310
assert get_filename_from_urlopen(resp) == filename


def test_secret_type():
secret_value = "sw0rdf1$h" # noqa: S105
secret = Secret(secret_value)

assert secret.get_secret_value() == secret_value
assert str(secret) == "**********"

0 comments on commit a45bff0

Please sign in to comment.