Skip to content
This repository has been archived by the owner on Mar 28, 2022. It is now read-only.

Better JSON handling in requests and responses #45

Merged
merged 3 commits into from
Apr 30, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 47 additions & 12 deletions runway/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import traceback
import json
from six import reraise
from flask import Flask, request
from flask import Flask, request, jsonify
from flask_cors import CORS
from gevent.pywsgi import WSGIServer
from .exceptions import RunwayError, MissingInputError, MissingOptionError, \
InferenceError, UnknownCommandError, SetupError
from .data_types import *
from .utils import gzipped, serialize_command, cast_to_obj
from .utils import gzipped, serialize_command, cast_to_obj, \
validate_post_request_body_is_json, get_json_or_none_if_invalid

class RunwayModel(object):
"""A Runway Model server. A singleton instance of this class is created automatically
Expand All @@ -27,12 +28,44 @@ def __init__(self):
self.running_status = 'STARTING'
self.app = Flask(__name__)
CORS(self.app)
self.define_error_handlers()
self.define_routes()

def define_error_handlers(self):

# not yet implemented, but if and when it is lets make sure its returned
# as JSON
@self.app.errorhandler(401)
def unauthorized(e):
msg = 'Unauthorized (well... '
msg += 'really unauthenticated but hey I didn\'t write the spec).'
return jsonify(dict(error=msg)), 401

# not yet implemented, but if and when it is lets make sure its returned
# as JSON
@self.app.errorhandler(403)
def forbidden(e):
return jsonify(dict(error='Forbidden.')), 403

@self.app.errorhandler(404)
def page_not_found(e):
return jsonify(dict(error='Not found.')), 404

@self.app.errorhandler(405)
def method_not_allowed(e):
return jsonify(dict(error='Method not allowed.')), 405

# we shouldn't have any of these as we are wrapping errors in
# RunwayError objects and returning stacktraces, but it can't hurt
# to be safe.
@self.app.errorhandler(500)
def internal_server_error(e):
return jsonify(dict(error='Internal server error.')), 500

def define_routes(self):
@self.app.route('/')
def manifest():
return json.dumps(dict(
return jsonify(dict(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flask's jsonify() is a helpful replacement for json.dumps() that wraps its parameter in a Flask Response object and sets the MIME type to application/json.

options=[opt.to_dict() for opt in self.options],
commands=[serialize_command(cmd) for cmd in self.commands.values()]
))
Expand All @@ -42,20 +75,22 @@ def healthcheck_route():
return self.running_status

@self.app.route('/setup', methods=['POST'])
@validate_post_request_body_is_json
def setup_route():
opts = request.json
opts = get_json_or_none_if_invalid(request)
try:
self.setup_model(opts)
return json.dumps(dict(success=True))
return jsonify(dict(success=True))
except RunwayError as err:
err.print_exception()
return json.dumps(err.to_response()), err.code
return jsonify(err.to_response()), err.code

@self.app.route('/setup', methods=['GET'])
def setup_options_route():
return json.dumps(self.options)
return jsonify(self.options)

@self.app.route('/<command_name>', methods=['POST'])
@validate_post_request_body_is_json
def command_route(command_name):
try:
try:
Expand All @@ -64,7 +99,7 @@ def command_route(command_name):
raise UnknownCommandError(command_name)
inputs = self.commands[command_name]['inputs']
outputs = self.commands[command_name]['outputs']
input_dict = request.json
input_dict = get_json_or_none_if_invalid(request)
deserialized_inputs = {}
for inp in inputs:
name = inp.name
Expand All @@ -86,11 +121,11 @@ def command_route(command_name):
for out in outputs:
name = out.to_dict()['name']
serialized_outputs[name] = out.serialize(results[name])
return json.dumps(serialized_outputs).encode('utf8')
return jsonify(json.loads(json.dumps(serialized_outputs).encode('utf8')))
brannondorsey marked this conversation as resolved.
Show resolved Hide resolved

except RunwayError as err:
err.print_exception()
return json.dumps(err.to_response()), err.code
return jsonify(err.to_response()), err.code

@self.app.route('/<command_name>', methods=['GET'])
def usage_route(command_name):
Expand All @@ -99,10 +134,10 @@ def usage_route(command_name):
command = self.commands[command_name]
except KeyError:
raise UnknownCommandError(command_name)
return json.dumps(serialize_command(command))
return jsonify(serialize_command(command))
except RunwayError as err:
err.print_exception()
return json.dumps(err.to_response())
return jsonify(err.to_response()), err.code

def setup(self, decorated_fn=None, options=None):
"""This decorator is used to wrap your own ``setup()`` (or equivalent)
Expand Down
15 changes: 14 additions & 1 deletion runway/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
else:
from io import BytesIO as IO
import numpy as np
from flask import after_this_request, request
from flask import after_this_request, request, jsonify


URL_REGEX = re.compile(
Expand All @@ -23,6 +23,19 @@
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)

def validate_post_request_body_is_json(f):
@functools.wraps(f)
def wrapped(*args, **kwargs):
json = get_json_or_none_if_invalid(request)
if json is not None:
return f(*args, **kwargs)
else:
err_msg = 'The body of all POST requests must contain JSON'
return jsonify(dict(error=err_msg)), 400
return wrapped
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A decorator to wrap around POST endpoints to ensure that requests that are sent to them contain JSON bodies. If they don't we immediately return a JSON error message and a 400 error code.


def get_json_or_none_if_invalid(request):
return request.get_json(force=True, silent=True)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a thin wrapper around Flask's request.get_json() but it ensures that we are calling it such that it will sniff JSON even if the content-type: application/json MIME type is not set, and will return None if the sniff fails instead of throwing an error.

We can then use this util function elsewhere to make sure that we are parsing JSON from Flask response objects in a consistent way.


def serialize_command(cmd):
ret = {}
Expand Down
Loading