Skip to content

Commit

Permalink
[api] New, add custom filters to search (#1327)
Browse files Browse the repository at this point in the history
  • Loading branch information
dpgaspar authored Mar 30, 2020
1 parent 9d6c435 commit 7f52f65
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 28 deletions.
18 changes: 17 additions & 1 deletion examples/crud_rest_api/app/api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from flask_appbuilder import ModelRestApi
from flask_appbuilder.api import BaseApi, expose
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_appbuilder.models.filters import BaseFilter
from sqlalchemy import or_

from . import appbuilder, db
from .models import Contact, ContactGroup, Gender, ModelOMChild, ModelOMParent
from .models import Contact, ContactGroup, Gender, ModelOMParent


def fill_gender():
Expand Down Expand Up @@ -52,11 +54,25 @@ def greeting(self):
appbuilder.add_api(GreetingApi)


class CustomFilter(BaseFilter):
name = "Custom Filter"
arg_name = "opr"

def apply(self, query, value):
return query.filter(
or_(
Contact.name.like(value + "%"),
Contact.address.like(value + "%"),
)
)


class ContactModelApi(ModelRestApi):
resource_name = "contact"
datamodel = SQLAInterface(Contact)
allow_browser_login = True

search_filters = {"name": [CustomFilter]}
openapi_spec_methods = {
"get_list": {
"get": {
Expand Down
9 changes: 7 additions & 2 deletions flask_appbuilder/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,10 @@ class MyView(ModelRestApi):
search_columns = ['name', 'address']
"""
search_filters = None
"""
Override default search filters for columns
"""
search_exclude_columns = None
"""
List with columns to exclude from search. Search includes all possible
Expand Down Expand Up @@ -846,7 +850,6 @@ def _init_properties(self):
x for x in search_columns if x not in self.search_exclude_columns
]
self._gen_labels_columns(self.datamodel.get_columns_list())
self._filters = self.datamodel.get_filters(self.search_columns)

def _init_titles(self):
pass
Expand Down Expand Up @@ -1114,7 +1117,9 @@ def _init_properties(self):
]
self._gen_labels_columns(self.list_columns)
self._gen_labels_columns(self.show_columns)
self._filters = self.datamodel.get_filters(self.search_columns)
self._filters = self.datamodel.get_filters(
search_columns=self.search_columns, search_filters=self.search_filters
)
self.edit_query_rel_fields = self.edit_query_rel_fields or dict()
self.add_query_rel_fields = self.add_query_rel_fields or dict()

Expand Down
9 changes: 7 additions & 2 deletions flask_appbuilder/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,14 @@ def _get_attr_value(item, col):
return value.value
return value

def get_filters(self, search_columns=None):
def get_filters(self, search_columns=None, search_filters=None):
search_columns = search_columns or []
return Filters(self.filter_converter_class, self, search_columns)
return Filters(
self.filter_converter_class,
self,
search_columns=search_columns,
search_filters=search_filters,
)

def get_values_item(self, item, show_columns):
return [self._get_attr_value(item, col) for col in show_columns]
Expand Down
70 changes: 48 additions & 22 deletions flask_appbuilder/models/filters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import logging
from typing import Any, Dict, List, Tuple

from .._compat import as_unicode
from ..exceptions import (
Expand Down Expand Up @@ -69,10 +70,14 @@ class FilterRelation(BaseFilter):
Base class for all filters for relations
"""

pass
def apply(self, query, value):
"""
Override this to implement your own new filters
"""
raise NotImplementedError


class BaseFilterConverter(object):
class BaseFilterConverter:
"""
Base Filter Converter, all classes responsible
for the association of columns and possible filters
Expand Down Expand Up @@ -113,20 +118,27 @@ def convert(self, col_name):


class Filters(object):
filters = []
""" List of instanciated BaseFilter classes """
values = []
filters: List[BaseFilter] = []
""" List of instantiated BaseFilter classes """
values: List[Any] = []
""" list of values to apply to filters """
_search_filters = {}
_search_filters: Dict[str, List[BaseFilter]] = {}
""" dict like {'col_name':[BaseFilter1, BaseFilter2, ...], ... } """
_all_filters = {}

def __init__(self, filter_converter, datamodel, search_columns=None):
_all_filters: Dict[str, List[BaseFilter]] = {}

def __init__(
self,
filter_converter: BaseFilterConverter,
datamodel,
search_columns: List[str] = None,
search_filters: Dict[str, List[BaseFilter]] = None,
):
"""
:param filter_converter: Accepts BaseFilterConverter class
:param search_columns: restricts possible columns,
accepts a list of column names
:param search_filters: Add custom defined filters to specific columns
:param datamodel: Accepts BaseInterface class
"""
self.search_columns = search_columns or []
Expand All @@ -137,10 +149,14 @@ def __init__(self, filter_converter, datamodel, search_columns=None):
self._search_filters = self._get_filters(self.search_columns)
self._all_filters = self._get_filters(datamodel.get_columns_list())

if search_filters:
for k, v in search_filters.items():
self._search_filters[k] += v

def get_search_filters(self):
return self._search_filters

def _get_filters(self, cols):
def _get_filters(self, cols: List[str]):
filters = {}
for col in cols:
_filters = self.filter_converter(self.datamodel).convert(col)
Expand All @@ -156,10 +172,12 @@ def _add_filter(self, filter_instance, value):
self.filters.append(filter_instance)
self.values.append(value)

def add_filter_index(self, column_name, filter_instance_index, value):
def add_filter_index(
self, column_name: str, filter_instance_index: int, value: Any
):
self._add_filter(self._all_filters[column_name][filter_instance_index], value)

def rest_add_filters(self, data):
def rest_add_filters(self, data: List[Dict]) -> None:
"""
Adds list of dicts
Expand All @@ -174,18 +192,26 @@ def rest_add_filters(self, data):
except KeyError:
log.warning("Invalid filter")
return
# Get filter class from defaults
filter_class = map_args_filter.get(opr, None)
if filter_class:
if _filter["col"] not in self.search_columns:
if col not in self.search_columns:
raise InvalidColumnFilterFABException(
f"Filter column: {col} not allowed to filter"
)
elif not self._rest_check_valid_filter_operation(col, opr):
raise InvalidOperationFilterFABException(
f"Filter operation: {opr} not allowed on column: {col}"
)
else:
self.add_filter(col, filter_class, value)
self.add_filter(col, filter_class, value)
continue
# Get filter class from custom defined filters
filters = self._search_filters.get(col)
if filters:
for filter in filters:
if filter.arg_name == opr:
self.add_filter(col, filter, value)
break
else:
raise InvalidOperationFilterFABException(
f"Filter operation: {opr} not allowed on column: {col}"
Expand Down Expand Up @@ -215,10 +241,10 @@ def get_joined_filters(self, filters):
"""
Creates a new filters class with active filters joined
"""
retfilters = Filters(self.filter_converter, self.datamodel)
retfilters.filters = self.filters + filters.filters
retfilters.values = self.values + filters.values
return retfilters
ret_filters = Filters(self.filter_converter, self.datamodel)
ret_filters.filters = self.filters + filters.filters
ret_filters.values = self.values + filters.values
return ret_filters

def copy(self):
"""
Expand All @@ -241,13 +267,13 @@ def get_relation_cols(self):
retlst.append(flt.column_name)
return retlst

def get_filters_values(self):
def get_filters_values(self) -> List[Tuple[BaseFilter, Any]]:
"""
Returns a list of tuples [(FILTER, value),(...,...),....]
"""
return [(flt, value) for flt, value in zip(self.filters, self.values)]

def get_filter_value(self, column_name):
def get_filter_value(self, column_name: str) -> Any:
"""
Returns the filtered value for a certain column
Expand All @@ -258,7 +284,7 @@ def get_filter_value(self, column_name):
if flt.column_name == column_name:
return value

def get_filters_values_tojson(self):
def get_filters_values_tojson(self) -> List[Tuple[str, str, Any]]:
return [
(flt.column_name, as_unicode(flt.name), value)
for flt, value in zip(self.filters, self.values)
Expand Down
Loading

0 comments on commit 7f52f65

Please sign in to comment.