Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jinja): improve url parameter formatting #16711

Merged
merged 3 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ assists people when migrating to a new version.
## Next

### Breaking Changes

- [16711](https://github.com/apache/incubator-superset/pull/16711): The `url_param` Jinja function will now by default escape the result. For instance, the value `O'Brien` will now be changed to `O''Brien`. To disable this behavior, call `url_param` with `escape_result` set to `False`: `url_param("my_key", "my default", escape_result=False)`.

### Potential Downtime
### Deprecations
### Other
Expand Down
22 changes: 20 additions & 2 deletions superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from flask_babel import gettext as _
from jinja2 import DebugUndefined
from jinja2.sandbox import SandboxedEnvironment
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.types import String
from typing_extensions import TypedDict

from superset.exceptions import SupersetTemplateException
Expand Down Expand Up @@ -95,9 +97,11 @@ def __init__(
self,
extra_cache_keys: Optional[List[Any]] = None,
removed_filters: Optional[List[str]] = None,
dialect: Optional[Dialect] = None,
):
self.extra_cache_keys = extra_cache_keys
self.removed_filters = removed_filters if removed_filters is not None else []
self.dialect = dialect

def current_user_id(self, add_to_cache_keys: bool = True) -> Optional[int]:
"""
Expand Down Expand Up @@ -145,7 +149,11 @@ def cache_key_wrapper(self, key: Any) -> Any:
return key

def url_param(
self, param: str, default: Optional[str] = None, add_to_cache_keys: bool = True
self,
param: str,
default: Optional[str] = None,
add_to_cache_keys: bool = True,
escape_result: bool = True,
) -> Optional[str]:
"""
Read a url or post parameter and use it in your SQL Lab query.
Expand All @@ -166,6 +174,7 @@ def url_param(
:param param: the parameter to lookup
:param default: the value to return in the absence of the parameter
:param add_to_cache_keys: Whether the value should be included in the cache key
:param escape_result: Should special characters in the result be escaped
:returns: The URL parameters
"""

Expand All @@ -178,6 +187,11 @@ def url_param(
form_data, _ = get_form_data()
url_params = form_data.get("url_params") or {}
result = url_params.get(param, default)
if result and escape_result and self.dialect:
# use the dialect specific quoting logic to escape string
result = String().literal_processor(dialect=self.dialect)(value=result)[
1:-1
]
if add_to_cache_keys:
self.cache_key_wrapper(result)
return result
Expand Down Expand Up @@ -430,7 +444,11 @@ def process_template(self, sql: str, **kwargs: Any) -> str:
class JinjaTemplateProcessor(BaseTemplateProcessor):
def set_context(self, **kwargs: Any) -> None:
super().set_context(**kwargs)
extra_cache = ExtraCache(self._extra_cache_keys, self._removed_filters)
extra_cache = ExtraCache(
extra_cache_keys=self._extra_cache_keys,
removed_filters=self._removed_filters,
dialect=self._database.get_dialect(),
)
self._context.update(
{
"url_param": partial(safe_proxy, extra_cache.url_param),
Expand Down
13 changes: 11 additions & 2 deletions tests/integration_tests/base_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
from flask import Response
from flask_appbuilder.security.sqla import models as ab_models
from flask_testing import TestCase
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.ext.declarative.api import DeclarativeMeta
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
from sqlalchemy.dialects.mysql import dialect

from tests.integration_tests.test_app import app
from superset.sql_parse import CtasMethod
Expand Down Expand Up @@ -422,15 +424,22 @@ def create_fake_db_for_macros(self):
self.login(username="admin")
database_name = "db_for_macros_testing"
db_id = 200
return self.get_or_create(
database = self.get_or_create(
cls=models.Database,
criteria={"database_name": database_name},
session=db.session,
sqlalchemy_uri="db_for_macros_testing://user@host:8080/hive",
id=db_id,
)

def delete_fake_db_for_macros(self):
def mock_get_dialect() -> Dialect:
return dialect()

database.get_dialect = mock_get_dialect
return database

@staticmethod
def delete_fake_db_for_macros():
database = (
db.session.query(Database)
.filter(Database.database_name == "db_for_macros_testing")
Expand Down
31 changes: 31 additions & 0 deletions tests/integration_tests/jinja_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from unittest import mock

import pytest
from sqlalchemy.dialects.postgresql import dialect

import tests.integration_tests.test_app
from superset import app
Expand Down Expand Up @@ -199,6 +200,36 @@ def test_url_param_form_data(self) -> None:
cache = ExtraCache()
self.assertEqual(cache.url_param("foo"), "bar")

def test_url_param_escaped_form_data(self) -> None:
with app.test_request_context(
query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
):
cache = ExtraCache(dialect=dialect())
self.assertEqual(cache.url_param("foo"), "O''Brien")

def test_url_param_escaped_default_form_data(self) -> None:
with app.test_request_context(
query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
):
cache = ExtraCache(dialect=dialect())
self.assertEqual(cache.url_param("bar", "O'Malley"), "O''Malley")

def test_url_param_unescaped_form_data(self) -> None:
with app.test_request_context(
query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
):
cache = ExtraCache(dialect=dialect())
self.assertEqual(cache.url_param("foo", escape_result=False), "O'Brien")

def test_url_param_unescaped_default_form_data(self) -> None:
with app.test_request_context(
query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
):
cache = ExtraCache(dialect=dialect())
self.assertEqual(
cache.url_param("bar", "O'Malley", escape_result=False), "O'Malley"
)

def test_safe_proxy_primitive(self) -> None:
def func(input: Any) -> Any:
return input
Expand Down