Skip to content

Commit

Permalink
Merge pull request #3 from pier-digital/f/add-pydantic-support
Browse files Browse the repository at this point in the history
Support to pydantic validators
  • Loading branch information
gabrielguarisa authored Mar 12, 2024
2 parents ad4b8f8 + 41f992f commit 6f0c15e
Show file tree
Hide file tree
Showing 10 changed files with 407 additions and 313 deletions.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@ tests:
formatting:
poetry run ruff format .
poetry run ruff check .

.PHONY: example
example:
poetry run python example.py
28 changes: 23 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,25 @@ MODEL = Pipeline(
Let's assume that you have a dataset with the following columns:

```python
FEATURES = [
request_model = [
{"name": "sepal length (cm)", "dtype": "float64"},
{"name": "sepal width (cm)", "dtype": "float64"},
{"name": "petal length (cm)", "dtype": "float64"},
{"name": "petal width (cm)", "dtype": "float64"},
]
```
Alternatively, you can use a pydantic model to define the request model, where the alias field is used to match the variable names with the column names in the training dataset:

```python
class InputData(pydantic.BaseModel):
sepal_length: float = pydantic.Field(alias="sepal length (cm)")
sepal_width: float = pydantic.Field(alias="sepal width (cm)")
petal_length: float = pydantic.Field(alias="petal length (cm)")
petal_width: float = pydantic.Field(alias="petal width (cm)")

request_model = InputData
```

After the model is created and trained, you can create a modelib runner for this model as follows:

```python
Expand All @@ -55,7 +67,7 @@ simple_runner = ml.SklearnRunner(
name="my simple model",
predictor=MODEL,
method_name="predict",
features=FEATURES,
request_model=request_model,
)
```

Expand All @@ -66,18 +78,24 @@ pipeline_runner = ml.SklearnPipelineRunner(
"Pipeline Model",
predictor=MODEL,
method_names=["transform", "predict"],
features=FEATURES,
request_model=request_model,
)
```

Now you can extend a FastAPI app with the runners:
Now you can create a FastAPI app with the runners:

```python
app = ml.init_app(runners=[simple_runner, pipeline_runner])
```

You can also pass an existing FastAPI app to the `init_app` function:

```python
import fastapi

app = fastapi.FastAPI()

app = ml.init_app(app, [simple_runner, pipeline_runner])
app = ml.init_app(app=app, runners=[simple_runner, pipeline_runner])
```

The `init_app` function will add the necessary routes to the FastAPI app to serve the models. You can now start the app with:
Expand Down
22 changes: 14 additions & 8 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import fastapi

import pydantic
import modelib as ml


Expand All @@ -12,9 +12,7 @@ def create_model():

X, y = load_iris(return_X_y=True, as_frame=True)

X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
X_train, _, y_train, _ = train_test_split(X, y, test_size=0.2, random_state=42)

model = Pipeline(
[
Expand All @@ -28,32 +26,40 @@ def create_model():
return model


FEATURES = [
features_metadata = [
{"name": "sepal length (cm)", "dtype": "float64"},
{"name": "sepal width (cm)", "dtype": "float64"},
{"name": "petal length (cm)", "dtype": "float64"},
{"name": "petal width (cm)", "dtype": "float64"},
]


class InputData(pydantic.BaseModel):
sepal_length: float = pydantic.Field(alias="sepal length (cm)")
sepal_width: float = pydantic.Field(alias="sepal width (cm)")
petal_length: float = pydantic.Field(alias="petal length (cm)")
petal_width: float = pydantic.Field(alias="petal width (cm)")


MODEL = create_model()

simple_runner = ml.SklearnRunner(
name="my simple model",
predictor=MODEL,
method_name="predict",
features=FEATURES,
request_model=InputData, # OR request_model=features_metadata
)

pipeline_runner = ml.SklearnPipelineRunner(
"Pipeline Model",
predictor=MODEL,
method_names=["transform", "predict"],
features=FEATURES,
request_model=InputData,
)

app = fastapi.FastAPI()

app = ml.init_app(app, [simple_runner, pipeline_runner])
app = ml.init_app(app=app, runners=[simple_runner, pipeline_runner])

if __name__ == "__main__":
import uvicorn
Expand Down
46 changes: 28 additions & 18 deletions modelib/core/endpoint_factory.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,57 @@
import typing

import fastapi
from slugify import slugify


from modelib.core import schemas
from modelib.runners.base import BaseRunner
from modelib.runners.base import EndpointMetadataManager, BaseRunner


def create_runner_endpoint(
app: fastapi.FastAPI,
runner: BaseRunner,
runner_func: typing.Callable,
endpoint_metadata_manager: EndpointMetadataManager,
**kwargs,
) -> fastapi.FastAPI:
path = f"/{slugify(runner.name)}"
path = f"/{endpoint_metadata_manager.slug}"

route_kwargs = {
"name": runner.name,
"name": endpoint_metadata_manager.name,
"methods": ["POST"],
"response_model": runner.response_model,
"response_model": endpoint_metadata_manager.response_model,
}
route_kwargs.update(kwargs)

app.add_api_route(
path,
runner.get_runner_func(),
runner_func,
**route_kwargs,
)

return app


def create_runners_router(runners: typing.List[BaseRunner]) -> fastapi.APIRouter:
router = fastapi.APIRouter(
tags=["runners"],
responses={
500: {
"model": schemas.JsonApiErrorModel,
"description": "Inference Internal Server Error",
}
},
)
def create_runners_router(
runners: typing.List[BaseRunner], **runners_router_kwargs
) -> fastapi.APIRouter:
responses = runners_router_kwargs.pop("responses", {})
if 500 not in responses:
responses[500] = {
"model": schemas.JsonApiErrorModel,
"description": "Inference Internal Server Error",
}

runners_router_kwargs["responses"] = responses

runners_router_kwargs["tags"] = runners_router_kwargs.get("tags", ["runners"])

router = fastapi.APIRouter(**runners_router_kwargs)

for runner in runners:
router = create_runner_endpoint(router, runner)
router = create_runner_endpoint(
router,
runner_func=runner.get_runner_func(),
endpoint_metadata_manager=runner.endpoint_metadata_manager,
)

return router
54 changes: 27 additions & 27 deletions modelib/runners/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing

import pydantic

from slugify import slugify
from modelib.core import schemas


Expand All @@ -12,51 +12,51 @@ def remove_unset_features(features: typing.List[dict]) -> typing.List[dict]:
]


class BaseRunner:
class EndpointMetadataManager:
def __init__(
self,
name: str,
predictor: typing.Any,
features: typing.List[dict] = None,
request_model: typing.Union[typing.Type[pydantic.BaseModel], typing.List[dict]],
response_model: typing.Type[pydantic.BaseModel] = schemas.ResultResponseModel,
**kwargs,
):
self._name = name
self._predictor = predictor
self._features = remove_unset_features(features) if features else None
self._request_model = (
schemas.pydantic_model_from_list_of_dicts(self.name, self.features)
if self.features
else None
)
if isinstance(request_model, list):
request_model = remove_unset_features(request_model)
self._request_model = schemas.pydantic_model_from_list_of_dicts(
name, request_model
)
elif issubclass(request_model, pydantic.BaseModel):
self._request_model = request_model
else:
raise ValueError("request_model must be a pydantic.BaseModel subclass")

if not issubclass(response_model, pydantic.BaseModel):
raise ValueError("response_model must be a pydantic.BaseModel subclass")

self._response_model = response_model

@property
def name(self) -> str:
return self._name

@property
def predictor(self) -> typing.Any:
return self._predictor

@property
def features(self) -> typing.List[str]:
return self._features
def slug(self) -> str:
return slugify(self.name)

@property
def request_model(self) -> typing.Type[pydantic.BaseModel]:
return self._request_model

@property
def response_model(self) -> typing.Type[pydantic.BaseModel]:
return schemas.ResultResponseModel
return self._response_model

def get_runner_func(self) -> typing.Callable:
raise NotImplementedError()

def to_dict(self) -> dict:
return {
"name": self.name,
"features": self.features,
}
class BaseRunner:
@property
def endpoint_metadata_manager(self) -> EndpointMetadataManager:
raise NotImplementedError

def __dict__(self) -> dict:
return self.to_dict()
def get_runner_func(self) -> typing.Callable:
raise NotImplementedError
Loading

0 comments on commit 6f0c15e

Please sign in to comment.