Skip to content

Commit

Permalink
Refactor API exception handling through HTTPException (#72)
Browse files Browse the repository at this point in the history
* Refactor API exception handling through HTTPException

* Move HTTPException handler into api_utils

* Cosmetic fix

* Remove redundant async
  • Loading branch information
dinvlad authored Jan 16, 2024
1 parent 2a3ec3d commit 3c9809c
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 19 deletions.
10 changes: 3 additions & 7 deletions src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from werkzeug.exceptions import HTTPException

from src import cli, signaling, status
from src.api_utils import handle_http_exception
from src.auth import register_terra_service_account
from src.utils import constants, custom_logging
from src.web import participants, study, web

logger = custom_logging.setup_logging(__name__)


def create_app() -> Quart:
if constants.TERRA:
logger.info("Creating app - on Terra")
Expand Down Expand Up @@ -49,12 +49,8 @@ async def _register_terra_service_account():
await register_terra_service_account()

@app.errorhandler(HTTPException)
async def handle_exception(e: HTTPException):
res = e.get_response()
if e.description:
res.data = json.dumps({ "error": e.description })
res.content_type = "application/json"
return res
async def _handle_http_exception(e: HTTPException):
return handle_http_exception(e)

return app

Expand Down
22 changes: 15 additions & 7 deletions src/api_utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
import json
import uuid
from urllib.parse import urlparse, urlunsplit

import httpx
import werkzeug.exceptions
from google.cloud.firestore_v1 import FieldFilter
from quart import current_app
from werkzeug.exceptions import HTTPException

from src.utils import constants, custom_logging

logger = custom_logging.setup_logging(__name__)


class APIException(werkzeug.exceptions.HTTPException):
def __init__(self, res: httpx.Response):
super().__init__(description=str(res.read()), response=res)


def get_websocket_origin():
url = urlparse(constants.SFKIT_API_URL)
scheme = "wss" if url.scheme == "https" else "ws"
Expand Down Expand Up @@ -104,3 +99,16 @@ def is_valid_uuid(val):
return True
except ValueError:
return False


def handle_http_exception(e: HTTPException):
res = e.get_response()
error = e.description
if not error:
if res.content_type == "application/json":
error = res.json()
error = error["message"] if "message" in error else str(error)
else:
error = str(res.read())
res.data = json.dumps({"error": error})
return res
6 changes: 3 additions & 3 deletions src/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from google.cloud import firestore
from jwt import algorithms
from quart import Request, Websocket, current_app, request
from werkzeug.exceptions import Unauthorized
from werkzeug.exceptions import HTTPException, Unauthorized

from src.api_utils import APIException, add_user_to_db
from src.api_utils import add_user_to_db
from src.utils import constants, custom_logging

logger = custom_logging.setup_logging(__name__)
Expand Down Expand Up @@ -108,7 +108,7 @@ async def register_terra_service_account():
)

if res.status_code not in (HTTPStatus.CREATED.value, HTTPStatus.CONFLICT.value):
raise APIException(res)
raise HTTPException(response=res)
else:
logger.info(res.json()["message"])

Expand Down
4 changes: 2 additions & 2 deletions src/utils/studies_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from quart import current_app, g
from sendgrid import SendGridAPIClient
from sendgrid.helpers.mail import Email, Mail
from werkzeug.exceptions import HTTPException

from src.api_utils import APIException
from src.auth import get_service_account_headers
from src.utils import constants, custom_logging
from src.utils.google_cloud.google_cloud_compute import (GoogleCloudCompute,
Expand Down Expand Up @@ -170,7 +170,7 @@ async def _terra_rawls_post(path: str, json: Dict[str, Any]):
json=json,
)
if res.status_code != HTTPStatus.CREATED.value:
raise APIException(res)
raise HTTPException(response=res)


async def submit_terra_workflow(study_id: str, _role: str) -> None:
Expand Down

0 comments on commit 3c9809c

Please sign in to comment.