From 461ed4b61f76735adaed278618beee842f6ad172 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexis=20M=C3=A9taireau?= Date: Thu, 3 Jan 2013 17:24:40 +0100 Subject: [PATCH 1/2] Add Cross-Origin Resource Sharing (CORS) support. --- CHANGES.txt | 2 +- cornice/cors.py | 128 +++++++++++++++++++++++ cornice/pyramidhook.py | 27 ++++- cornice/service.py | 170 ++++++++++++++++++++++++------ cornice/tests/test_cors.py | 192 ++++++++++++++++++++++++++++++++++ cornice/tests/test_service.py | 138 +++++++++++++++++++++++- 6 files changed, 619 insertions(+), 38 deletions(-) create mode 100644 cornice/cors.py create mode 100644 cornice/tests/test_cors.py diff --git a/CHANGES.txt b/CHANGES.txt index 7640675a..712f5c06 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,7 +1,7 @@ 0.13 - XXXX-XX-XX ================= -- ??? +- Added Cross-Origin Resource Sharing (CORS) support. 0.12 - 2012-11-21 ================= diff --git a/cornice/cors.py b/cornice/cors.py new file mode 100644 index 00000000..3d308c52 --- /dev/null +++ b/cornice/cors.py @@ -0,0 +1,128 @@ +import fnmatch + + +CORS_PARAMETERS = ('cors_headers', 'cors_enabled', 'cors_origins', + 'cors_credentials', 'cors_max_age', + 'cors_expose_all_headers') + + +def get_cors_preflight_view(service): + """Return a view for the OPTION method. + + Checks that the User-Agent is authorized to do a request to the server, and + to this particular service, and add the various checks that are specified + in http://www.w3.org/TR/cors/#resource-processing-model. + """ + + def _preflight_view(request): + response = request.response + origin = request.headers.get('Origin') + supported_headers = service.cors_supported_headers + + if not origin: + request.errors.add('header', 'Origin', + 'this header is mandatory') + + requested_method = request.headers.get('Access-Control-Request-Method') + if not requested_method: + request.errors.add('header', 'Access-Control-Request-Method', + 'this header is mandatory') + + if not (requested_method and origin): + return + + requested_headers = ( + request.headers.get('Access-Control-Request-Headers', ())) + + if requested_headers: + requested_headers = requested_headers.split(',') + + if requested_method not in service.cors_supported_methods: + request.errors.add('header', 'Access-Control-Request-Method', + 'Method not allowed') + + if not service.cors_expose_all_headers: + for h in requested_headers: + if not h.lower() in [s.lower() for s in supported_headers]: + request.errors.add( + 'header', + 'Access-Control-Request-Headers', + 'Header "%s" not allowed' % h) + + supported_headers = set(supported_headers) | set(requested_headers) + + response.headers['Access-Control-Allow-Headers'] = ( + ','.join(supported_headers)) + + response.headers['Access-Control-Allow-Methods'] = ( + ','.join(service.cors_supported_methods)) + + max_age = service.cors_max_age_for(requested_method) + if max_age is not None: + response.headers['Access-Control-Max-Age'] = str(max_age) + + return 'ok' + return _preflight_view + + +def _get_method(request): + """Return what's supposed to be the method for CORS operations. + (e.g if the verb is options, look at the A-C-Request-Method header, + otherwise return the HTTP verb). + """ + if request.method == 'OPTIONS': + method = request.headers.get('Access-Control-Request-Method', + request.method) + else: + method = request.method + return method + + +def get_cors_validator(service): + """Create a cornice validator to handle CORS-related verifications. + + Checks, if an "Origin" header is present, that the origin is authorized + (and issue an error if not) + """ + + def _cors_validator(request): + response = request.response + method = _get_method(request) + + # If we have an "Origin" header, check it's authorized and add the + # response headers accordingly. + origin = request.headers.get('Origin') + if origin: + if not any([fnmatch.fnmatchcase(origin, o) + for o in service.cors_origins_for(method)]): + request.errors.add('header', 'Origin', + '%s not allowed' % origin) + else: + response.headers['Access-Control-Allow-Origin'] = origin + return _cors_validator + + +def get_cors_filter(service): + """Create a cornice filter to handle CORS-related post-request + things. + + Add some response headers, such as the Expose-Headers and the + Allow-Credentials ones. + """ + + def _cors_filter(response, request): + method = _get_method(request) + + if (service.cors_support_credentials(method) and + not 'Access-Control-Allow-Credentials' in response.headers): + response.headers['Access-Control-Allow-Credentials'] = 'true' + + if request.method is not 'OPTIONS': + # Which headers are exposed? + supported_headers = service.cors_supported_headers + if supported_headers: + response.headers['Access-Control-Expose-Headers'] = ( + ', '.join(supported_headers)) + + return response + return _cors_filter diff --git a/cornice/pyramidhook.py b/cornice/pyramidhook.py index e7723851..165db8fe 100644 --- a/cornice/pyramidhook.py +++ b/cornice/pyramidhook.py @@ -3,6 +3,7 @@ # You can obtain one at http://mozilla.org/MPL/2.0/. import json import functools +import copy from pyramid.httpexceptions import HTTPMethodNotAllowed, HTTPNotAcceptable from pyramid.exceptions import PredicateMismatch @@ -10,6 +11,8 @@ from cornice.service import decorate_view from cornice.errors import Errors from cornice.util import to_list +from cornice.cors import (get_cors_filter, get_cors_validator, + get_cors_preflight_view, CORS_PARAMETERS) def match_accept_header(func, context, request): @@ -52,7 +55,7 @@ def _fallback_view(request): continue if 'accept' in args: acceptable.extend( - service.get_acceptable(method, filter_callables=True)) + service.get_acceptable(method, filter_callables=True)) if 'acceptable' in request.info: for content_type in request.info['acceptable']: if content_type not in acceptable: @@ -85,7 +88,10 @@ def cornice_tween(request): for _filter in kwargs.get('filters', []): if isinstance(_filter, basestring) and ob is not None: _filter = getattr(ob, _filter) - response = _filter(response) + try: + response = _filter(response, request) + except TypeError: + response = _filter(response) return response return cornice_tween @@ -117,16 +123,29 @@ def register_service_views(config, service): # keep track of the registered routes registered_routes = [] + # before doing anything else, register a view for the OPTIONS method + # if we need to + if service.cors_enabled and 'OPTIONS' not in service.defined_methods: + service.add_view('options', view=get_cors_preflight_view(service)) + # register the fallback view, which takes care of returning good error # messages to the user-agent + cors_validator = get_cors_validator(service) + cors_filter = get_cors_filter(service) + for method, view, args in service.definitions: - args = dict(args) # make a copy of the dict to not modify it + args = copy.deepcopy(args) # make a copy of the dict to not modify it args['request_method'] = method + if service.cors_enabled: + args['validators'].insert(0, cors_validator) + args['filters'].append(cors_filter) + decorated_view = decorate_view(view, dict(args), method) + for item in ('filters', 'validators', 'schema', 'klass', - 'error_handler'): + 'error_handler') + CORS_PARAMETERS: if item in args: del args[item] diff --git a/cornice/service.py b/cornice/service.py index 7dc4c2df..6df5eb51 100644 --- a/cornice/service.py +++ b/cornice/service.py @@ -46,43 +46,76 @@ class Service(object): All the class attributes defined in this class or in childs are considered default values. - :param name: the name of the service. Should be unique among all the - services. + :param name: + The name of the service. Should be unique among all the services. - :param path: the path the service is available at. Should also be unique. + :param path: + The path the service is available at. Should also be unique. - :param renderer: the renderer that should be used by this service. Default - value is 'simplejson'. + :param renderer: + The renderer that should be used by this service. Default value is + 'simplejson'. - :param description: the description of what the webservice does. This is - primarily intended for documentation purposes. + :param description: + The description of what the webservice does. This is primarily intended + for documentation purposes. - :param validators: a list of callables to pass the request into before - passing it to the associated view. + :param validators: + A list of callables to pass the request into before passing it to the + associated view. - :param filters: a list of callables to pass the response into before - returning it to the client. + :param filters: + A list of callables to pass the response into before returning it to + the client. - :param accept: a list of headers accepted for this service (or method if - overwritten when defining a method). It can also be a - callable, in which case the content-type will be discovered - at runtime. If a callable is passed, it should be able to - take the request as a first argument. + :param accept: + A list of headers accepted for this service (or method if overwritten + when defining a method). It can also be a callable, in which case the + content-type will be discovered at runtime. If a callable is passed, it + should be able to take the request as a first argument. - :param factory: A factory returning callables which return boolean values. - The callables take the request as their first argument and - return boolean values. - This param is exclusive with the 'acl' one. + :param factory: + A factory returning callables which return boolean values. The + callables take the request as their first argument and return boolean + values. This param is exclusive with the 'acl' one. - :param acl: a callable defininng the ACL (returns true or false, function - of the given request). Exclusive with the 'factory' option. + :param acl: + A callable defininng the ACL (returns true or false, function of the + given request). Exclusive with the 'factory' option. - :param klass: the class to use when resolving views (if they are not - callables) + :param klass: + The class to use when resolving views (if they are not callables) - :param error_handler: (optional) A callable which is used to render - responses following validation failures. Defaults to - 'json_renderer'. + :param error_handler: + A callable which is used to render responses following validation + failures. Defaults to 'json_renderer'. + + There is also a number of parameters that are related to the support of + CORS (Cross Origin Resource Sharing). You can read the CORS specification + at http://www.w3.org/TR/cors/ + + :param cors_enabled: + To use if you especially want to disable CORS support for a particular + service / method. + + :param cors_origins: + The list of origins for CORS. You can use wildcards here if needed, + e.g. ('list', 'of', '*.domain'). + + :param cors_headers: + The list of headers supported for the services. + + :param cors_credentials: + Should the client send credential information (False by default). + + :param cors_max_age: + Indicates how long the results of a preflight request can be cached in + a preflight result cache. + + :param cors_expose_all_headers: + If set to True, all the headers will be exposed and considered valid + ones (Default: True). If set to False, all the headers need be + explicitely mentionned with the cors_headers parameter. See http://readthedocs.org/docs/pyramid/en/1.0-branch/glossary.html#term-acl @@ -97,7 +130,7 @@ class Service(object): default_filters = DEFAULT_FILTERS mandatory_arguments = ('renderer',) - list_arguments = ('validators', 'filters') + list_arguments = ('validators', 'filters', 'cors_headers', 'cors_origins') def __repr__(self): return u'' % (self.name, self.path) @@ -106,14 +139,16 @@ def __init__(self, name, path, description=None, depth=1, **kw): self.name = name self.path = path self.description = description + self.cors_expose_all_headers = True self._schemas = {} + self._cors_enabled = None - for key in ('validators', 'filters'): + for key in self.list_arguments: # default_{validators,filters} and {filters,validators} doesn't # have to be mutables, so we need to create a new list from them extra = to_list(kw.get(key, [])) kw[key] = [] - kw[key].extend(getattr(self, 'default_%s' % key)) + kw[key].extend(getattr(self, 'default_%s' % key, [])) kw[key].extend(extra) self.arguments = self.get_arguments(kw) @@ -221,6 +256,7 @@ def add_view(self, method, view, **kwargs): if hasattr(self, 'get_view_wrapper'): view = self.get_view_wrapper(kwargs)(view) self.definitions.append((method, view, args)) + # keep track of the defined methods for the service if method not in self.defined_methods: self.defined_methods.append(method) @@ -299,6 +335,80 @@ def schemas(self): warnings.warn(msg, DeprecationWarning) return self._schemas + @property + def cors_enabled(self): + if self._cors_enabled is False: + return False + + return bool(self.cors_origins or self._cors_enabled) + + @cors_enabled.setter + def cors_enabled(self, value): + self._cors_enabled = value + + + + @property + def cors_supported_headers(self): + """Return an iterable of supported headers for this service. + + The supported headers are defined by the :param headers: argument + that is passed to services or methods, at definition time. + """ + headers = set() + for _, _, args in self.definitions: + if args.get('cors_enabled', True): + headers |= set(args.get('cors_headers', ())) + return headers + + @property + def cors_supported_methods(self): + """Return an iterable of methods supported by CORS""" + methods = [] + for meth, _, args in self.definitions: + if args.get('cors_enabled', True) and meth not in methods: + methods.append(meth) + return methods + + @property + def cors_supported_origins(self): + origins = set(getattr(self, 'cors_origins', ())) + for _, _, args in self.definitions: + origins |= set(args.get('cors_origins', ())) + return origins + + def cors_origins_for(self, method): + """Return the list of origins supported for a given HTTP method""" + origins = set() + for meth, view, args in self.definitions: + if meth.upper() == method.upper(): + origins |= set(args.get('cors_origins', ())) + + if not origins: + origins = self.cors_origins + return origins + + def cors_support_credentials(self, method=None): + """Returns if the given method support credentials. + + :param method: + The method to check the credentials support for + """ + for meth, view, args in self.definitions: + if meth.upper() == method.upper(): + return args.get('cors_credentials', False) + + if getattr(self, 'cors_credentials', False): + return self.cors_credentials + return False + + def cors_max_age_for(self, method=None): + for meth, view, args in self.definitions: + if meth.upper() == method.upper(): + return args.get('cors_max_age', False) + + return getattr(self, 'cors_max_age', None) + def decorate_view(view, args, method): """Decorate a given view with cornice niceties. diff --git a/cornice/tests/test_cors.py b/cornice/tests/test_cors.py new file mode 100644 index 00000000..9d2f3d27 --- /dev/null +++ b/cornice/tests/test_cors.py @@ -0,0 +1,192 @@ +from pyramid import testing +from webtest import TestApp + +from cornice.service import Service +from cornice.tests.support import TestCase, CatchErrors + + +squirel = Service(path='/squirel', name='squirel', cors_origins=('foobar',)) +spam = Service(path='/spam', name='spam', cors_origins=('*',)) +eggs = Service(path='/eggs', name='egg', cors_origins=('*'), + cors_expose_all_headers=False) + + +@squirel.get(cors_origins=('notmyidea.org',)) +def get_squirel(request): + return "squirels" + + +@squirel.post(cors_enabled=False, cors_headers=('X-Another-Header')) +def post_squirel(request): + return "moar squirels (take care)" + + +@squirel.put(cors_headers=('X-My-Header',)) +def put_squirel(request): + return "squirels!" + + +@spam.get(cors_credentials=True, cors_headers=('X-My-Header'), + cors_max_age=42) +def gimme_some_spam_please(request): + return 'spam' + + +@spam.post() +def moar_spam(request): + return 'moar spam' + + +class TestCORS(TestCase): + + def setUp(self): + self.config = testing.setUp() + self.config.include("cornice") + self.config.scan("cornice.tests.test_cors") + self.app = TestApp(CatchErrors(self.config.make_wsgi_app())) + + def tearDown(self): + testing.tearDown() + + def test_preflight_missing_headers(self): + # we should have an OPTION method defined. + # If we just try to reach it, without using correct headers: + # "Access-Control-Request-Method"or without the "Origin" header, + # we should get a 400. + resp = self.app.options('/squirel', status=400) + self.assertEquals(len(resp.json['errors']), 2) + + def test_preflight_missing_origin(self): + + resp = self.app.options( + '/squirel', + headers={'Access-Control-Request-Method': 'GET'}, + status=400) + self.assertEquals(len(resp.json['errors']), 1) + + def test_preflight_missing_request_method(self): + + resp = self.app.options( + '/squirel', + headers={'Origin': 'foobar.org'}, + status=400) + + self.assertEquals(len(resp.json['errors']), 1) + + def test_preflight_incorrect_origin(self): + # we put "lolnet.org" where only "notmyidea.org" is authorized + resp = self.app.options( + '/squirel', + headers={'Origin': 'lolnet.org', + 'Access-Control-Request-Method': 'GET'}, + status=400) + self.assertEquals(len(resp.json['errors']), 1) + + def test_preflight_correct_origin(self): + resp = self.app.options( + '/squirel', + headers={'Origin': 'notmyidea.org', + 'Access-Control-Request-Method': 'GET'}) + self.assertEquals( + resp.headers['Access-Control-Allow-Origin'], + 'notmyidea.org') + + allowed_methods = (resp.headers['Access-Control-Allow-Methods'] + .split(',')) + + self.assertNotIn('POST', allowed_methods) + self.assertIn('GET', allowed_methods) + self.assertIn('PUT', allowed_methods) + self.assertIn('HEAD', allowed_methods) + + allowed_headers = (resp.headers['Access-Control-Allow-Headers'] + .split(',')) + + self.assertIn('X-My-Header', allowed_headers) + self.assertNotIn('X-Another-Header', allowed_headers) + + def test_preflight_deactivated_method(self): + self.app.options('/squirel', + headers={'Origin': 'notmyidea.org', + 'Access-Control-Request-Method': 'POST'}, + status=400) + + def test_preflight_origin_not_allowed_for_method(self): + self.app.options('/squirel', + headers={'Origin': 'notmyidea.org', + 'Access-Control-Request-Method': 'PUT'}, + status=400) + + def test_preflight_credentials_are_supported(self): + resp = self.app.options('/spam', + headers={'Origin': 'notmyidea.org', + 'Access-Control-Request-Method': 'GET'}) + + self.assertIn('Access-Control-Allow-Credentials', resp.headers) + self.assertEquals(resp.headers['Access-Control-Allow-Credentials'], + 'true') + + def test_preflight_credentials_header_not_included_when_not_needed(self): + resp = self.app.options('/spam', + headers={'Origin': 'notmyidea.org', + 'Access-Control-Request-Method': 'POST'}) + + self.assertNotIn('Access-Control-Allow-Credentials', resp.headers) + + def test_preflight_contains_max_age(self): + resp = self.app.options('/spam', + headers={'Origin': 'notmyidea.org', + 'Access-Control-Request-Method': 'GET'}) + + self.assertIn('Access-Control-Max-Age', resp.headers) + self.assertEquals(resp.headers['Access-Control-Max-Age'], '42') + + def test_resp_dont_include_allow_origin(self): + resp = self.app.get('/squirel') # omit the Origin header + self.assertNotIn('Access-Control-Allow-Origin', resp.headers) + self.assertEquals(resp.json, 'squirels') + + def test_responses_include_an_allow_origin_header(self): + resp = self.app.get('/squirel', headers={'Origin': 'notmyidea.org'}) + self.assertIn('Access-Control-Allow-Origin', resp.headers) + self.assertEquals(resp.headers['Access-Control-Allow-Origin'], + 'notmyidea.org') + + def test_credentials_are_included(self): + resp = self.app.get('/spam', headers={'Origin': 'notmyidea.org'}) + self.assertIn('Access-Control-Allow-Credentials', resp.headers) + self.assertEquals(resp.headers['Access-Control-Allow-Credentials'], + 'true') + + def test_headers_are_exposed(self): + resp = self.app.get('/squirel', headers={'Origin': 'notmyidea.org'}) + self.assertIn('Access-Control-Expose-Headers', resp.headers) + + headers = resp.headers['Access-Control-Expose-Headers'].split(',') + self.assertIn('X-My-Header', headers) + + def test_preflight_request_headers_are_included(self): + resp = self.app.options('/squirel', + headers={'Origin': 'notmyidea.org', + 'Access-Control-Request-Method': 'GET', + 'Access-Control-Request-Headers': 'foo,bar,baz'}) + + # per default, they should be authorized, and returned in the list of + # authorized headers + headers = resp.headers['Access-Control-Allow-Headers'].split(',') + self.assertIn('foo', headers) + self.assertIn('bar', headers) + self.assertIn('baz', headers) + + def test_preflight_request_headers_isnt_too_permissive(self): + self.app.options('/eggs', + headers={'Origin': 'notmyidea.org', + 'Access-Control-Request-Method': 'GET', + 'Access-Control-Request-Headers': 'foo,bar,baz'}, + status=400) + + def test_preflight_headers_arent_case_sensitive(self): + self.app.options('/spam', headers={ + 'Origin': 'notmyidea.org', + 'Access-Control-Request-Method': 'GET', + 'Access-Control-Request-Headers': 'x-my-header', }) diff --git a/cornice/tests/test_service.py b/cornice/tests/test_service.py index 48719b43..9208653f 100644 --- a/cornice/tests/test_service.py +++ b/cornice/tests/test_service.py @@ -7,6 +7,7 @@ _validator = lambda req: True _validator2 = lambda req: True +_stub = lambda req: None class TestService(TestCase): @@ -173,7 +174,7 @@ def test_schemas_for(self): service.add_view("GET", lambda x: "red", schema=schema) self.assertEquals(len(service.schemas_for("GET")), 1) service.add_view("GET", lambda x: "red", validators=_validator, - schema=schema) + schema=schema) self.assertEquals(len(service.schemas_for("GET")), 2) def test_class_parameters(self): @@ -223,7 +224,7 @@ def freshair(request): Service.default_validators = [custom_validator, ] Service.default_filters = [custom_filter, ] service = Service("TemperatureCooler", "/freshair") - service.add_view("get", freshair) + service.add_view("GET", freshair) method, view, args = service.definitions[0] self.assertIn(custom_validator, args['validators']) @@ -244,7 +245,7 @@ def groove_em_all(request): validators=[another_validator], filters=[another_filter]) - service2.add_view("get", groove_em_all) + service2.add_view("GET", groove_em_all) method, view, args = service2.definitions[0] self.assertIn(custom_validator, args['validators']) @@ -254,3 +255,134 @@ def groove_em_all(request): finally: Service.default_validators = old_validators Service.default_filters = old_filters + + def test_cors_support(self): + self.assertFalse( + Service(name='foo', path='/foo').cors_enabled) + + self.assertTrue( + Service(name='foo', path='/foo', cors_enabled=True) + .cors_enabled) + + self.assertFalse( + Service(name='foo', path='/foo', cors_enabled=False) + .cors_enabled) + + self.assertTrue( + Service(name='foo', path='/foo', cors_origins=('*',)) + .cors_enabled) + + self.assertFalse( + Service(name='foo', path='/foo', + cors_origins=('*'), cors_enabled=False) + .cors_enabled) + + def test_cors_headers_for_service_instanciation(self): + # When definining services, it's possible to add headers. This tests + # it is possible to list all the headers supported by a service. + service = Service('coconuts', '/migrate', + cors_headers=('X-Header-Coconut')) + self.assertNotIn('X-Header-Coconut', service.cors_supported_headers) + + service.add_view('POST', _stub) + self.assertIn('X-Header-Coconut', service.cors_supported_headers) + + def test_cors_headers_for_view_definition(self): + # defining headers in the view should work. + service = Service('coconuts', '/migrate') + service.add_view('POST', _stub, cors_headers=('X-Header-Foobar')) + self.assertIn('X-Header-Foobar', service.cors_supported_headers) + + def test_cors_headers_extension(self): + # definining headers in the service and in the view + service = Service('coconuts', '/migrate', + cors_headers=('X-Header-Foobar')) + service.add_view('POST', _stub, cors_headers=('X-Header-Barbaz')) + self.assertIn('X-Header-Foobar', service.cors_supported_headers) + self.assertIn('X-Header-Barbaz', service.cors_supported_headers) + + # check that adding the same header twice doesn't make bad things + # happen + service.add_view('POST', _stub, cors_headers=('X-Header-Foobar'),) + self.assertEquals(len(service.cors_supported_headers), 2) + + # check that adding a header on a cors disabled method doesn't + # change anything + service.add_view('put', _stub, + cors_headers=('X-Another-Header',), + cors_enabled=False) + + self.assertFalse('X-Another-Header' in service.cors_supported_headers) + + def test_cors_supported_methods(self): + foo = Service(name='foo', path='/foo', cors_enabled=True) + foo.add_view('GET', _stub) + self.assertIn('GET', foo.cors_supported_methods) + + foo.add_view('POST', _stub) + self.assertIn('POST', foo.cors_supported_methods) + + def test_disabling_cors_for_one_method(self): + foo = Service(name='foo', path='/foo', cors_enabled=True) + foo.add_view('GET', _stub) + self.assertIn('GET', foo.cors_supported_methods) + + foo.add_view('POST', _stub, cors_enabled=False) + self.assertIn('GET', foo.cors_supported_methods) + self.assertFalse('POST' in foo.cors_supported_methods) + + def test_cors_supported_origins(self): + foo = Service( + name='foo', path='/foo', cors_origins=('mozilla.org',)) + + foo.add_view('GET', _stub, + cors_origins=('notmyidea.org', 'lolnet.org')) + + self.assertIn('mozilla.org', foo.cors_supported_origins) + self.assertIn('notmyidea.org', foo.cors_supported_origins) + self.assertIn('lolnet.org', foo.cors_supported_origins) + + def test_per_method_supported_origins(self): + foo = Service( + name='foo', path='/foo', cors_origins=('mozilla.org',)) + foo.add_view('GET', _stub, cors_origins=('lolnet.org',)) + + self.assertTrue('mozilla.org' in foo.cors_origins_for('GET')) + self.assertTrue('lolnet.org' in foo.cors_origins_for('GET')) + + foo.add_view('POST', _stub) + self.assertFalse('lolnet.org' in foo.cors_origins_for('POST')) + + def test_credential_support_can_be_enabled(self): + foo = Service(name='foo', path='/foo', cors_credentials=True) + self.assertTrue(foo.cors_support_credentials()) + + def test_credential_support_is_disabled_by_default(self): + foo = Service(name='foo', path='/foo') + self.assertFalse(foo.cors_support_credentials()) + + def test_per_method_credential_support(self): + foo = Service(name='foo', path='/foo') + foo.add_view('GET', _stub, cors_credentials=True) + foo.add_view('POST', _stub) + self.assertTrue(foo.cors_support_credentials('GET')) + self.assertFalse(foo.cors_support_credentials('POST')) + + def test_method_takes_precendence_for_credential_support(self): + foo = Service(name='foo', path='/foo', cors_credentials=True) + foo.add_view('GET', _stub, cors_credentials=False) + self.assertFalse(foo.cors_support_credentials('GET')) + + def test_max_age_can_be_defined(self): + foo = Service(name='foo', path='/foo', cors_max_age=42) + self.assertEquals(foo.cors_max_age_for(), 42) + + def test_max_age_can_be_different_dependeing_methods(self): + foo = Service(name='foo', path='/foo', cors_max_age=42) + foo.add_view('GET', _stub) + foo.add_view('POST', _stub, cors_max_age=32) + foo.add_view('PUT', _stub, cors_max_age=7) + + self.assertEquals(foo.cors_max_age_for('GET'), 42) + self.assertEquals(foo.cors_max_age_for('POST'), 32) + self.assertEquals(foo.cors_max_age_for('PUT'), 7) From 9a79e9d1bda5ee3736f8541f38cca1abae01df49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexis=20M=C3=A9taireau?= Date: Thu, 24 Jan 2013 17:45:02 +0100 Subject: [PATCH 2/2] Allow usage of a cors_policy dict. This can be useful if you don't want to specify all the cors-related parameters each time you define a service --- cornice/service.py | 20 +++++++++++++++++--- cornice/tests/test_service.py | 13 +++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/cornice/service.py b/cornice/service.py index 6df5eb51..5f1c1dce 100644 --- a/cornice/service.py +++ b/cornice/service.py @@ -117,6 +117,17 @@ class Service(object): ones (Default: True). If set to False, all the headers need be explicitely mentionned with the cors_headers parameter. + :param cors_policy: + It may be easier to have an external object containing all the policy + information related to CORS, e.g:: + + >>> cors_policy = {'origins': ('*',), 'max_age': 42, + ... 'credentials': True} + + You can pass a dict here and all the values will be + unpacked and considered rather than the parameters starting by `cors_` + here. + See http://readthedocs.org/docs/pyramid/en/1.0-branch/glossary.html#term-acl for more information about ACLs. @@ -135,7 +146,8 @@ class Service(object): def __repr__(self): return u'' % (self.name, self.path) - def __init__(self, name, path, description=None, depth=1, **kw): + def __init__(self, name, path, description=None, cors_policy=None, depth=1, + **kw): self.name = name self.path = path self.description = description @@ -143,6 +155,10 @@ def __init__(self, name, path, description=None, depth=1, **kw): self._schemas = {} self._cors_enabled = None + if cors_policy: + for key, value in cors_policy.items(): + kw.setdefault('cors_' + key, value) + for key in self.list_arguments: # default_{validators,filters} and {filters,validators} doesn't # have to be mutables, so we need to create a new list from them @@ -346,8 +362,6 @@ def cors_enabled(self): def cors_enabled(self, value): self._cors_enabled = value - - @property def cors_supported_headers(self): """Return an iterable of supported headers for this service. diff --git a/cornice/tests/test_service.py b/cornice/tests/test_service.py index 9208653f..8faa811c 100644 --- a/cornice/tests/test_service.py +++ b/cornice/tests/test_service.py @@ -386,3 +386,16 @@ def test_max_age_can_be_different_dependeing_methods(self): self.assertEquals(foo.cors_max_age_for('GET'), 42) self.assertEquals(foo.cors_max_age_for('POST'), 32) self.assertEquals(foo.cors_max_age_for('PUT'), 7) + + def test_cors_policy(self): + policy = {'origins': ('foo', 'bar', 'baz')} + foo = Service(name='foo', path='/foo', cors_policy=policy) + self.assertTrue('foo' in foo.cors_supported_origins) + self.assertTrue('bar' in foo.cors_supported_origins) + self.assertTrue('baz' in foo.cors_supported_origins) + + def test_cors_policy_can_be_overwritten(self): + policy = {'origins': ('foo', 'bar', 'baz')} + foo = Service(name='foo', path='/foo', cors_origins=(), + cors_policy=policy) + self.assertEquals(len(foo.cors_supported_origins), 0)