Skip to content
This repository has been archived by the owner on Aug 25, 2023. It is now read-only.

Commit

Permalink
Support explain method and json error (#77)
Browse files Browse the repository at this point in the history
* Support explain method and json error

* Add missing imports

* Remove explain from fastapi

* Add tests for explainer

* Fix import

* Try optional query
  • Loading branch information
Han Qiao authored Apr 23, 2021
1 parent 2752a40 commit 7b8eb1c
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 14 deletions.
2 changes: 1 addition & 1 deletion fastapi-gunicorn/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fi
EXTRA_OPTS=""
if [ "$WORKERS" -gt 1 ]; then
# Try not to override user defined variable
export prometheus_multiproc_dir="${prometheus_multiproc_dir:-/tmp}"
export PROMETHEUS_MULTIPROC_DIR="${PROMETHEUS_MULTIPROC_DIR:-/tmp}"
EXTRA_OPTS="--config gunicorn_config.py"
fi

Expand Down
27 changes: 21 additions & 6 deletions fastapi-gunicorn/serve_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
from datetime import datetime
from importlib import import_module
from logging import getLogger
from typing import Optional
from os import getenv
from uuid import UUID

from bedrock_client.bedrock.metrics.context import PredictionContext
from bedrock_client.bedrock.metrics.registry import is_single_value
from bedrock_client.bedrock.metrics.service import ModelMonitoringService
from bedrock_client.bedrock.model import BaseModel
from boxkite.monitoring.context import PredictionContext
from boxkite.monitoring.registry import is_single_value
from boxkite.monitoring.service import ModelMonitoringService
from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse

logger = getLogger()
try:
Expand Down Expand Up @@ -38,6 +40,7 @@ async def middleware(request: Request):


@app.post("/")
@app.post("/predict")
async def predict(request: Request):
await middleware(request)

Expand Down Expand Up @@ -87,13 +90,25 @@ async def predict(request: Request):
return request.app.model.post_process(score=score, prediction_id=pid)


@app.post("/explain/")
@app.post("/explain/<target>")
def explain(target: Optional[str] = None):
return JSONResponse(
content={
"type": "InternalServerError",
"reason": "Model does not implement 'explain' method",
},
status_code=501,
)


@app.get("/metrics")
async def get_metrics(request: Request):
"""Returns real time feature values recorded by Prometheus
"""
"""Returns real time feature values recorded by Prometheus"""
await middleware(request)

body, content_type = request.app.monitor.export_http(
params=dict(request.query_params), headers=request.headers,
params=dict(request.query_params),
headers=request.headers,
)
return Response(body, media_type=content_type)
31 changes: 28 additions & 3 deletions flask-gunicorn-gpu/serve_http.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import json
from dataclasses import replace
from datetime import datetime
from importlib import import_module
from logging import getLogger
from os import getenv
from uuid import UUID

from bedrock_client.bedrock.metrics.context import PredictionContext
from bedrock_client.bedrock.metrics.registry import is_single_value
from bedrock_client.bedrock.metrics.service import ModelMonitoringService
from bedrock_client.bedrock.model import BaseModel
from boxkite.monitoring.context import PredictionContext
from boxkite.monitoring.registry import is_single_value
from boxkite.monitoring.service import ModelMonitoringService
from flask import Flask, Response, current_app, request
from werkzeug.exceptions import HTTPException

logger = getLogger()
try:
Expand Down Expand Up @@ -38,7 +40,19 @@ def init_background_threads():
current_app.monitor = ModelMonitoringService()


@app.errorhandler(HTTPException)
def handle_exception(e):
"""Return JSON instead of HTML for HTTP errors."""
# Start with the correct headers and status code from the error
response = e.get_response()
# Replace the http body with JSON
response.data = json.dumps({"type": e.name, "reason": e.description})
response.content_type = "application/json"
return response


@app.route("/", methods=["POST"])
@app.route("/predict", methods=["POST"])
def predict():
# User code to load features
features = current_app.model.pre_process(http_body=request.data, files=request.files)
Expand Down Expand Up @@ -82,6 +96,17 @@ def predict():
return current_app.model.post_process(score=score, prediction_id=pid)


@app.route("/explain/", defaults={"target": None}, methods=["POST"])
@app.route("/explain/<target>", methods=["POST"])
def explain(target):
if not callable(getattr(current_app.model, "explain", None)):
return "Model does not implement 'explain' method", 501
features = current_app.model.pre_process(
http_body=request.data, files=request.files
)
return current_app.model.explain(features=features, target=target)[0]


@app.route("/metrics", methods=["GET"])
def get_metrics():
"""Returns real time feature values recorded by Prometheus
Expand Down
2 changes: 1 addition & 1 deletion flask-gunicorn/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fi
EXTRA_OPTS=""
if [ "$WORKERS" -gt 1 ]; then
# Try not to override user defined variable
export prometheus_multiproc_dir="${prometheus_multiproc_dir:-/tmp}"
export PROMETHEUS_MULTIPROC_DIR="${PROMETHEUS_MULTIPROC_DIR:-/tmp}"
EXTRA_OPTS="--config gunicorn_config.py"
fi

Expand Down
31 changes: 28 additions & 3 deletions flask-gunicorn/serve_http.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import json
from dataclasses import replace
from datetime import datetime
from importlib import import_module
from logging import getLogger
from os import getenv
from uuid import UUID

from bedrock_client.bedrock.metrics.context import PredictionContext
from bedrock_client.bedrock.metrics.registry import is_single_value
from bedrock_client.bedrock.metrics.service import ModelMonitoringService
from bedrock_client.bedrock.model import BaseModel
from boxkite.monitoring.context import PredictionContext
from boxkite.monitoring.registry import is_single_value
from boxkite.monitoring.service import ModelMonitoringService
from flask import Flask, Response, current_app, request
from werkzeug.exceptions import HTTPException

logger = getLogger()
try:
Expand Down Expand Up @@ -38,7 +40,19 @@ def init_background_threads():
current_app.monitor = ModelMonitoringService()


@app.errorhandler(HTTPException)
def handle_exception(e):
"""Return JSON instead of HTML for HTTP errors."""
# Start with the correct headers and status code from the error
response = e.get_response()
# Replace the http body with JSON
response.data = json.dumps({"type": e.name, "reason": e.description})
response.content_type = "application/json"
return response


@app.route("/", methods=["POST"])
@app.route("/predict", methods=["POST"])
def predict():
# User code to load features
features = current_app.model.pre_process(http_body=request.data, files=request.files)
Expand Down Expand Up @@ -82,6 +96,17 @@ def predict():
return current_app.model.post_process(score=score, prediction_id=pid)


@app.route("/explain/", defaults={"target": None}, methods=["POST"])
@app.route("/explain/<target>", methods=["POST"])
def explain(target):
if not callable(getattr(current_app.model, "explain", None)):
return "Model does not implement 'explain' method", 501
features = current_app.model.pre_process(
http_body=request.data, files=request.files
)
return current_app.model.explain(features=features, target=target)[0]


@app.route("/metrics", methods=["GET"])
def get_metrics():
"""Returns real time feature values recorded by Prometheus
Expand Down
6 changes: 6 additions & 0 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ def test_file_upload(self):
after = get_inference_count(resp.text, 208)
self.assertEqual(after - before, 1)

@skipIf(getenv("MODEL", None) not in MODELS["image"], "test explainer on vision models")
def test_explain(self):
with Session() as s:
resp = s.post(f"{self.url}/explain")
assert resp.status_code == 501

@skipIf(getenv("MODEL", None) not in MODELS["language"], "post body for language models")
def test_sentiment(self):
with Session() as s:
Expand Down

0 comments on commit 7b8eb1c

Please sign in to comment.