From cc3188f9dfdc172fc35aeef1797f28603fb70a91 Mon Sep 17 00:00:00 2001 From: Brannon Dorsey Date: Tue, 30 Apr 2019 00:07:08 -0400 Subject: [PATCH 1/3] Always return application/json. Sniff JSON in POST requests and respond with 400 if none found. --- runway/model.py | 59 +++++++++--- runway/utils.py | 15 +++- tests/test_model.py | 214 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 275 insertions(+), 13 deletions(-) diff --git a/runway/model.py b/runway/model.py index 9eba787..ff9aed1 100644 --- a/runway/model.py +++ b/runway/model.py @@ -5,13 +5,14 @@ import traceback import json from six import reraise -from flask import Flask, request +from flask import Flask, request, jsonify from flask_cors import CORS from gevent.pywsgi import WSGIServer from .exceptions import RunwayError, MissingInputError, MissingOptionError, \ InferenceError, UnknownCommandError, SetupError from .data_types import * -from .utils import gzipped, serialize_command, cast_to_obj +from .utils import gzipped, serialize_command, cast_to_obj, \ + validate_post_request_body_is_json, get_json_or_none_if_invalid class RunwayModel(object): """A Runway Model server. A singleton instance of this class is created automatically @@ -27,12 +28,44 @@ def __init__(self): self.running_status = 'STARTING' self.app = Flask(__name__) CORS(self.app) + self.define_error_handlers() self.define_routes() + def define_error_handlers(self): + + # not yet implemented, but if and when it is lets make sure its returned + # as JSON + @self.app.errorhandler(401) + def unauthorized(e): + msg = 'Unauthorized (well... ' + msg += 'really unauthenticated but hey I didn\'t write the spec).' + return jsonify(dict(error=msg)), 401 + + # not yet implemented, but if and when it is lets make sure its returned + # as JSON + @self.app.errorhandler(403) + def forbidden(e): + return jsonify(dict(error='Forbidden.')), 403 + + @self.app.errorhandler(404) + def page_not_found(e): + return jsonify(dict(error='Not found.')), 404 + + @self.app.errorhandler(405) + def method_not_allowed(e): + return jsonify(dict(error='Method not allowed.')), 405 + + # we shouldn't have any of these as we are wrapping errors in + # RunwayError objects and returning stacktraces, but it can't hurt + # to be safe. + @self.app.errorhandler(500) + def internal_server_error(e): + return jsonify(dict(error='Internal server error.')), 500 + def define_routes(self): @self.app.route('/') def manifest(): - return json.dumps(dict( + return jsonify(dict( options=[opt.to_dict() for opt in self.options], commands=[serialize_command(cmd) for cmd in self.commands.values()] )) @@ -42,20 +75,22 @@ def healthcheck_route(): return self.running_status @self.app.route('/setup', methods=['POST']) + @validate_post_request_body_is_json def setup_route(): - opts = request.json + opts = get_json_or_none_if_invalid(request) try: self.setup_model(opts) - return json.dumps(dict(success=True)) + return jsonify(dict(success=True)) except RunwayError as err: err.print_exception() - return json.dumps(err.to_response()), err.code + return jsonify(err.to_response()), err.code @self.app.route('/setup', methods=['GET']) def setup_options_route(): - return json.dumps(self.options) + return jsonify(self.options) @self.app.route('/', methods=['POST']) + @validate_post_request_body_is_json def command_route(command_name): try: try: @@ -64,7 +99,7 @@ def command_route(command_name): raise UnknownCommandError(command_name) inputs = self.commands[command_name]['inputs'] outputs = self.commands[command_name]['outputs'] - input_dict = request.json + input_dict = get_json_or_none_if_invalid(request) deserialized_inputs = {} for inp in inputs: name = inp.name @@ -86,11 +121,11 @@ def command_route(command_name): for out in outputs: name = out.to_dict()['name'] serialized_outputs[name] = out.serialize(results[name]) - return json.dumps(serialized_outputs).encode('utf8') + return jsonify(json.loads(json.dumps(serialized_outputs).encode('utf8'))) except RunwayError as err: err.print_exception() - return json.dumps(err.to_response()), err.code + return jsonify(err.to_response()), err.code @self.app.route('/', methods=['GET']) def usage_route(command_name): @@ -99,10 +134,10 @@ def usage_route(command_name): command = self.commands[command_name] except KeyError: raise UnknownCommandError(command_name) - return json.dumps(serialize_command(command)) + return jsonify(serialize_command(command)) except RunwayError as err: err.print_exception() - return json.dumps(err.to_response()) + return jsonify(err.to_response()), err.code def setup(self, decorated_fn=None, options=None): """This decorator is used to wrap your own ``setup()`` (or equivalent) diff --git a/runway/utils.py b/runway/utils.py index 486f197..d088e2e 100644 --- a/runway/utils.py +++ b/runway/utils.py @@ -12,7 +12,7 @@ else: from io import BytesIO as IO import numpy as np -from flask import after_this_request, request +from flask import after_this_request, request, jsonify URL_REGEX = re.compile( @@ -23,6 +23,19 @@ r'(?::\d+)?' # optional port r'(?:/?|[/?]\S+)$', re.IGNORECASE) +def validate_post_request_body_is_json(f): + @functools.wraps(f) + def wrapped(*args, **kwargs): + json = get_json_or_none_if_invalid(request) + if json is not None: + return f(*args, **kwargs) + else: + err_msg = 'The body of all POST requests must contain JSON' + return jsonify(dict(error=err_msg)), 400 + return wrapped + +def get_json_or_none_if_invalid(request): + return request.get_json(force=True, silent=True) def serialize_command(cmd): ret = {} diff --git a/tests/test_model.py b/tests/test_model.py index b3b2a60..e440f1a 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -12,6 +12,7 @@ from runway.exceptions import * from utils import get_test_client from deepdiff import DeepDiff +from flask import abort os.environ['RW_NO_SERVE'] = '1' @@ -68,11 +69,15 @@ def test_command(model, opts): # check the manifest via a GET / response = client.get('/') + assert response.is_json + manifest = json.loads(response.data) assert manifest == expected_manifest # check the input/output manifest for GET /test_command response = client.get('/test_command') + assert response.is_json + command_manifest = json.loads(response.data) assert command_manifest == expected_manifest['commands'][0] @@ -80,6 +85,7 @@ def test_command(model, opts): 'input': 'test input' } response = client.post('/test_command', json=post_data) + assert response.is_json assert json.loads(response.data) == { 'output' : 100 } assert closure['command_ran'] == True @@ -263,6 +269,213 @@ def command_2(opts): os.environ['RW_META'] = '0' +def test_post_setup_json_no_mime_type(): + + rw = RunwayModel() + + @rw.setup(options={'input': text}) + def setup(opts): + pass + + rw.run(debug=True) + + client = get_test_client(rw) + response = client.post('/setup', data='{"input": "test input"}') + assert response.is_json + assert json.loads(response.data) == { 'success': True } + +def test_post_setup_invalid_json_no_mime_type(): + + rw = RunwayModel() + + @rw.setup(options={'input': text}) + def setup(opts): + pass + + rw.run(debug=True) + + client = get_test_client(rw) + response = client.post('/setup', data='{"input": test input"}') + + assert response.is_json + assert response.status_code == 400 + + expect = { 'error': 'The body of all POST requests must contain JSON' } + assert json.loads(response.data) == expect + + +def test_post_setup_json_mime_type(): + + rw = RunwayModel() + + @rw.setup(options={'input': text}) + def setup(opts): + pass + + rw.run(debug=True) + + client = get_test_client(rw) + response = client.post('/setup', json={ 'input': 'test input' }) + assert response.is_json + assert json.loads(response.data) == { 'success': True } + +def test_post_setup_form_encoding(): + + rw = RunwayModel() + + @rw.setup(options={'input': text}) + def setup(opts): + pass + + rw.run(debug=True) + + client = get_test_client(rw) + + content_type='application/x-www-form-urlencoded' + response = client.post('/setup', data='input=test', content_type=content_type) + + assert response.is_json + assert response.status_code == 400 + + expect = { 'error': 'The body of all POST requests must contain JSON' } + assert json.loads(response.data) == expect + +def test_post_command_json_no_mime_type(): + + rw = RunwayModel() + + @rw.command('times_two', inputs={ 'input': number }, outputs={ 'output': number }) + def times_two(model, args): + return args['input'] * 2 + + rw.run(debug=True) + + client = get_test_client(rw) + response = client.post('/times_two', data='{ "input": 5 }') + assert response.is_json + assert json.loads(response.data) == { 'output': 10 } + +def test_post_command_json_mime_type(): + + rw = RunwayModel() + + @rw.command('times_two', inputs={ 'input': number }, outputs={ 'output': number }) + def times_two(model, args): + return args['input'] * 2 + + rw.run(debug=True) + + client = get_test_client(rw) + response = client.post('/times_two', json={ 'input': 5 }) + assert response.is_json + assert json.loads(response.data) == { 'output': 10 } + +def test_post_command_form_encoding(): + + rw = RunwayModel() + + @rw.command('times_two', inputs={ 'input': number }, outputs={ 'output': number }) + def times_two(model, args): + return args['input'] * 2 + + rw.run(debug=True) + + client = get_test_client(rw) + + content_type='application/x-www-form-urlencoded' + response = client.post('/times_two', data='input=5', content_type=content_type) + assert response.is_json + assert response.status_code == 400 + + expect = { 'error': 'The body of all POST requests must contain JSON' } + assert json.loads(response.data) == expect + +def test_405_method_not_allowed(): + + rw = RunwayModel() + + @rw.setup(options={'input': text}) + def setup(opts): + pass + + rw.run(debug=True) + + client = get_test_client(rw) + response = client.put('/setup', json= { 'input': 'test input'}) + + assert response.is_json + assert response.status_code == 405 + assert response.json == { 'error': 'Method not allowed.' } + +def test_404_not_found(): + + rw = RunwayModel() + + rw.run(debug=True) + + client = get_test_client(rw) + response = client.get('/asfd') + + assert response.is_json + assert response.status_code == 404 + +def test_401_unauthorized(): + + rw = RunwayModel() + + @rw.app.route('/test/unauthorized') + def unauthorized(): + abort(401) + + rw.run(debug=True) + + client = get_test_client(rw) + response = client.get('/test/unauthorized') + + assert response.is_json + assert response.status_code == 401 + + expect = { 'error': 'Unauthorized (well... really unauthenticated but hey I didn\'t write the spec).' } + assert response.json == expect + +def test_403_forbidden(): + + rw = RunwayModel() + + @rw.app.route('/test/forbidden') + def unauthorized(): + abort(403) + + rw.run(debug=True) + + client = get_test_client(rw) + response = client.get('/test/forbidden') + + assert response.is_json + assert response.status_code == 403 + + expect = { 'error': 'Forbidden.' } + assert response.json == expect + +def test_500_internal_server_error(): + + rw = RunwayModel() + + @rw.app.route('/test/internal_server_error') + def unauthorized(): + abort(500) + + rw.run(debug=True) + + client = get_test_client(rw) + response = client.get('/test/internal_server_error') + + assert response.is_json + assert response.status_code == 500 + + expect = { 'error': 'Internal server error.' } + assert response.json == expect + def test_setup_error_setup_no_args(): rw = RunwayModel() @@ -300,4 +513,5 @@ def test_command(model, inputs): rw.run(debug=True) response = client.post('test_command', json={ 'input': 5 }) + assert response.is_json assert 'InferenceError' in str(response.data) From a288ca8e8ac327da8e67d67dfec5dafada4fb6aa Mon Sep 17 00:00:00 2001 From: Brannon Dorsey Date: Tue, 30 Apr 2019 12:25:23 -0400 Subject: [PATCH 2/3] Remove forced UTF-8 encoding in command output. --- runway/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runway/model.py b/runway/model.py index ff9aed1..ffe9d6e 100644 --- a/runway/model.py +++ b/runway/model.py @@ -121,7 +121,7 @@ def command_route(command_name): for out in outputs: name = out.to_dict()['name'] serialized_outputs[name] = out.serialize(results[name]) - return jsonify(json.loads(json.dumps(serialized_outputs).encode('utf8'))) + return jsonify(serialized_outputs) except RunwayError as err: err.print_exception() From c2142be9c8e53cb9c300d3d66e4449e43523f42e Mon Sep 17 00:00:00 2001 From: Brannon Dorsey Date: Tue, 30 Apr 2019 12:44:55 -0400 Subject: [PATCH 3/3] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c3a0dae..443069c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ ## v0.0.70 +- Model server now wraps common server error codes in JSON responses (e.g. 401, 403, 404, 405, 500). +- Model server now sniffs the body of `POST` for JSON even if `content-type: application/json` is not sent in the request header. [#1](https://github.com/runwayml/model-sdk/issues/1) +- Model server now returns `content-type: application/json`. [#6](https://github.com/runwayml/model-sdk/issues/6) - Add `RW_NO_SERVE` environment variable and `no_serve` keyword argument to `runway.run()`. These settings prevent `runway.run()` from starting the Flask server so that mock HTTP requests can be made via `app.test_client()`. See [Testing Flask Applications](http://flask.pocoo.org/docs/1.0/testing/) for more details. - Add model tests in [`tests/test_model.py`](tests/test_model.py) - Minor change to `docs/` so that JavaScript HTTP -> HTTPS redirect only occurs when the protocol is actually `http:`.