Skip to content

Commit

Permalink
feat: implement cache invalidation api (apache#10761)
Browse files Browse the repository at this point in the history
* Add cache endpoints

* Implement cache endpoint

* Tests and address feedback

* Set cache config

* Address feedback

* Expose only invalidate endpoint

Co-authored-by: bogdan kyryliuk <bogdankyryliuk@dropbox.com>
  • Loading branch information
2 people authored and auxten committed Nov 20, 2020
1 parent 63096b2 commit f5ecb89
Show file tree
Hide file tree
Showing 10 changed files with 422 additions and 27 deletions.
2 changes: 2 additions & 0 deletions superset/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def init_views(self) -> None:
#
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
from superset.cachekeys.api import CacheRestApi
from superset.charts.api import ChartRestApi
from superset.connectors.druid.views import (
Druid,
Expand Down Expand Up @@ -194,6 +195,7 @@ def init_views(self) -> None:
#
# Setup API views
#
appbuilder.add_api(CacheRestApi)
appbuilder.add_api(ChartRestApi)
appbuilder.add_api(DashboardRestApi)
appbuilder.add_api(DatabaseRestApi)
Expand Down
16 changes: 16 additions & 0 deletions superset/cachekeys/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
123 changes: 123 additions & 0 deletions superset/cachekeys/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging

from flask import request, Response
from flask_appbuilder import expose
from flask_appbuilder.api import safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_appbuilder.security.decorators import protect
from marshmallow.exceptions import ValidationError
from sqlalchemy.exc import SQLAlchemyError

from superset.cachekeys.schemas import CacheInvalidationRequestSchema
from superset.connectors.connector_registry import ConnectorRegistry
from superset.extensions import cache_manager, db, event_logger
from superset.models.cache import CacheKey
from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics

logger = logging.getLogger(__name__)


class CacheRestApi(BaseSupersetModelRestApi):
datamodel = SQLAInterface(CacheKey)
resource_name = "cachekey"
allow_browser_login = True
class_permission_name = "CacheRestApi"
include_route_methods = {
"invalidate",
}

openapi_spec_component_schemas = (CacheInvalidationRequestSchema,)

@expose("/invalidate", methods=["POST"])
@event_logger.log_this
@protect()
@safe
@statsd_metrics
def invalidate(self) -> Response:
"""
Takes a list of datasources, finds the associated cache records and
invalidates them and removes the database records
---
post:
description: >-
Takes a list of datasources, finds the associated cache records and
invalidates them and removes the database records
requestBody:
description: >-
A list of datasources uuid or the tuples of database and datasource names
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/CacheInvalidationRequestSchema"
responses:
201:
description: cache was successfully invalidated
400:
$ref: '#/components/responses/400'
500:
$ref: '#/components/responses/500'
"""
try:
datasources = CacheInvalidationRequestSchema().load(request.json)
except KeyError:
return self.response_400(message="Request is incorrect")
except ValidationError as error:
return self.response_400(message=str(error))
datasource_uids = set(datasources.get("datasource_uids", []))
for ds in datasources.get("datasources", []):
ds_obj = ConnectorRegistry.get_datasource_by_name(
session=db.session,
datasource_type=ds.get("datasource_type"),
datasource_name=ds.get("datasource_name"),
schema=ds.get("schema"),
database_name=ds.get("database_name"),
)
if ds_obj:
datasource_uids.add(ds_obj.uid)

cache_key_objs = (
db.session.query(CacheKey)
.filter(CacheKey.datasource_uid.in_(datasource_uids))
.all()
)
cache_keys = [c.cache_key for c in cache_key_objs]
if cache_key_objs:
all_keys_deleted = cache_manager.cache.delete_many(*cache_keys)

if not all_keys_deleted:
# expected behavior as keys may expire and cache is not a
# persistent storage
logger.info(
"Some of the cache keys were not deleted in the list %s", cache_keys
)

try:
delete_stmt = CacheKey.__table__.delete().where( # pylint: disable=no-member
CacheKey.cache_key.in_(cache_keys)
)
db.session.execute(delete_stmt)
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
logger.error(ex)
db.session.rollback()
return self.response_500(str(ex))
db.session.commit()
return self.response(201)
45 changes: 45 additions & 0 deletions superset/cachekeys/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# RISON/JSON schemas for query parameters
from marshmallow import fields, Schema, validate

from superset.charts.schemas import (
datasource_name_description,
datasource_type_description,
datasource_uid_description,
)


class Datasource(Schema):
database_name = fields.String(description="Datasource name",)
datasource_name = fields.String(description=datasource_name_description,)
schema = fields.String(description="Datasource schema",)
datasource_type = fields.String(
description=datasource_type_description,
validate=validate.OneOf(choices=("druid", "table", "view")),
required=True,
)


class CacheInvalidationRequestSchema(Schema):
datasource_uids = fields.List(
fields.String(), description=datasource_uid_description,
)
datasources = fields.List(
fields.Nested(Datasource),
description="A list of the data source and database names",
)
4 changes: 4 additions & 0 deletions superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@
"A complete datasource identification needs `datasouce_id` "
"and `datasource_type`."
)
datasource_uid_description = (
"The uid of the dataset/datasource this new chart will use. "
"A complete datasource identification needs `datasouce_uid` "
)
datasource_type_description = (
"The type of dataset/datasource identified on `datasource_id`."
)
Expand Down
63 changes: 45 additions & 18 deletions tests/base_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from unittest.mock import Mock, patch

import pandas as pd
import pytest
from flask import Response
from flask_appbuilder.security.sqla import models as ab_models
from flask_testing import TestCase
Expand All @@ -42,6 +43,7 @@
from superset.views.base_api import BaseSupersetModelRestApi

FAKE_DB_NAME = "fake_db_100"
test_client = app.test_client()


def login(client: Any, username: str = "admin", password: str = "general"):
Expand Down Expand Up @@ -69,6 +71,39 @@ def get_resp(
return resp.data.decode("utf-8")


def post_assert_metric(
client: Any, uri: str, data: Dict[str, Any], func_name: str
) -> Response:
"""
Simple client post with an extra assertion for statsd metrics
:param client: test client for superset api requests
:param uri: The URI to use for the HTTP POST
:param data: The JSON data payload to be posted
:param func_name: The function name that the HTTP POST triggers
for the statsd metric assertion
:return: HTTP Response
"""
with patch.object(
BaseSupersetModelRestApi, "incr_stats", return_value=None
) as mock_method:
rv = client.post(uri, json=data)
if 200 <= rv.status_code < 400:
mock_method.assert_called_once_with("success", func_name)
else:
mock_method.assert_called_once_with("error", func_name)
return rv


@pytest.fixture
def logged_in_admin():
"""Fixture with app context and logged in admin user."""
with app.app_context():
login(test_client, username="admin")
yield
test_client.get("/logout/", follow_redirects=True)


class SupersetTestCase(TestCase):

default_schema_backend_map = {
Expand All @@ -84,6 +119,15 @@ class SupersetTestCase(TestCase):
def create_app(self):
return app

@staticmethod
def get_birth_names_dataset():
example_db = get_example_database()
return (
db.session.query(SqlaTable)
.filter_by(database=example_db, table_name="birth_names")
.one()
)

@staticmethod
def create_user_with_roles(username: str, roles: List[str]):
user_to_create = security_manager.find_user(username)
Expand Down Expand Up @@ -422,24 +466,7 @@ def delete_assert_metric(self, uri: str, func_name: str) -> Response:
def post_assert_metric(
self, uri: str, data: Dict[str, Any], func_name: str
) -> Response:
"""
Simple client post with an extra assertion for statsd metrics
:param uri: The URI to use for the HTTP POST
:param data: The JSON data payload to be posted
:param func_name: The function name that the HTTP POST triggers
for the statsd metric assertion
:return: HTTP Response
"""
with patch.object(
BaseSupersetModelRestApi, "incr_stats", return_value=None
) as mock_method:
rv = self.client.post(uri, json=data)
if 200 <= rv.status_code < 400:
mock_method.assert_called_once_with("success", func_name)
else:
mock_method.assert_called_once_with("error", func_name)
return rv
return post_assert_metric(self.client, uri, data, func_name)

def put_assert_metric(
self, uri: str, data: Dict[str, Any], func_name: str
Expand Down
16 changes: 16 additions & 0 deletions tests/cachekeys/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Loading

0 comments on commit f5ecb89

Please sign in to comment.