diff --git a/fastapi-gunicorn/serve_http.py b/fastapi-gunicorn/serve_http.py index 9f93f25..9da1ccd 100644 --- a/fastapi-gunicorn/serve_http.py +++ b/fastapi-gunicorn/serve_http.py @@ -2,8 +2,8 @@ from datetime import datetime from importlib import import_module from logging import getLogger -from typing import Optional from os import getenv +from typing import Optional from uuid import UUID from bedrock_client.bedrock.model import BaseModel @@ -108,7 +108,6 @@ async def get_metrics(request: Request): await middleware(request) body, content_type = request.app.monitor.export_http( - params=dict(request.query_params), - headers=request.headers, + params=dict(request.query_params), headers=request.headers ) return Response(body, media_type=content_type) diff --git a/flask-gunicorn-gpu/serve_http.py b/flask-gunicorn-gpu/serve_http.py index 0e7f23c..ba4234c 100644 --- a/flask-gunicorn-gpu/serve_http.py +++ b/flask-gunicorn-gpu/serve_http.py @@ -101,18 +101,15 @@ def predict(): def explain(target): if not callable(getattr(current_app.model, "explain", None)): return "Model does not implement 'explain' method", 501 - features = current_app.model.pre_process( - http_body=request.data, files=request.files - ) + features = current_app.model.pre_process(http_body=request.data, files=request.files) return current_app.model.explain(features=features, target=target)[0] @app.route("/metrics", methods=["GET"]) def get_metrics(): - """Returns real time feature values recorded by Prometheus - """ + """Returns real time feature values recorded by Prometheus""" body, content_type = current_app.monitor.export_http( - params=request.args.to_dict(flat=False), headers=request.headers, + params=request.args.to_dict(flat=False), headers=request.headers ) return Response(body, content_type=content_type) diff --git a/flask-gunicorn/serve_http.py b/flask-gunicorn/serve_http.py index 0e7f23c..ba4234c 100644 --- a/flask-gunicorn/serve_http.py +++ b/flask-gunicorn/serve_http.py @@ -101,18 +101,15 @@ def predict(): def explain(target): if not callable(getattr(current_app.model, "explain", None)): return "Model does not implement 'explain' method", 501 - features = current_app.model.pre_process( - http_body=request.data, files=request.files - ) + features = current_app.model.pre_process(http_body=request.data, files=request.files) return current_app.model.explain(features=features, target=target)[0] @app.route("/metrics", methods=["GET"]) def get_metrics(): - """Returns real time feature values recorded by Prometheus - """ + """Returns real time feature values recorded by Prometheus""" body, content_type = current_app.monitor.export_http( - params=request.args.to_dict(flat=False), headers=request.headers, + params=request.args.to_dict(flat=False), headers=request.headers ) return Response(body, content_type=content_type) diff --git a/requirements-dev.txt b/requirements-dev.txt index c5000c7..05604ab 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,3 +2,5 @@ black flake8 isort mypy +types-Flask +types-Werkzeug diff --git a/style.sh b/style.sh index 6cd9f59..a812c68 100755 --- a/style.sh +++ b/style.sh @@ -18,10 +18,10 @@ while getopts 'cf:' flag; do done if [ -z "$check" ]; then - isort "$files" --check --diff + isort --profile black "$files" --check-only --diff black "$files" --check else - isort "$files" + isort --profile black "$files" black "$files" fi diff --git a/tests/lightgbm/model-server/requirements.txt b/tests/lightgbm/model-server/requirements.txt index 53bf3d3..b9ec155 100644 --- a/tests/lightgbm/model-server/requirements.txt +++ b/tests/lightgbm/model-server/requirements.txt @@ -1,3 +1,3 @@ -bdrk==0.7.2 +bdrk==0.8.2 boxkite==0.0.4 lightgbm==2.3.1 diff --git a/tests/test_e2e.py b/tests/test_e2e.py index a7eebe2..4155baa 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -81,9 +81,7 @@ def test_sentiment(self): after = get_inference_count(resp.text) self.assertEqual(after - before, 4) - @skipIf( - getenv("MODEL", None) not in MODELS["churn"], "post body for churn prediction models", - ) + @skipIf(getenv("MODEL", None) not in MODELS["churn"], "post body for churn prediction models") def test_post(self): with Session() as s: resp = s.get(f"{self.url}/metrics") diff --git a/tests/tf-vision/model-server/requirements.txt b/tests/tf-vision/model-server/requirements.txt index fe636ee..c945604 100644 --- a/tests/tf-vision/model-server/requirements.txt +++ b/tests/tf-vision/model-server/requirements.txt @@ -1,3 +1,3 @@ tensorflow_hub==0.6.0 tensorflow==2.5.0 -Pillow==8.1.1 +Pillow==8.2.0