Skip to content

Commit

Permalink
fix(Jinja): Extra cache keys for Jinja columns (#30715)
Browse files Browse the repository at this point in the history
  • Loading branch information
geido authored Oct 25, 2024
1 parent 1c56857 commit a12ccf2
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
6 changes: 5 additions & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
)
from superset.utils import core as utils, json
from superset.utils.backports import StrEnum
from superset.utils.core import GenericDataType, MediumText
from superset.utils.core import GenericDataType, is_adhoc_column, MediumText

config = app.config
metadata = Model.metadata # pylint: disable=no-member
Expand Down Expand Up @@ -1980,6 +1980,10 @@ def has_extra_cache_key_calls(self, query_obj: QueryObjectDict) -> bool:
templatable_statements.append(extras["where"])
if "having" in extras:
templatable_statements.append(extras["having"])
if "columns" in query_obj:
templatable_statements += [
c["sqlExpression"] for c in query_obj["columns"] if is_adhoc_column(c)
]
if self.is_rls_supported:
templatable_statements += [
f.clause for f in security_manager.get_rls_filters(self)
Expand Down
50 changes: 50 additions & 0 deletions tests/integration_tests/sqla_models_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,56 @@ def test_extra_cache_keys_in_sql_expression(
assert extra_cache_keys == expected_cache_keys


@pytest.mark.usefixtures("app_context")
@pytest.mark.parametrize(
"sql_expression,expected_cache_keys,has_extra_cache_keys",
[
("'{{ current_username() }}'", ["abc"], True),
("(user != 'abc')", [], False),
],
)
@patch("superset.jinja_context.get_user_id", return_value=1)
@patch("superset.jinja_context.get_username", return_value="abc")
@patch("superset.jinja_context.get_user_email", return_value="abc@test.com")
def test_extra_cache_keys_in_columns(
mock_user_email,
mock_username,
mock_user_id,
sql_expression,
expected_cache_keys,
has_extra_cache_keys,
):
table = SqlaTable(
table_name="test_has_no_extra_cache_keys_table",
sql="SELECT 'abc' as user",
database=get_example_database(),
)
base_query_obj = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": [],
"metrics": [],
"is_timeseries": False,
"filter": [],
}

query_obj = dict(
**base_query_obj,
columns=[
{
"label": None,
"expressionType": "SQL",
"sqlExpression": sql_expression,
}
],
)

extra_cache_keys = table.get_extra_cache_keys(query_obj)
assert table.has_extra_cache_key_calls(query_obj) == has_extra_cache_keys
assert extra_cache_keys == expected_cache_keys


@pytest.mark.usefixtures("app_context")
@pytest.mark.parametrize(
"row,dimension,result",
Expand Down

0 comments on commit a12ccf2

Please sign in to comment.