Skip to content

Commit

Permalink
use pydantic models to populate fastapi docs
Browse files Browse the repository at this point in the history
Signed-off-by: Rob Howley <howley.robert@gmail.com>
  • Loading branch information
robhowley committed Oct 30, 2024
1 parent 60fbc62 commit ad98954
Showing 1 changed file with 32 additions and 32 deletions.
64 changes: 32 additions & 32 deletions sdk/python/feast/feature_server.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import json
import sys
import threading
import time
import traceback
from contextlib import asynccontextmanager
from typing import List, Optional
from typing import Any, Dict, List, Optional

import pandas as pd
import psutil
Expand Down Expand Up @@ -69,6 +68,13 @@ class MaterializeIncrementalRequest(BaseModel):
feature_views: Optional[List[str]] = None


class GetOnlineFeaturesRequest(BaseModel):
entities: Dict[str, List[Any]]
feature_service: Optional[str] = None
features: Optional[List[str]] = None
full_feature_names: bool = False


def get_app(
store: "feast.FeatureStore",
registry_ttl_sec: int = DEFAULT_FEATURE_SERVER_REGISTRY_TTL,
Expand Down Expand Up @@ -108,33 +114,26 @@ async def lifespan(app: FastAPI):

app = FastAPI(lifespan=lifespan)

async def get_body(request: Request):
return await request.body()

@app.post(
"/get-online-features",
dependencies=[Depends(inject_user_details)],
)
async def get_online_features(body=Depends(get_body)):
body = json.loads(body)
full_feature_names = body.get("full_feature_names", False)
entity_rows = body["entities"]
async def get_online_features(request: GetOnlineFeaturesRequest):
# Initialize parameters for FeatureStore.get_online_features(...) call
if "feature_service" in body:
if request.feature_service:
feature_service = store.get_feature_service(
body["feature_service"], allow_cache=True
request.feature_service, allow_cache=True
)
assert_permissions(
resource=feature_service, actions=[AuthzedAction.READ_ONLINE]
)
features = feature_service
features = request.feature_service # type: ignore
else:
features = body["features"]
all_feature_views, all_on_demand_feature_views = (
utils._get_feature_views_to_use(
store.registry,
store.project,
features,
request.features,
allow_cache=True,
hide_dummy_entity=False,
)
Expand All @@ -147,18 +146,19 @@ async def get_online_features(body=Depends(get_body)):
assert_permissions(
resource=od_feature_view, actions=[AuthzedAction.READ_ONLINE]
)
features = request.features # type: ignore

read_params = dict(
features=features,
entity_rows=entity_rows,
full_feature_names=full_feature_names,
entity_rows=request.entities,
full_feature_names=request.full_feature_names,
)

if store._get_provider().async_supported.online.read:
response = await store.get_online_features_async(**read_params)
response = await store.get_online_features_async(**read_params) # type: ignore
else:
response = await run_in_threadpool(
lambda: store.get_online_features(**read_params)
lambda: store.get_online_features(**read_params) # type: ignore
)

# Convert the Protobuf object to JSON and return it
Expand All @@ -167,8 +167,7 @@ async def get_online_features(body=Depends(get_body)):
)

@app.post("/push", dependencies=[Depends(inject_user_details)])
async def push(body=Depends(get_body)):
request = PushFeaturesRequest(**json.loads(body))
async def push(request: PushFeaturesRequest):
df = pd.DataFrame(request.df)
actions = []
if request.to == "offline":
Expand Down Expand Up @@ -220,17 +219,16 @@ async def push(body=Depends(get_body)):
store.push(**push_params)

@app.post("/write-to-online-store", dependencies=[Depends(inject_user_details)])
def write_to_online_store(body=Depends(get_body)):
request = WriteToFeatureStoreRequest(**json.loads(body))
def write_to_online_store(request: WriteToFeatureStoreRequest):
df = pd.DataFrame(request.df)
feature_view_name = request.feature_view_name
allow_registry_cache = request.allow_registry_cache
try:
feature_view = store.get_stream_feature_view(
feature_view = store.get_stream_feature_view( # type: ignore
feature_view_name, allow_registry_cache=allow_registry_cache
)
except FeatureViewNotFoundException:
feature_view = store.get_feature_view(
feature_view = store.get_feature_view( # type: ignore
feature_view_name, allow_registry_cache=allow_registry_cache
)

Expand All @@ -250,11 +248,12 @@ async def health():
)

@app.post("/materialize", dependencies=[Depends(inject_user_details)])
def materialize(body=Depends(get_body)):
request = MaterializeRequest(**json.loads(body))
for feature_view in request.feature_views:
def materialize(request: MaterializeRequest):
for feature_view in request.feature_views or []:
# TODO: receives a str for resource but isn't in the Union. is str actually allowed?
assert_permissions(
resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE]
resource=feature_view, # type: ignore
actions=[AuthzedAction.WRITE_ONLINE],
)
store.materialize(
utils.make_tzaware(parser.parse(request.start_ts)),
Expand All @@ -263,11 +262,12 @@ def materialize(body=Depends(get_body)):
)

@app.post("/materialize-incremental", dependencies=[Depends(inject_user_details)])
def materialize_incremental(body=Depends(get_body)):
request = MaterializeIncrementalRequest(**json.loads(body))
for feature_view in request.feature_views:
def materialize_incremental(request: MaterializeIncrementalRequest):
for feature_view in request.feature_views or []:
# TODO: receives a str for resource but isn't in the Union. is str actually allowed?
assert_permissions(
resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE]
resource=feature_view, # type: ignore
actions=[AuthzedAction.WRITE_ONLINE],
)
store.materialize_incremental(
utils.make_tzaware(parser.parse(request.end_ts)), request.feature_views
Expand Down

0 comments on commit ad98954

Please sign in to comment.