From 8e5998cd759d0a87a0197a87e7f780fcf84f2933 Mon Sep 17 00:00:00 2001 From: Greg Kempe Date: Sun, 2 Feb 2014 15:39:00 +0200 Subject: [PATCH] Install csrf_token() for all template types. This installs csrf_token() in the render context so that it works by default for all template types, not just Jinja2. Also adds a test for the helper function and bumps the version number. --- flask_wtf/__init__.py | 2 +- flask_wtf/csrf.py | 6 +++++- tests/templates/csrf.html | 8 ++++++++ tests/test_csrf.py | 8 ++++++++ 4 files changed, 22 insertions(+), 2 deletions(-) create mode 100644 tests/templates/csrf.html diff --git a/flask_wtf/__init__.py b/flask_wtf/__init__.py index 91fb1d5c..4ae9851f 100644 --- a/flask_wtf/__init__.py +++ b/flask_wtf/__init__.py @@ -16,4 +16,4 @@ from .csrf import CsrfProtect from .recaptcha import * -__version__ = '0.9.4' +__version__ = '0.9.5' diff --git a/flask_wtf/csrf.py b/flask_wtf/csrf.py index 9bcc49ae..5b4a8c34 100644 --- a/flask_wtf/csrf.py +++ b/flask_wtf/csrf.py @@ -130,9 +130,13 @@ def __init__(self, app=None): self.init_app(app) def init_app(self, app): - app.jinja_env.globals['csrf_token'] = generate_csrf app.config.setdefault('WTF_CSRF_SSL_STRICT', True) app.config.setdefault('WTF_CSRF_ENABLED', True) + + # expose csrf_token as a helper in all templates + @app.context_processor + def csrf_token(): + return dict(csrf_token=generate_csrf) @app.before_request def _csrf_protect(): diff --git a/tests/templates/csrf.html b/tests/templates/csrf.html new file mode 100644 index 00000000..c16c98ac --- /dev/null +++ b/tests/templates/csrf.html @@ -0,0 +1,8 @@ + + + + + + token: {{ csrf_token() }} + + diff --git a/tests/test_csrf.py b/tests/test_csrf.py index c43b22ff..0b0d89d1 100644 --- a/tests/test_csrf.py +++ b/tests/test_csrf.py @@ -184,3 +184,11 @@ def test_validate_csrf(self): assert not validate_csrf('ff##dd') csrf_token = generate_csrf() assert validate_csrf(csrf_token) + + def test_csrf_token_helper(self): + @self.app.route("/token") + def withtoken(): + return render_template("csrf.html") + + response = self.client.get('/token') + assert re.compile('token: [a-zA-Z0-9#.]{3,}').search(response.data)