diff --git a/python/cog/__init__.py b/python/cog/__init__.py index 70e84de40..a77151ad3 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -28,4 +28,5 @@ "Path", "Secret", "emit_metric", + "Secret", ] diff --git a/python/cog/predictor.py b/python/cog/predictor.py index c546cdda8..bc14a75ec 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -28,11 +28,8 @@ from typing_extensions import Annotated from .errors import ConfigDoesNotExist, PredictorNotSet -from .types import ( - File as CogFile, -) +from .types import File as CogFile, Path as CogPath, Secret as CogSecret from .types import Input -from .types import Path as CogPath, Secret as CogSecret log = structlog.get_logger("cog.server.predictor") diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index ca046ce80..36f457aa3 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -88,7 +88,7 @@ class SetupResult: class TimeShareTracker: def __init__(self) -> None: - self._time_shares_per_prediction: "dict[str, float]" = {} + self._time_shares_per_prediction: dict[str, float] = {} self._last_updated_time_shares = 0.0 def update_time_shares(self) -> None: @@ -122,8 +122,8 @@ def __init__( self._shutdown_event = shutdown_event # __main__ waits for this event self._upload_url = upload_url - self._predictions: "dict[str, tuple[schema.PredictionResponse, PredictionTask]]" = {} - self._predictions_in_flight: "set[str]" = set() + self._predictions: dict[str, tuple[schema.PredictionResponse, PredictionTask]] = {} + self._predictions_in_flight: set[str] = set() # it would be lovely to merge these but it's not fully clear how best to handle it # since idempotent requests can kinda come whenever? # p: dict[str, PredictionTask] @@ -144,7 +144,7 @@ def __init__( self._time_shares_per_prediction: "dict[str, float]" = {} self._last_updated_time_shares = 0.0 self._child = _ChildWorker(predictor_ref, child_events, tee_output) - self._events: "AsyncConnection[tuple[str, PublicEventType]]" = AsyncConnection( + self._events: AsyncConnection[tuple[str, PublicEventType]] = AsyncConnection( events ) # shutdown requested diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index 9584c4974..4b40515e5 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -54,11 +54,11 @@ class WorkerState(Enum): class Mux: def __init__(self, terminating: asyncio.Event) -> None: - self.outs: "defaultdict[str, asyncio.Queue[PublicEventType]]" = defaultdict( + self.outs: defaultdict[str, asyncio.Queue[PublicEventType]] = defaultdict( asyncio.Queue ) self.terminating = terminating - self.fatal: "Optional[FatalWorkerException]" = None + self.fatal: Optional[FatalWorkerException] = None async def write(self, id: str, item: PublicEventType) -> None: await self.outs[id].put(item) @@ -199,11 +199,11 @@ def _loop_sync(self) -> None: print(f"Got unexpected event: {ev}", file=sys.stderr) async def _loop_async(self) -> None: - events: "AsyncConnection[tuple[str, PublicEventType]]" = AsyncConnection( + events: AsyncConnection[tuple[str, PublicEventType]] = AsyncConnection( self._events ) with events: - tasks: "dict[str, asyncio.Task[None]]" = {} + tasks: dict[str, asyncio.Task[None]] = {} while True: try: ev = await events.recv() diff --git a/python/tests/server/test_runner.py b/python/tests/server/test_runner.py index 79cce97fa..4c9600c1d 100644 --- a/python/tests/server/test_runner.py +++ b/python/tests/server/test_runner.py @@ -7,7 +7,6 @@ import pytest import pytest_asyncio - from cog.schema import PredictionRequest, PredictionResponse, Status, WebhookEvent from cog.server.clients import ClientManager from cog.server.eventtypes import ( diff --git a/python/tests/test_types.py b/python/tests/test_types.py index ba3eaca76..c67b31cbb 100644 --- a/python/tests/test_types.py +++ b/python/tests/test_types.py @@ -4,7 +4,7 @@ import pytest import responses -from cog.types import Secret, URLFile, get_filename_from_url, get_filename_from_urlopen, +from cog.types import Secret, URLFile, get_filename_from_url, get_filename_from_urlopen @responses.activate @@ -131,3 +131,11 @@ def test_secret_type(): def test_get_filename_from_urlopen(url, filename): resp = urllib.request.urlopen(url) # noqa: S310 assert get_filename_from_urlopen(resp) == filename + + +def test_secret_type(): + secret_value = "sw0rdf1$h" # noqa: S105 + secret = Secret(secret_value) + + assert secret.get_secret_value() == secret_value + assert str(secret) == "**********"