Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multiple PredictionEventHandler #1392

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions python/cog/command/ast_openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,15 @@ def find(obj: ast.AST, name: str) -> ast.AST:
"""Find a particular named node in a tree"""
return next(node for node in ast.walk(obj) if getattr(node, "name", "") == name)


if typing.TYPE_CHECKING:
AstVal: "typing.TypeAlias" = "int | float | complex | str | list[AstVal] | bytes | None"
AstVal: "typing.TypeAlias" = (
"int | float | complex | str | list[AstVal] | bytes | None"
)
AstValNoBytes: "typing.TypeAlias" = "int | float | str | list[AstValNoBytes]"
JSONObject: "typing.TypeAlias" = "int | float | str | list[JSONObject] | JSONDict | None"
JSONObject: "typing.TypeAlias" = (
"int | float | str | list[JSONObject] | JSONDict | None"
)
JSONDict: "typing.TypeAlias" = "dict[str, JSONObject]"


Expand All @@ -327,6 +332,7 @@ def to_serializable(val: "AstVal") -> "JSONObject":
else:
return val


def get_value(node: ast.AST) -> "AstVal":
"""Return the value of constant or list of constants"""
if isinstance(node, ast.Constant):
Expand All @@ -339,7 +345,7 @@ def get_value(node: ast.AST) -> "AstVal":
if isinstance(node, (ast.List, ast.Tuple)):
return [get_value(e) for e in node.elts]
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
return -typing.cast(typing.Union[int, float, complex], get_value(node.operand))
return -typing.cast(typing.Union[int, float, complex], get_value(node.operand))
raise ValueError("Unexpected node type", type(node))


Expand All @@ -366,12 +372,13 @@ def get_call_name(call: ast.Call) -> str:
def parse_args(tree: ast.AST) -> "list[tuple[ast.arg, ast.expr | types.EllipsisType]]":
"""Parse argument, default pairs from a file with a predict function"""
predict = find(tree, "predict")
assert isinstance(predict, ast.FunctionDef)
assert isinstance(predict, (ast.FunctionDef, ast.AsyncFunctionDef))
args = predict.args.args # [-len(defaults) :]
# use Ellipsis instead of None here to distinguish a default of None
defaults = [...] * (len(args) - len(predict.args.defaults)) + predict.args.defaults
return list(zip(args, defaults))


def parse_assignment(assignment: ast.AST) -> "None | tuple[str, JSONObject]":
"""Parse an assignment into an OpenAPI object property"""
if isinstance(assignment, ast.AnnAssign):
Expand Down Expand Up @@ -403,7 +410,9 @@ def parse_class(classdef: ast.AST) -> "JSONDict":
"""Parse a class definition into an OpenAPI object"""
assert isinstance(classdef, ast.ClassDef)
properties = {
assignment[0]: assignment[1] for assignment in map(parse_assignment, classdef.body) if assignment
assignment[0]: assignment[1]
for assignment in map(parse_assignment, classdef.body)
if assignment
}
return {
"title": classdef.name,
Expand All @@ -428,17 +437,19 @@ def resolve_name(node: ast.expr) -> str:
return node.id
if isinstance(node, ast.Index):
# deprecated, but needed for py3.8
return resolve_name(node.value) # type: ignore
return resolve_name(node.value) # type: ignore
if isinstance(node, ast.Attribute):
return node.attr
if isinstance(node, ast.Subscript):
return resolve_name(node.value)
raise ValueError("Unexpected node type", type(node), ast.unparse(node))


def parse_return_annotation(tree: ast.AST, fn: str = "predict") -> "tuple[JSONDict, JSONDict]":
def parse_return_annotation(
tree: ast.AST, fn: str = "predict"
) -> "tuple[JSONDict, JSONDict]":
predict = find(tree, fn)
if not isinstance(predict, ast.FunctionDef):
if not isinstance(predict, (ast.FunctionDef, ast.AsyncFunctionDef)):
raise ValueError("Could not find predict function")
annotation = predict.returns
if not annotation:
Expand Down Expand Up @@ -550,7 +561,7 @@ def extract_info(code: str) -> "JSONDict":
**return_schema,
}
# trust me, typechecker, I know BASE_SCHEMA
x: "JSONDict" = schema["components"]["schemas"] # type: ignore
x: "JSONDict" = schema["components"]["schemas"] # type: ignore
x.update(components)
return schema

