From 2aff4de390deca933c97fc24facc3515fea788ce Mon Sep 17 00:00:00 2001 From: Rob DiCiuccio Date: Tue, 17 Nov 2020 11:55:47 -0800 Subject: [PATCH] feat(templating): Safer Jinja template processing (#11704) * Enable safer Jinja template processing * Allow JINJA_CONTEXT_ADDONS with SAFE_JINJA_PROCESSING * Make template processor initialization less magical, refactor classes * Consolidat Jinja logic, remove config flag in favor of sane defaults * Restore previous ENABLE_TEMPLATE_PROCESSING default * Add recursive type checking, update tests * remove erroneous config file * Remove TableColumn models from template context * pylint refactoring * Add entry to UPDATING.md * Resolve botched merge conflict * Update docs on running single python test * Refactor template context checking to support engine-specific methods --- CONTRIBUTING.md | 2 +- UPDATING.md | 4 +- superset/__init__.py | 2 - superset/app.py | 5 - superset/config.py | 13 +- superset/connectors/sqla/models.py | 6 +- superset/extensions.py | 39 +----- superset/jinja_context.py | 187 ++++++++++++++++++++------- tests/core_tests.py | 93 +------------- tests/jinja_context_tests.py | 197 ++++++++++++++++++++++++++++- 10 files changed, 358 insertions(+), 190 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index cffbedc74d039..2e7d79e405303 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -607,7 +607,7 @@ tox -e -- tests/test_file.py or for a specific test via, ```bash -tox -e -- tests/test_file.py:TestClassName.test_method_name +tox -e -- tests/test_file.py::TestClassName::test_method_name ``` Note that the test environment uses a temporary directory for defining the diff --git a/UPDATING.md b/UPDATING.md index c0e03afbe9a03..b4c686338804a 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -24,6 +24,8 @@ assists people when migrating to a new version. ## Next +- [11704](https://github.com/apache/incubator-superset/pull/11704) Breaking change: Jinja templating for SQL queries has been updated, removing default modules such as `datetime` and `random` and enforcing static template values. To restore or extend functionality, use `JINJA_CONTEXT_ADDONS` and `CUSTOM_TEMPLATE_PROCESSORS` in `superset_config.py`. + - [11509](https://github.com/apache/incubator-superset/pull/11509): Config value `TABLE_NAMES_CACHE_CONFIG` has been renamed to `DATA_CACHE_CONFIG`, which will now also hold query results cache from connected datasources (previously held in `CACHE_CONFIG`), in addition to the table names. If you will set `DATA_CACHE_CONFIG` to a new cache backend different than your previous `CACHE_CONFIG`, plan for additional cache warmup to avoid degrading charting performance for the end users. - [11575](https://github.com/apache/incubator-superset/pull/11575) The Row Level Security (RLS) config flag has been moved to a feature flag. To migrate, add `ROW_LEVEL_SECURITY: True` to the `FEATURE_FLAGS` dict in `superset_config.py`. @@ -38,7 +40,7 @@ assists people when migrating to a new version. and requires more work. You can easily turn on the languages you want to expose in your environment in superset_config.py -- [11172](https://github.com/apache/incubator-superset/pull/11172): Breaking change: SQL templating is turned off be default. To turn it on set `ENABLE_TEMPLATE_PROCESSING` to True on `DEFAULT_FEATURE_FLAGS` +- [11172](https://github.com/apache/incubator-superset/pull/11172): Breaking change: SQL templating is turned off by default. To turn it on set `ENABLE_TEMPLATE_PROCESSING` to True on `FEATURE_FLAGS` - [11155](https://github.com/apache/incubator-superset/pull/11155): The `FAB_UPDATE_PERMS` config parameter is no longer required as the Superset application correctly informs FAB under which context permissions should be updated. diff --git a/superset/__init__.py b/superset/__init__.py index 97e086edcbd89..6df897f3ecdb1 100644 --- a/superset/__init__.py +++ b/superset/__init__.py @@ -26,7 +26,6 @@ db, event_logger, feature_flag_manager, - jinja_context_manager, manifest_processor, results_backend_manager, security_manager, @@ -44,7 +43,6 @@ get_feature_flags = feature_flag_manager.get_feature_flags get_manifest_files = manifest_processor.get_manifest_files is_feature_enabled = feature_flag_manager.is_feature_enabled -jinja_base_context = jinja_context_manager.base_context results_backend = LocalProxy(lambda: results_backend_manager.results_backend) results_backend_use_msgpack = LocalProxy( lambda: results_backend_manager.should_use_msgpack diff --git a/superset/app.py b/superset/app.py index 806dbfd482179..acc2f2612747b 100644 --- a/superset/app.py +++ b/superset/app.py @@ -35,7 +35,6 @@ csrf, db, feature_flag_manager, - jinja_context_manager, machine_auth_provider_factory, manifest_processor, migrate, @@ -515,7 +514,6 @@ def init_app(self) -> None: self.configure_logging() self.configure_middlewares() self.configure_cache() - self.configure_jinja_context() with self.flask_app.app_context(): # type: ignore self.init_app_in_ctx() @@ -573,9 +571,6 @@ def configure_url_map_converters(self) -> None: self.flask_app.url_map.converters["regex"] = RegexConverter self.flask_app.url_map.converters["object_type"] = ObjectTypeConverter - def configure_jinja_context(self) -> None: - jinja_context_manager.init_app(self.flask_app) - def configure_middlewares(self) -> None: if self.config["ENABLE_CORS"]: from flask_cors import CORS diff --git a/superset/config.py b/superset/config.py index 87f7a51dab580..9eb6cf16b587b 100644 --- a/superset/config.py +++ b/superset/config.py @@ -672,14 +672,19 @@ class CeleryConfig: # pylint: disable=too-few-public-methods # A dictionary of items that gets merged into the Jinja context for # SQL Lab. The existing context gets updated with this dictionary, # meaning values for existing keys get overwritten by the content of this -# dictionary. +# dictionary. Exposing functionality through JINJA_CONTEXT_ADDONS has security +# implications as it opens a window for a user to execute untrusted code. +# It's important to make sure that the objects exposed (as well as objects attached +# to those objets) are harmless. We recommend only exposing simple/pure functions that +# return native types. JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {} -# A dictionary of macro template processors that gets merged into global +# A dictionary of macro template processors (by engine) that gets merged into global # template processors. The existing template processors get updated with this # dictionary, which means the existing keys get overwritten by the content of this -# dictionary. The customized addons don't necessarily need to use jinjia templating -# language. This allows you to define custom logic to process macro template. +# dictionary. The customized addons don't necessarily need to use Jinja templating +# language. This allows you to define custom logic to process templates on a per-engine +# basis. Example value = `{"presto": CustomPrestoTemplateProcessor}` CUSTOM_TEMPLATE_PROCESSORS: Dict[str, Type[BaseTemplateProcessor]] = {} # Roles that are controlled by the API / Superset and should not be changes diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 34961f0b4b6e9..f2ed5f1f4c0be 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -875,14 +875,14 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma ) -> SqlaQuery: """Querying any sqla table from this common interface""" template_kwargs = { - "from_dttm": from_dttm, + "from_dttm": from_dttm.isoformat() if from_dttm else None, "groupby": groupby, "metrics": metrics, "row_limit": row_limit, "row_offset": row_offset, - "to_dttm": to_dttm, + "to_dttm": to_dttm.isoformat() if to_dttm else None, "filter": filter, - "columns": {col.column_name: col for col in self.columns}, + "columns": [col.column_name for col in self.columns], } is_sip_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE") template_kwargs.update(self.template_params_dict) diff --git a/superset/extensions.py b/superset/extensions.py index 9be0c37bb937a..7011bb237d636 100644 --- a/superset/extensions.py +++ b/superset/extensions.py @@ -16,15 +16,10 @@ # under the License. import json import os -import random -import time -import uuid -from datetime import datetime, timedelta -from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional import celery from cachelib.base import BaseCache -from dateutil.relativedelta import relativedelta from flask import Flask from flask_appbuilder import AppBuilder, SQLA from flask_migrate import Migrate @@ -36,37 +31,6 @@ from superset.utils.feature_flag_manager import FeatureFlagManager from superset.utils.machine_auth import MachineAuthProviderFactory -if TYPE_CHECKING: - from superset.jinja_context import BaseTemplateProcessor - - -class JinjaContextManager: - def __init__(self) -> None: - self._base_context = { - "datetime": datetime, - "random": random, - "relativedelta": relativedelta, - "time": time, - "timedelta": timedelta, - "uuid1": uuid.uuid1, - "uuid3": uuid.uuid3, - "uuid4": uuid.uuid4, - "uuid5": uuid.uuid5, - } - self._template_processors: Dict[str, Type["BaseTemplateProcessor"]] = {} - - def init_app(self, app: Flask) -> None: - self._base_context.update(app.config["JINJA_CONTEXT_ADDONS"]) - self._template_processors.update(app.config["CUSTOM_TEMPLATE_PROCESSORS"]) - - @property - def base_context(self) -> Dict[str, Any]: - return self._base_context - - @property - def template_processors(self) -> Dict[str, Type["BaseTemplateProcessor"]]: - return self._template_processors - class ResultsBackendManager: def __init__(self) -> None: @@ -140,7 +104,6 @@ def get_manifest_files(self, bundle: str, asset_type: str) -> List[str]: _event_logger: Dict[str, Any] = {} event_logger = LocalProxy(lambda: _event_logger.get("event_logger")) feature_flag_manager = FeatureFlagManager() -jinja_context_manager = JinjaContextManager() machine_auth_provider_factory = MachineAuthProviderFactory() manifest_processor = UIManifestProcessor(APP_DIR) migrate = Migrate() diff --git a/superset/jinja_context.py b/superset/jinja_context.py index 988aea89bdb2a..400b9d60df7e0 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -15,22 +15,49 @@ # specific language governing permissions and limitations # under the License. """Defines the templating context for SQL Lab""" -import inspect +import json import re -from typing import Any, cast, List, Optional, Tuple, TYPE_CHECKING +from functools import partial +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, TYPE_CHECKING -from flask import g, request +from flask import current_app, g, request +from flask_babel import gettext as _ from jinja2.sandbox import SandboxedEnvironment -from superset import jinja_base_context -from superset.extensions import feature_flag_manager, jinja_context_manager -from superset.utils.core import convert_legacy_filters_into_adhoc, merge_extra_filters +from superset.exceptions import SupersetTemplateException +from superset.extensions import feature_flag_manager +from superset.utils.core import ( + convert_legacy_filters_into_adhoc, + memoized, + merge_extra_filters, +) if TYPE_CHECKING: from superset.connectors.sqla.models import SqlaTable from superset.models.core import Database from superset.models.sql_lab import Query +NONE_TYPE = type(None).__name__ +ALLOWED_TYPES = ( + NONE_TYPE, + "bool", + "str", + "unicode", + "int", + "long", + "float", + "list", + "dict", + "tuple", + "set", +) +COLLECTION_TYPES = ("list", "dict", "tuple", "set") + + +@memoized +def context_addons() -> Dict[str, Any]: + return current_app.config.get("JINJA_CONTEXT_ADDONS", {}) + def filter_values(column: str, default: Optional[str] = None) -> List[str]: """ Gets a values for a particular filter as a list @@ -151,7 +178,7 @@ def cache_key_wrapper(self, key: Any) -> Any: def url_param( self, param: str, default: Optional[str] = None, add_to_cache_keys: bool = True - ) -> Optional[Any]: + ) -> Optional[str]: """ Read a url or post parameter and use it in your SQL Lab query. @@ -186,19 +213,68 @@ def url_param( return result +def safe_proxy(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + return_value = func(*args, **kwargs) + value_type = type(return_value).__name__ + if value_type not in ALLOWED_TYPES: + raise SupersetTemplateException( + _( + "Unsafe return type for function %(func)s: %(value_type)s", + func=func.__name__, + value_type=value_type, + ) + ) + if value_type in COLLECTION_TYPES: + try: + return_value = json.loads(json.dumps(return_value)) + except TypeError: + raise SupersetTemplateException( + _("Unsupported return value for method %(name)s", name=func.__name__,) + ) + + return return_value + + +def validate_context_types(context: Dict[str, Any]) -> Dict[str, Any]: + for key in context: + arg_type = type(context[key]).__name__ + if arg_type not in ALLOWED_TYPES and key not in context_addons(): + if arg_type == "partial" and context[key].func.__name__ == "safe_proxy": + continue + raise SupersetTemplateException( + _( + "Unsafe template value for key %(key)s: %(value_type)s", + key=key, + value_type=arg_type, + ) + ) + if arg_type in COLLECTION_TYPES: + try: + context[key] = json.loads(json.dumps(context[key])) + except TypeError: + raise SupersetTemplateException( + _("Unsupported template value for key %(key)s", key=key) + ) + + return context + + +def validate_template_context( + engine: Optional[str], context: Dict[str, Any] +) -> Dict[str, Any]: + if engine and engine in context: + # validate engine context separately to allow for engine-specific methods + engine_context = validate_context_types(context.pop(engine)) + valid_context = validate_context_types(context) + valid_context[engine] = engine_context + return valid_context + + return validate_context_types(context) + + class BaseTemplateProcessor: # pylint: disable=too-few-public-methods - """Base class for database-specific jinja context - - There's this bit of magic in ``process_template`` that instantiates only - the database context for the active database as a ``models.Database`` - object binds it to the context object, so that object methods - have access to - that context. This way, {{ hive.latest_partition('mytable') }} just - knows about the database it is operating in. - - This means that object methods are only available for the active database - and are given access to the ``models.Database`` object and schema - name. For globally available methods use ``@classmethod``. + """ + Base class for database-specific jinja context """ engine: Optional[str] = None @@ -218,22 +294,14 @@ def __init__( self._schema = query.schema elif table: self._schema = table.schema + self._extra_cache_keys = extra_cache_keys + self._context: Dict[str, Any] = {} + self._env = SandboxedEnvironment() + self.set_context(**kwargs) - extra_cache = ExtraCache(extra_cache_keys) - - self._context = { - "url_param": extra_cache.url_param, - "current_user_id": extra_cache.current_user_id, - "current_username": extra_cache.current_username, - "cache_key_wrapper": extra_cache.cache_key_wrapper, - "filter_values": filter_values, - "form_data": {}, - } + def set_context(self, **kwargs: Any) -> None: self._context.update(kwargs) - self._context.update(jinja_base_context) - if self.engine: - self._context[self.engine] = self - self._env = SandboxedEnvironment() + self._context.update(context_addons()) def process_template(self, sql: str, **kwargs: Any) -> str: """Processes a sql template @@ -244,7 +312,24 @@ def process_template(self, sql: str, **kwargs: Any) -> str: """ template = self._env.from_string(sql) kwargs.update(self._context) - return template.render(kwargs) + + context = validate_template_context(self.engine, kwargs) + return template.render(context) + + +class JinjaTemplateProcessor(BaseTemplateProcessor): + def set_context(self, **kwargs: Any) -> None: + super().set_context(**kwargs) + extra_cache = ExtraCache(self._extra_cache_keys) + self._context.update( + { + "url_param": partial(safe_proxy, extra_cache.url_param), + "current_user_id": partial(safe_proxy, extra_cache.current_user_id), + "current_username": partial(safe_proxy, extra_cache.current_username), + "cache_key_wrapper": partial(safe_proxy, extra_cache.cache_key_wrapper), + "filter_values": partial(safe_proxy, filter_values), + } + ) class NoOpTemplateProcessor( @@ -257,7 +342,7 @@ def process_template(self, sql: str, **kwargs: Any) -> str: return sql -class PrestoTemplateProcessor(BaseTemplateProcessor): +class PrestoTemplateProcessor(JinjaTemplateProcessor): """Presto Jinja context The methods described here are namespaced under ``presto`` in the @@ -266,6 +351,15 @@ class PrestoTemplateProcessor(BaseTemplateProcessor): engine = "presto" + def set_context(self, **kwargs: Any) -> None: + super().set_context(**kwargs) + self._context[self.engine] = { + "first_latest_partition": partial(safe_proxy, self.first_latest_partition), + "latest_partitions": partial(safe_proxy, self.latest_partitions), + "latest_sub_partition": partial(safe_proxy, self.latest_sub_partition), + "latest_partition": partial(safe_proxy, self.latest_partition), + } + @staticmethod def _schema_table( table_name: str, schema: Optional[str] @@ -319,13 +413,18 @@ class HiveTemplateProcessor(PrestoTemplateProcessor): engine = "hive" -# The global template processors from Jinja context manager. -template_processors = jinja_context_manager.template_processors -keys = tuple(globals().keys()) -for k in keys: - o = globals()[k] - if o and inspect.isclass(o) and issubclass(o, BaseTemplateProcessor): - template_processors[o.engine] = o +DEFAULT_PROCESSORS = {"presto": PrestoTemplateProcessor, "hive": HiveTemplateProcessor} + + +@memoized +def get_template_processors() -> Dict[str, Any]: + processors = current_app.config.get("CUSTOM_TEMPLATE_PROCESSORS", {}) + for engine in DEFAULT_PROCESSORS: + # do not overwrite engine-specific CUSTOM_TEMPLATE_PROCESSORS + if not engine in processors: + processors[engine] = DEFAULT_PROCESSORS[engine] + + return processors def get_template_processor( @@ -335,8 +434,8 @@ def get_template_processor( **kwargs: Any, ) -> BaseTemplateProcessor: if feature_flag_manager.is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"): - template_processor = template_processors.get( - database.backend, BaseTemplateProcessor + template_processor = get_template_processors().get( + database.backend, JinjaTemplateProcessor ) else: template_processor = NoOpTemplateProcessor diff --git a/tests/core_tests.py b/tests/core_tests.py index 64abd94fdf810..87b698f4d8097 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -669,103 +669,14 @@ def test_extra_table_metadata(self): f"/superset/extra_table_metadata/{example_db.id}/birth_names/{schema}/" ) - def test_process_template(self): - maindb = utils.get_example_database() - if maindb.backend == "presto": - # TODO: make it work for presto - return - sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'" - tp = jinja_context.get_template_processor(database=maindb) - rendered = tp.process_template(sql) - self.assertEqual("SELECT '2017-01-01T00:00:00'", rendered) - - def test_get_template_kwarg(self): - maindb = utils.get_example_database() - if maindb.backend == "presto": - # TODO: make it work for presto - return - s = "{{ foo }}" - tp = jinja_context.get_template_processor(database=maindb, foo="bar") - rendered = tp.process_template(s) - self.assertEqual("bar", rendered) - - def test_template_kwarg(self): - maindb = utils.get_example_database() - if maindb.backend == "presto": - # TODO: make it work for presto - return - s = "{{ foo }}" - tp = jinja_context.get_template_processor(database=maindb) - rendered = tp.process_template(s, foo="bar") - self.assertEqual("bar", rendered) - def test_templated_sql_json(self): if utils.get_example_database().backend == "presto": # TODO: make it work for presto return self.login() - sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}' as test" + sql = "SELECT '{{ 1+1 }}' as test" data = self.run_sql(sql, "fdaklj3ws") - self.assertEqual(data["data"][0]["test"], "2017-01-01T00:00:00") - - @mock.patch("tests.superset_test_custom_template_processors.datetime") - def test_custom_process_template(self, mock_dt) -> None: - """Test macro defined in custom template processor works.""" - mock_dt.utcnow = mock.Mock(return_value=datetime.datetime(1970, 1, 1)) - db = mock.Mock() - db.backend = "db_for_macros_testing" - tp = jinja_context.get_template_processor(database=db) - - sql = "SELECT '$DATE()'" - rendered = tp.process_template(sql) - self.assertEqual("SELECT '{}'".format("1970-01-01"), rendered) - - sql = "SELECT '$DATE(1, 2)'" - rendered = tp.process_template(sql) - self.assertEqual("SELECT '{}'".format("1970-01-02"), rendered) - - def test_custom_get_template_kwarg(self): - """Test macro passed as kwargs when getting template processor - works in custom template processor.""" - db = mock.Mock() - db.backend = "db_for_macros_testing" - s = "$foo()" - tp = jinja_context.get_template_processor(database=db, foo=lambda: "bar") - rendered = tp.process_template(s) - self.assertEqual("bar", rendered) - - def test_custom_template_kwarg(self) -> None: - """Test macro passed as kwargs when processing template - works in custom template processor.""" - db = mock.Mock() - db.backend = "db_for_macros_testing" - s = "$foo()" - tp = jinja_context.get_template_processor(database=db) - rendered = tp.process_template(s, foo=lambda: "bar") - self.assertEqual("bar", rendered) - - def test_custom_template_processors_overwrite(self) -> None: - """Test template processor for presto gets overwritten by custom one.""" - db = mock.Mock() - db.backend = "db_for_macros_testing" - tp = jinja_context.get_template_processor(database=db) - - sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'" - rendered = tp.process_template(sql) - self.assertEqual(sql, rendered) - - sql = "SELECT '{{ DATE(1, 2) }}'" - rendered = tp.process_template(sql) - self.assertEqual(sql, rendered) - - def test_custom_template_processors_ignored(self) -> None: - """Test custom template processor is ignored for a difference backend - database.""" - maindb = utils.get_example_database() - sql = "SELECT '$DATE()'" - tp = jinja_context.get_template_processor(database=maindb) - rendered = tp.process_template(sql) - assert sql == rendered + self.assertEqual(data["data"][0]["test"], "2") @mock.patch("tests.superset_test_custom_template_processors.datetime") @mock.patch("superset.sql_lab.get_sql_results") diff --git a/tests/jinja_context_tests.py b/tests/jinja_context_tests.py index 349e13593aa8a..a3314d05ea25c 100644 --- a/tests/jinja_context_tests.py +++ b/tests/jinja_context_tests.py @@ -15,10 +15,22 @@ # specific language governing permissions and limitations # under the License. import json +from datetime import datetime +from typing import Any +from unittest import mock + +import pytest import tests.test_app from superset import app -from superset.jinja_context import ExtraCache, filter_values +from superset.exceptions import SupersetTemplateException +from superset.jinja_context import ( + ExtraCache, + filter_values, + get_template_processor, + safe_proxy, +) +from superset.utils import core as utils from tests.base_tests import SupersetTestCase @@ -97,3 +109,186 @@ def test_url_param_form_data(self) -> None: query_string={"form_data": json.dumps({"url_params": {"foo": "bar"}})} ): self.assertEqual(ExtraCache().url_param("foo"), "bar") + + def test_safe_proxy_primitive(self) -> None: + def func(input: Any) -> Any: + return input + + return_value = safe_proxy(func, "foo") + self.assertEqual("foo", return_value) + + def test_safe_proxy_dict(self) -> None: + def func(input: Any) -> Any: + return input + + return_value = safe_proxy(func, {"foo": "bar"}) + self.assertEqual({"foo": "bar"}, return_value) + + def test_safe_proxy_lambda(self) -> None: + def func(input: Any) -> Any: + return input + + with pytest.raises(SupersetTemplateException): + safe_proxy(func, lambda: "bar") + + def test_safe_proxy_nested_lambda(self) -> None: + def func(input: Any) -> Any: + return input + + with pytest.raises(SupersetTemplateException): + safe_proxy(func, {"foo": lambda: "bar"}) + + def test_process_template(self) -> None: + maindb = utils.get_example_database() + sql = "SELECT '{{ 1+1 }}'" + tp = get_template_processor(database=maindb) + rendered = tp.process_template(sql) + self.assertEqual("SELECT '2'", rendered) + + def test_get_template_kwarg(self) -> None: + maindb = utils.get_example_database() + s = "{{ foo }}" + tp = get_template_processor(database=maindb, foo="bar") + rendered = tp.process_template(s) + self.assertEqual("bar", rendered) + + def test_template_kwarg(self) -> None: + maindb = utils.get_example_database() + s = "{{ foo }}" + tp = get_template_processor(database=maindb) + rendered = tp.process_template(s, foo="bar") + self.assertEqual("bar", rendered) + + def test_get_template_kwarg_dict(self) -> None: + maindb = utils.get_example_database() + s = "{{ foo.bar }}" + tp = get_template_processor(database=maindb, foo={"bar": "baz"}) + rendered = tp.process_template(s) + self.assertEqual("baz", rendered) + + def test_template_kwarg_dict(self) -> None: + maindb = utils.get_example_database() + s = "{{ foo.bar }}" + tp = get_template_processor(database=maindb) + rendered = tp.process_template(s, foo={"bar": "baz"}) + self.assertEqual("baz", rendered) + + def test_get_template_kwarg_lambda(self) -> None: + maindb = utils.get_example_database() + s = "{{ foo() }}" + tp = get_template_processor(database=maindb, foo=lambda: "bar") + with pytest.raises(SupersetTemplateException): + tp.process_template(s) + + def test_template_kwarg_lambda(self) -> None: + maindb = utils.get_example_database() + s = "{{ foo() }}" + tp = get_template_processor(database=maindb) + with pytest.raises(SupersetTemplateException): + tp.process_template(s, foo=lambda: "bar") + + def test_get_template_kwarg_module(self) -> None: + maindb = utils.get_example_database() + s = "{{ dt(2017, 1, 1).isoformat() }}" + tp = get_template_processor(database=maindb, dt=datetime) + with pytest.raises(SupersetTemplateException): + tp.process_template(s) + + def test_template_kwarg_module(self) -> None: + maindb = utils.get_example_database() + s = "{{ dt(2017, 1, 1).isoformat() }}" + tp = get_template_processor(database=maindb) + with pytest.raises(SupersetTemplateException): + tp.process_template(s, dt=datetime) + + def test_get_template_kwarg_nested_module(self) -> None: + maindb = utils.get_example_database() + s = "{{ foo.dt }}" + tp = get_template_processor(database=maindb, foo={"dt": datetime}) + with pytest.raises(SupersetTemplateException): + tp.process_template(s) + + def test_template_kwarg_nested_module(self) -> None: + maindb = utils.get_example_database() + s = "{{ foo.dt }}" + tp = get_template_processor(database=maindb) + with pytest.raises(SupersetTemplateException): + tp.process_template(s, foo={"bar": datetime}) + + @mock.patch("superset.jinja_context.HiveTemplateProcessor.latest_partition") + def test_template_hive(self, lp_mock) -> None: + lp_mock.return_value = "the_latest" + db = mock.Mock() + db.backend = "hive" + s = "{{ hive.latest_partition('my_table') }}" + tp = get_template_processor(database=db) + rendered = tp.process_template(s) + self.assertEqual("the_latest", rendered) + + @mock.patch("superset.jinja_context.context_addons") + def test_template_context_addons(self, addons_mock) -> None: + addons_mock.return_value = {"datetime": datetime} + maindb = utils.get_example_database() + s = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'" + tp = get_template_processor(database=maindb) + rendered = tp.process_template(s) + self.assertEqual("SELECT '2017-01-01T00:00:00'", rendered) + + @mock.patch("tests.superset_test_custom_template_processors.datetime") + def test_custom_process_template(self, mock_dt) -> None: + """Test macro defined in custom template processor works.""" + mock_dt.utcnow = mock.Mock(return_value=datetime(1970, 1, 1)) + db = mock.Mock() + db.backend = "db_for_macros_testing" + tp = get_template_processor(database=db) + + sql = "SELECT '$DATE()'" + rendered = tp.process_template(sql) + self.assertEqual("SELECT '{}'".format("1970-01-01"), rendered) + + sql = "SELECT '$DATE(1, 2)'" + rendered = tp.process_template(sql) + self.assertEqual("SELECT '{}'".format("1970-01-02"), rendered) + + def test_custom_get_template_kwarg(self) -> None: + """Test macro passed as kwargs when getting template processor + works in custom template processor.""" + db = mock.Mock() + db.backend = "db_for_macros_testing" + s = "$foo()" + tp = get_template_processor(database=db, foo=lambda: "bar") + rendered = tp.process_template(s) + self.assertEqual("bar", rendered) + + def test_custom_template_kwarg(self) -> None: + """Test macro passed as kwargs when processing template + works in custom template processor.""" + db = mock.Mock() + db.backend = "db_for_macros_testing" + s = "$foo()" + tp = get_template_processor(database=db) + rendered = tp.process_template(s, foo=lambda: "bar") + self.assertEqual("bar", rendered) + + def test_custom_template_processors_overwrite(self) -> None: + """Test template processor for presto gets overwritten by custom one.""" + db = mock.Mock() + db.backend = "db_for_macros_testing" + tp = get_template_processor(database=db) + + sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'" + rendered = tp.process_template(sql) + self.assertEqual(sql, rendered) + + sql = "SELECT '{{ DATE(1, 2) }}'" + rendered = tp.process_template(sql) + self.assertEqual(sql, rendered) + + def test_custom_template_processors_ignored(self) -> None: + """Test custom template processor is ignored for a difference backend + database.""" + maindb = utils.get_example_database() + sql = "SELECT '$DATE()'" + tp = get_template_processor(database=maindb) + rendered = tp.process_template(sql) + assert sql == rendered