From e1592f3263819536b0a7788c7aba517c9582029a Mon Sep 17 00:00:00 2001 From: Grace Guo Date: Tue, 29 Sep 2020 10:57:16 -0700 Subject: [PATCH] feat: enable ETag header for dashboard GET requests (#10963) * feat: add etag for dashboard load requests * fix review comments --- superset/config.py | 1 + superset/utils/decorators.py | 28 +++++++++++-- superset/views/core.py | 78 ++++++++++++++++++------------------ superset/views/utils.py | 64 +++++++++++++++++++++++++++-- tests/utils_tests.py | 12 ++++++ 5 files changed, 136 insertions(+), 47 deletions(-) diff --git a/superset/config.py b/superset/config.py index adbb95986b19d..e7bf8c1b99893 100644 --- a/superset/config.py +++ b/superset/config.py @@ -297,6 +297,7 @@ def _try_json_readsha( # pylint: disable=unused-argument # Experimental feature introducing a client (browser) cache "CLIENT_CACHE": False, "ENABLE_EXPLORE_JSON_CSRF_PROTECTION": False, + "ENABLE_DASHBOARD_ETAG_HEADER": False, "KV_STORE": False, "PRESTO_EXPAND_DATA": False, # Exposes API endpoint to compute thumbnails diff --git a/superset/utils/decorators.py b/superset/utils/decorators.py index 694e07bd2c434..ae4c5726d8ee7 100644 --- a/superset/utils/decorators.py +++ b/superset/utils/decorators.py @@ -17,7 +17,7 @@ import logging from datetime import datetime, timedelta from functools import wraps -from typing import Any, Callable, Iterator +from typing import Any, Callable, Iterator, Optional from contextlib2 import contextmanager from flask import request @@ -46,7 +46,12 @@ def stats_timing(stats_key: str, stats_logger: BaseStatsLogger) -> Iterator[floa stats_logger.timing(stats_key, now_as_float() - start_ts) -def etag_cache(max_age: int, check_perms: Callable[..., Any]) -> Callable[..., Any]: +def etag_cache( + max_age: int, + check_perms: Callable[..., Any], + get_last_modified: Optional[Callable[..., Any]] = None, + skip: Optional[Callable[..., Any]] = None, +) -> Callable[..., Any]: """ A decorator for caching views and handling etag conditional requests. @@ -69,7 +74,7 @@ def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin: # for POST requests we can't set cache headers, use the response # cache nor use conditional requests; this will still use the # dataframe cache in `superset/viz.py`, though. - if request.method == "POST": + if request.method == "POST" or (skip and skip(*args, **kwargs)): return f(*args, **kwargs) response = None @@ -89,13 +94,28 @@ def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin: raise logger.exception("Exception possibly due to cache backend.") + # if cache is stale? + if get_last_modified: + content_changed_time = get_last_modified(*args, **kwargs) + if ( + response + and response.last_modified + and response.last_modified.timestamp() + < content_changed_time.timestamp() + ): + response = None + else: + # if caller didn't provide content's last_modified time, assume + # its cache won't be stale. + content_changed_time = datetime.utcnow() + # if no response was cached, compute it using the wrapped function if response is None: response = f(*args, **kwargs) # add headers for caching: Last Modified, Expires and ETag response.cache_control.public = True - response.last_modified = datetime.utcnow() + response.last_modified = content_changed_time expiration = max_age if max_age != 0 else FAR_FUTURE response.expires = response.last_modified + timedelta( seconds=expiration diff --git a/superset/views/core.py b/superset/views/core.py index 71080425b47d8..36668cf082539 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -26,7 +26,17 @@ import backoff import pandas as pd import simplejson as json -from flask import abort, flash, g, Markup, redirect, render_template, request, Response +from flask import ( + abort, + flash, + g, + make_response, + Markup, + redirect, + render_template, + request, + Response, +) from flask_appbuilder import expose from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.security.decorators import has_access, has_access_api @@ -114,9 +124,12 @@ _deserialize_results_payload, apply_display_max_row_limit, bootstrap_user_data, + check_dashboard_perms, check_datasource_perms, check_slice_perms, get_cta_schema_name, + get_dashboard, + get_dashboard_changedon_dt, get_dashboard_extra_filters, get_datasource_info, get_form_data, @@ -1585,49 +1598,32 @@ def publish( # pylint: disable=no-self-use return json_success(json.dumps({"published": dash.published})) @has_access + @etag_cache( + 0, + check_perms=check_dashboard_perms, + get_last_modified=get_dashboard_changedon_dt, + skip=lambda _self, dashboard_id_or_slug: not is_feature_enabled( + "ENABLE_DASHBOARD_ETAG_HEADER" + ), + ) @expose("/dashboard//") def dashboard( # pylint: disable=too-many-locals self, dashboard_id_or_slug: str ) -> FlaskResponse: """Server side rendering for a dashboard""" - session = db.session() - qry = session.query(Dashboard) - if dashboard_id_or_slug.isdigit(): - qry = qry.filter_by(id=int(dashboard_id_or_slug)) - else: - qry = qry.filter_by(slug=dashboard_id_or_slug) + dash = get_dashboard(dashboard_id_or_slug) - dash = qry.one_or_none() - if not dash: - abort(404) - - datasources = defaultdict(list) + slices_by_datasources = defaultdict(list) for slc in dash.slices: datasource = slc.datasource if datasource: - datasources[datasource].append(slc) - - if config["ENABLE_ACCESS_REQUEST"]: - for datasource in datasources: - if datasource and not security_manager.can_access_datasource( - datasource - ): - flash( - __( - security_manager.get_datasource_access_error_msg(datasource) - ), - "danger", - ) - return redirect( - "superset/request_access/?" f"dashboard_id={dash.id}&" - ) - + slices_by_datasources[datasource].append(slc) # Filter out unneeded fields from the datasource payload datasources_payload = { datasource.uid: datasource.data_for_slices(slices) if is_feature_enabled("REDUCE_DASHBOARD_BOOTSTRAP_PAYLOAD") else datasource.data - for datasource, slices in datasources.items() + for datasource, slices in slices_by_datasources.items() } dash_edit_perm = check_ownership( @@ -1661,7 +1657,7 @@ def dashboard(**_: Any) -> None: if is_feature_enabled("REMOVE_SLICE_LEVEL_LABEL_COLORS"): # dashboard metadata has dashboard-level label_colors, # so remove slice-level label_colors from its form_data - for slc in dashboard_data.get("slices"): + for slc in dashboard_data.get("slices") or []: form_data = slc.get("form_data") form_data.pop("label_colors", None) @@ -1695,15 +1691,17 @@ def dashboard(**_: Any) -> None: json.dumps(bootstrap_data, default=utils.pessimistic_json_iso_dttm_ser) ) - return self.render_template( - "superset/dashboard.html", - entry="dashboard", - standalone_mode=standalone_mode, - title=dash.dashboard_title, - custom_css=dashboard_data.get("css"), - bootstrap_data=json.dumps( - bootstrap_data, default=utils.pessimistic_json_iso_dttm_ser - ), + return make_response( + self.render_template( + "superset/dashboard.html", + entry="dashboard", + standalone_mode=standalone_mode, + title=dash.dashboard_title, + custom_css=dashboard_data.get("css"), + bootstrap_data=json.dumps( + bootstrap_data, default=utils.pessimistic_json_iso_dttm_ser + ), + ) ) @api diff --git a/superset/views/utils.py b/superset/views/utils.py index eaecc5fe87031..dc164944c2b57 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -16,19 +16,27 @@ # under the License. import logging from collections import defaultdict -from datetime import date +from datetime import date, datetime from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple, Union from urllib import parse import msgpack import pyarrow as pa import simplejson as json -from flask import g, request +from flask import abort, flash, g, redirect, request from flask_appbuilder.security.sqla import models as ab_models from flask_appbuilder.security.sqla.models import User +from flask_babel import gettext as __ import superset.models.core as models -from superset import app, dataframe, db, is_feature_enabled, result_set +from superset import ( + app, + dataframe, + db, + is_feature_enabled, + result_set, + security_manager, +) from superset.connectors.connector_registry import ConnectorRegistry from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetException, SupersetSecurityException @@ -298,6 +306,36 @@ def get_time_range_endpoints( CONTAINER_TYPES = ["COLUMN", "GRID", "TABS", "TAB", "ROW"] +def get_dashboard(dashboard_id_or_slug: str,) -> Dashboard: + session = db.session() + qry = session.query(Dashboard) + if dashboard_id_or_slug.isdigit(): + qry = qry.filter_by(id=int(dashboard_id_or_slug)) + else: + qry = qry.filter_by(slug=dashboard_id_or_slug) + dashboard = qry.one_or_none() + + if not dashboard: + abort(404) + + return dashboard + + +def get_dashboard_changedon_dt(_self: Any, dashboard_id_or_slug: str) -> datetime: + """ + Get latest changed datetime for a dashboard. The change could be dashboard + metadata change, or any of its slice data change. + + This function takes `self` since it must have the same signature as the + the decorated method. + """ + dash = get_dashboard(dashboard_id_or_slug) + dash_changed_on = dash.changed_on + slices_changed_on = max([s.changed_on for s in dash.slices]) + # drop microseconds in datetime to match with last_modified header + return max(dash_changed_on, slices_changed_on).replace(microsecond=0) + + def get_dashboard_extra_filters( slice_id: int, dashboard_id: int ) -> List[Dict[str, Any]]: @@ -490,6 +528,26 @@ def check_slice_perms(_self: Any, slice_id: int) -> None: viz_obj.raise_for_access() +def check_dashboard_perms(_self: Any, dashboard_id_or_slug: str) -> None: + """ + Check if user can access a cached response from explore_json. + + This function takes `self` since it must have the same signature as the + the decorated method. + """ + + dash = get_dashboard(dashboard_id_or_slug) + datasources = list(dash.datasources) + if app.config["ENABLE_ACCESS_REQUEST"]: + for datasource in datasources: + if datasource and not security_manager.can_access_datasource(datasource): + flash( + __(security_manager.get_datasource_access_error_msg(datasource)), + "danger", + ) + redirect("superset/request_access/?" f"dashboard_id={dash.id}&") + + def _deserialize_results_payload( payload: Union[bytes, str], query: Query, use_msgpack: Optional[bool] = False ) -> Dict[str, Any]: diff --git a/tests/utils_tests.py b/tests/utils_tests.py index 035a7864feb20..c63fcc119328d 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -67,6 +67,7 @@ from superset.utils import schema from superset.views.utils import ( build_extra_filters, + get_dashboard_changedon_dt, get_form_data, get_time_range_endpoints, ) @@ -1134,3 +1135,14 @@ def test_get_form_data_token(self): assert get_form_data_token({"token": "token_abcdefg1"}) == "token_abcdefg1" generated_token = get_form_data_token({}) assert re.match(r"^token_[a-z0-9]{8}$", generated_token) is not None + + def test_get_dashboard_changedon_dt(self) -> None: + slug = "world_health" + dashboard = db.session.query(Dashboard).filter_by(slug=slug).one() + dashboard_last_changedon = dashboard.changed_on + slices = dashboard.slices + slices_last_changedon = max([slc.changed_on for slc in slices]) + # drop microsecond in datetime + assert get_dashboard_changedon_dt(self, slug) == max( + dashboard_last_changedon, slices_last_changedon + ).replace(microsecond=0)