Expand Down
82 changes: 48 additions & 34 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import sys
import types
from abc import ABC, abstractmethod
from collections.abc import Iterator
from collections.abc import Iterator, AsyncIterator
from pathlib import Path
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Expand Down Expand Up @@ -48,7 +49,9 @@


class BasePredictor(ABC):
def setup(self, weights: Optional[Union[CogFile, CogPath]] = None) -> None:
def setup(
self, weights: Optional[Union[CogFile, CogPath]] = None
) -> Optional[Awaitable[None]]:
"""
An optional method to prepare the model so multiple predictions run efficiently.
"""
Expand All @@ -63,15 +66,25 @@ def predict(self, **kwargs: Any) -> Any:


def run_setup(predictor: BasePredictor) -> None:
weights_type = get_weights_type(predictor.setup)

# No weights need to be passed, so just run setup() without any arguments.
if weights_type is None:
weights = get_weights_argument(predictor)
if weights:
predictor.setup(weights=weights)
else:
predictor.setup()
return

weights: Union[io.IOBase, Path, None]

async def run_setup_async(predictor: BasePredictor) -> None:
weights = get_weights_argument(predictor)
maybe_coro = predictor.setup(weights=weights) if weights else predictor.setup()
if maybe_coro:
return await maybe_coro


def get_weights_argument(predictor: BasePredictor) -> Union[CogFile, CogPath, None]:
# by the time we get here we assume predictor has a setup method
weights_type = get_weights_type(predictor.setup)
if weights_type is None:
return None
weights_url = os.environ.get("COG_WEIGHTS")
weights_path = "weights"

Expand All @@ -81,30 +94,25 @@ def run_setup(predictor: BasePredictor) -> None:
# TODO: CogFile/CogPath should have subclasses for each of the subtypes
if weights_url:
if weights_type == CogFile:
weights = cast(CogFile, CogFile.validate(weights_url))
elif weights_type == CogPath:
return cast(CogFile, CogFile.validate(weights_url))
if weights_type == CogPath:
# TODO: So this can be a url. evil!
weights = cast(CogPath, CogPath.validate(weights_url))
else:
raise ValueError(
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported"
)
elif os.path.exists(weights_path):
return cast(CogPath, CogPath.validate(weights_url))
raise ValueError(
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported"
)
if os.path.exists(weights_path):
if weights_type == CogFile:
weights = cast(CogFile, open(weights_path, "rb"))
elif weights_type == CogPath:
weights = CogPath(weights_path)
else:
raise ValueError(
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported"
)
else:
weights = None

predictor.setup(weights=weights)
return cast(CogFile, open(weights_path, "rb"))
if weights_type == CogPath:
return CogPath(weights_path)
raise ValueError(
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported"
)
return None


def get_weights_type(setup_function: Callable[[Any], None]) -> Optional[Any]:
def get_weights_type(setup_function: Callable[[Any], Optional[Awaitable[None]]]) -> Optional[Any]:
signature = inspect.signature(setup_function)
if "weights" not in signature.parameters:
return None
Expand All @@ -118,7 +126,9 @@ def get_weights_type(setup_function: Callable[[Any], None]) -> Optional[Any]:


