Skip to content

Commit

Permalink
feat(saved_queries): add custom api filter for all string & text fiel…
Browse files Browse the repository at this point in the history
…ds (apache#11031)
  • Loading branch information
nytai authored and auxten committed Nov 20, 2020
1 parent 5835a43 commit 18d4ae6
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 3 deletions.
3 changes: 2 additions & 1 deletion superset-frontend/src/components/ListView/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ export interface Filter {
| 'rel_m_m'
| 'rel_o_m'
| 'title_or_slug'
| 'name_or_description';
| 'name_or_description'
| 'all_text';
input?: 'text' | 'textarea' | 'select' | 'checkbox' | 'search';
unfilteredLabel?: string;
selects?: SelectOption[];
Expand Down
7 changes: 6 additions & 1 deletion superset/queries/saved_queries/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
SavedQueryBulkDeleteFailedError,
SavedQueryNotFoundError,
)
from superset.queries.saved_queries.filters import SavedQueryFilter
from superset.queries.saved_queries.filters import (
SavedQueryAllTextFilter,
SavedQueryFilter,
)
from superset.queries.saved_queries.schemas import (
get_delete_ids_schema,
openapi_spec_methods_override,
Expand Down Expand Up @@ -93,6 +96,8 @@ class SavedQueryRestApi(BaseSupersetModelRestApi):
"database.database_name",
]

search_filters = {"label": [SavedQueryAllTextFilter]}

apispec_parameter_schemas = {
"get_delete_ids_schema": get_delete_ids_schema,
}
Expand Down
21 changes: 21 additions & 0 deletions superset/queries/saved_queries/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,33 @@
from typing import Any

from flask import g
from flask_babel import lazy_gettext as _
from flask_sqlalchemy import BaseQuery
from sqlalchemy import or_
from sqlalchemy.orm.query import Query

from superset.models.sql_lab import SavedQuery
from superset.views.base import BaseFilter


class SavedQueryAllTextFilter(BaseFilter): # pylint: disable=too-few-public-methods
name = _("All Text")
arg_name = "all_text"

def apply(self, query: Query, value: Any) -> Query:
if not value:
return query
ilike_value = f"%{value}%"
return query.filter(
or_(
SavedQuery.schema.ilike(ilike_value),
SavedQuery.label.ilike(ilike_value),
SavedQuery.description.ilike(ilike_value),
SavedQuery.sql.ilike(ilike_value),
)
)


class SavedQueryFilter(BaseFilter): # pylint: disable=too-few-public-methods
def apply(self, query: BaseQuery, value: Any) -> BaseQuery:
"""
Expand Down
94 changes: 93 additions & 1 deletion tests/queries/saved_queries/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def insert_saved_query(
db_id: Optional[int] = None,
created_by=None,
schema: Optional[str] = "",
description: Optional[str] = "",
) -> SavedQuery:
database = None
if db_id:
Expand All @@ -53,6 +54,7 @@ def insert_saved_query(
sql=sql,
label=label,
schema=schema,
description=description,
)
db.session.add(query)
db.session.commit()
Expand All @@ -69,6 +71,7 @@ def insert_default_saved_query(
db_id=example_db.id,
created_by=admin,
schema=schema,
description="cool description",
)

@pytest.fixture()
Expand Down Expand Up @@ -195,6 +198,95 @@ def test_get_list_filter_saved_query(self):
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == len(all_queries)

@pytest.mark.usefixtures("create_saved_queries")
def test_get_list_custom_filter_schema_saved_query(self):
"""
Saved Query API: Test get list and custom filter (schema) saved query
"""
self.login(username="admin")
admin = self.get_user("admin")

all_queries = (
db.session.query(SavedQuery)
.filter(SavedQuery.created_by == admin)
.filter(SavedQuery.schema.ilike("%2%"))
.all()
)
query_string = {
"filters": [{"col": "label", "opr": "all_text", "value": "schema2"}],
}
uri = f"api/v1/saved_query/?q={prison.dumps(query_string)}"
rv = self.get_assert_metric(uri, "get_list")
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == len(all_queries)

@pytest.mark.usefixtures("create_saved_queries")
def test_get_list_custom_filter_label_saved_query(self):
"""
Saved Query API: Test get list and custom filter (label) saved query
"""
self.login(username="admin")
admin = self.get_user("admin")
all_queries = (
db.session.query(SavedQuery)
.filter(SavedQuery.created_by == admin)
.filter(SavedQuery.label.ilike("%3%"))
.all()
)
query_string = {
"filters": [{"col": "label", "opr": "all_text", "value": "label3"}],
}
uri = f"api/v1/saved_query/?q={prison.dumps(query_string)}"
rv = self.get_assert_metric(uri, "get_list")
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == len(all_queries)

@pytest.mark.usefixtures("create_saved_queries")
def test_get_list_custom_filter_sql_saved_query(self):
"""
Saved Query API: Test get list and custom filter (sql) saved query
"""
self.login(username="admin")
admin = self.get_user("admin")
all_queries = (
db.session.query(SavedQuery)
.filter(SavedQuery.created_by == admin)
.filter(SavedQuery.sql.ilike("%table%"))
.all()
)
query_string = {
"filters": [{"col": "label", "opr": "all_text", "value": "table"}],
}
uri = f"api/v1/saved_query/?q={prison.dumps(query_string)}"
rv = self.get_assert_metric(uri, "get_list")
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == len(all_queries)

@pytest.mark.usefixtures("create_saved_queries")
def test_get_list_custom_filter_description_saved_query(self):
"""
Saved Query API: Test get list and custom filter (description) saved query
"""
self.login(username="admin")
admin = self.get_user("admin")
all_queries = (
db.session.query(SavedQuery)
.filter(SavedQuery.created_by == admin)
.filter(SavedQuery.description.ilike("%cool%"))
.all()
)
query_string = {
"filters": [{"col": "label", "opr": "all_text", "value": "cool"}],
}
uri = f"api/v1/saved_query/?q={prison.dumps(query_string)}"
rv = self.get_assert_metric(uri, "get_list")
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == len(all_queries)

def test_info_saved_query(self):
"""
SavedQuery API: Test info
Expand Down Expand Up @@ -281,7 +373,7 @@ def test_get_saved_query(self):
expected_result = {
"id": saved_query.id,
"database": {"id": saved_query.database.id, "database_name": "examples"},
"description": None,
"description": "cool description",
"created_by": {
"first_name": saved_query.created_by.first_name,
"id": saved_query.created_by.id,
Expand Down

0 comments on commit 18d4ae6

Please sign in to comment.