def run_prediction(
predictor: BasePredictor, inputs: Dict[Any, Any], cleanup_functions: List[Callable[[], None]],
predictor: BasePredictor,
inputs: Dict[Any, Any],
cleanup_functions: List[Callable[[], None]],
) -> Any:
"""
Run the predictor on the inputs, and append resulting paths
Expand Down Expand Up @@ -218,20 +228,24 @@ def get_predict(predictor: Any) -> Callable[..., Any]:
return predictor.predict
return predictor


def validate_input_type(type: Type[Any], name: str) -> None:
if type is inspect.Signature.empty:
raise TypeError(
f"No input type provided for parameter `{name}`. Supported input types are: {readable_types_list(ALLOWED_INPUT_TYPES)}, or a Union or List of those types."
)
raise TypeError(
f"No input type provided for parameter `{name}`. Supported input types are: {readable_types_list(ALLOWED_INPUT_TYPES)}, or a Union or List of those types."
)
elif type not in ALLOWED_INPUT_TYPES:
if get_origin(type) in (Union, List, list) or (hasattr(types, "UnionType") and get_origin(type) is types.UnionType): # noqa: E721
if get_origin(type) in (Union, List, list) or (
hasattr(types, "UnionType") and get_origin(type) is types.UnionType
): # noqa: E721
for t in get_args(type):
validate_input_type(t, name)
else:
raise TypeError(
f"Unsupported input type {human_readable_type_name(type)} for parameter `{name}`. Supported input types are: {readable_types_list(ALLOWED_INPUT_TYPES)}, or a Union or List of those types."
)


def get_input_type(predictor: BasePredictor) -> Type[BaseInput]:
"""
Creates a Pydantic Input model from the arguments of a Predictor's predict() method.
Expand Down Expand Up @@ -329,7 +343,7 @@ def predict(
OutputType = signature.return_annotation

# The type that goes in the response is a list of the yielded type
if get_origin(OutputType) is Iterator:
if get_origin(OutputType) in {Iterator, AsyncIterator}:
# Annotated allows us to attach Field annotations to the list, which we use to mark that this is an iterator
# https://pydantic-docs.helpmanual.io/usage/schema/#typingannotated-fields
OutputType: Type[BaseModel] = Annotated[List[get_args(OutputType)[0]], Field(**{"x-cog-array-type": "iterator"})] # type: ignore
Expand Down
4 changes: 4 additions & 0 deletions python/cog/server/eventtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#
@define
class PredictionInput:
id: str
payload: Dict[str, Any]


Expand All @@ -25,16 +26,19 @@ class Log:

@define
class PredictionOutput:
id: str
payload: Any


@define
class PredictionOutputType:
id: str
multi: bool = False


@define
class Done:
id: str
canceled: bool = False
error: bool = False
error_detail: str = ""
Expand Down
10 changes: 9 additions & 1 deletion python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,17 @@ class Health(Enum):
BUSY = auto()
SETUP_FAILED = auto()


class State:
health: Health
setup_result: "Optional[asyncio.Task[schema.PredictionResponse]]"
setup_result_payload: Optional[schema.PredictionResponse]


class MyFastAPI(FastAPI):
state: State


def create_app(
config: Dict[str, Any],
shutdown_event: Optional[threading.Event],
Expand Down Expand Up @@ -95,6 +98,7 @@ def create_app(

class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType)):
pass

PredictionResponse = schema.PredictionResponse.with_types(
input_type=InputType, output_type=OutputType
)
Expand All @@ -104,6 +108,7 @@ class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType
if TYPE_CHECKING:
P = ParamSpec("P")
T = TypeVar("T")

def limited(f: "Callable[P, Awaitable[T]]") -> "Callable[P, Awaitable[T]]":
@functools.wraps(f)
async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> "T":
Expand Down Expand Up @@ -148,7 +153,10 @@ async def healthcheck() -> Any:
response_model=PredictionResponse,
response_model_exclude_unset=True,
)
async def predict(request: PredictionRequest = Body(default=None), prefer: Union[str, None] = Header(default=None)) -> Any: # type: ignore
async def predict(
request: PredictionRequest = Body(default=None),
prefer: Union[str, None] = Header(default=None),
) -> Any: # type: ignore
"""
Run a single prediction on the model
"""
Expand Down
Loading
Loading