Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] New, add custom filters #1327

Merged
merged 7 commits into from
Mar 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion examples/crud_rest_api/app/api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from flask_appbuilder import ModelRestApi
from flask_appbuilder.api import BaseApi, expose
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_appbuilder.models.filters import BaseFilter
from sqlalchemy import or_

from . import appbuilder, db
from .models import Contact, ContactGroup, Gender, ModelOMChild, ModelOMParent
from .models import Contact, ContactGroup, Gender, ModelOMParent


def fill_gender():
Expand Down Expand Up @@ -52,11 +54,25 @@ def greeting(self):
appbuilder.add_api(GreetingApi)


class CustomFilter(BaseFilter):
name = "Custom Filter"
arg_name = "opr"

def apply(self, query, value):
return query.filter(
or_(
Contact.name.like(value + "%"),
Contact.address.like(value + "%"),
)
)


class ContactModelApi(ModelRestApi):
resource_name = "contact"
datamodel = SQLAInterface(Contact)
allow_browser_login = True

search_filters = {"name": [CustomFilter]}
openapi_spec_methods = {
"get_list": {
"get": {
Expand Down
9 changes: 7 additions & 2 deletions flask_appbuilder/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,10 @@ class MyView(ModelRestApi):
search_columns = ['name', 'address']

"""
search_filters = None
"""
Override default search filters for columns
"""
search_exclude_columns = None
"""
List with columns to exclude from search. Search includes all possible
Expand Down Expand Up @@ -846,7 +850,6 @@ def _init_properties(self):
x for x in search_columns if x not in self.search_exclude_columns
]
self._gen_labels_columns(self.datamodel.get_columns_list())
self._filters = self.datamodel.get_filters(self.search_columns)

def _init_titles(self):
pass
Expand Down Expand Up @@ -1114,7 +1117,9 @@ def _init_properties(self):
]
self._gen_labels_columns(self.list_columns)
self._gen_labels_columns(self.show_columns)
self._filters = self.datamodel.get_filters(self.search_columns)
self._filters = self.datamodel.get_filters(
search_columns=self.search_columns, search_filters=self.search_filters
)
self.edit_query_rel_fields = self.edit_query_rel_fields or dict()
self.add_query_rel_fields = self.add_query_rel_fields or dict()

Expand Down
9 changes: 7 additions & 2 deletions flask_appbuilder/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,14 @@ def _get_attr_value(item, col):
return value.value
return value

def get_filters(self, search_columns=None):
def get_filters(self, search_columns=None, search_filters=None):
search_columns = search_columns or []
return Filters(self.filter_converter_class, self, search_columns)
return Filters(
self.filter_converter_class,
self,
search_columns=search_columns,
search_filters=search_filters,
)

def get_values_item(self, item, show_columns):
return [self._get_attr_value(item, col) for col in show_columns]
Expand Down
70 changes: 48 additions & 22 deletions flask_appbuilder/models/filters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import logging
from typing import Any, Dict, List, Tuple

from .._compat import as_unicode
from ..exceptions import (
Expand Down Expand Up @@ -69,10 +70,14 @@ class FilterRelation(BaseFilter):
Base class for all filters for relations
"""

pass
def apply(self, query, value):
"""
Override this to implement your own new filters
"""
raise NotImplementedError


class BaseFilterConverter(object):
class BaseFilterConverter:
"""
Base Filter Converter, all classes responsible
for the association of columns and possible filters
Expand Down Expand Up @@ -113,20 +118,27 @@ def convert(self, col_name):


class Filters(object):
filters = []
""" List of instanciated BaseFilter classes """
values = []
filters: List[BaseFilter] = []
""" List of instantiated BaseFilter classes """
values: List[Any] = []
""" list of values to apply to filters """
_search_filters = {}
_search_filters: Dict[str, List[BaseFilter]] = {}
""" dict like {'col_name':[BaseFilter1, BaseFilter2, ...], ... } """
_all_filters = {}

def __init__(self, filter_converter, datamodel, search_columns=None):
_all_filters: Dict[str, List[BaseFilter]] = {}

def __init__(
self,
filter_converter: BaseFilterConverter,
datamodel,
search_columns: List[str] = None,
search_filters: Dict[str, List[BaseFilter]] = None,
):
"""

:param filter_converter: Accepts BaseFilterConverter class
:param search_columns: restricts possible columns,
accepts a list of column names
:param search_filters: Add custom defined filters to specific columns
:param datamodel: Accepts BaseInterface class
"""
self.search_columns = search_columns or []
Expand All @@ -137,10 +149,14 @@ def __init__(self, filter_converter, datamodel, search_columns=None):
self._search_filters = self._get_filters(self.search_columns)
self._all_filters = self._get_filters(datamodel.get_columns_list())

if search_filters:
for k, v in search_filters.items():
self._search_filters[k] += v

def get_search_filters(self):
return self._search_filters

def _get_filters(self, cols):
def _get_filters(self, cols: List[str]):
filters = {}
for col in cols:
_filters = self.filter_converter(self.datamodel).convert(col)
Expand All @@ -156,10 +172,12 @@ def _add_filter(self, filter_instance, value):
self.filters.append(filter_instance)
self.values.append(value)

def add_filter_index(self, column_name, filter_instance_index, value):
def add_filter_index(
self, column_name: str, filter_instance_index: int, value: Any
):
self._add_filter(self._all_filters[column_name][filter_instance_index], value)

def rest_add_filters(self, data):
def rest_add_filters(self, data: List[Dict]) -> None:
"""
Adds list of dicts

Expand All @@ -174,18 +192,26 @@ def rest_add_filters(self, data):
except KeyError:
log.warning("Invalid filter")
return
# Get filter class from defaults
filter_class = map_args_filter.get(opr, None)
if filter_class:
if _filter["col"] not in self.search_columns:
if col not in self.search_columns:
raise InvalidColumnFilterFABException(
f"Filter column: {col} not allowed to filter"
)
elif not self._rest_check_valid_filter_operation(col, opr):
raise InvalidOperationFilterFABException(
f"Filter operation: {opr} not allowed on column: {col}"
)
else:
self.add_filter(col, filter_class, value)
self.add_filter(col, filter_class, value)
continue
# Get filter class from custom defined filters
filters = self._search_filters.get(col)
if filters:
for filter in filters:
if filter.arg_name == opr:
self.add_filter(col, filter, value)
break
else:
raise InvalidOperationFilterFABException(
f"Filter operation: {opr} not allowed on column: {col}"
Expand Down Expand Up @@ -215,10 +241,10 @@ def get_joined_filters(self, filters):
"""
Creates a new filters class with active filters joined
"""
retfilters = Filters(self.filter_converter, self.datamodel)
retfilters.filters = self.filters + filters.filters
retfilters.values = self.values + filters.values
return retfilters
ret_filters = Filters(self.filter_converter, self.datamodel)
ret_filters.filters = self.filters + filters.filters
ret_filters.values = self.values + filters.values
return ret_filters

def copy(self):
"""
Expand All @@ -241,13 +267,13 @@ def get_relation_cols(self):
retlst.append(flt.column_name)
return retlst

def get_filters_values(self):
def get_filters_values(self) -> List[Tuple[BaseFilter, Any]]:
"""
Returns a list of tuples [(FILTER, value),(...,...),....]
"""
return [(flt, value) for flt, value in zip(self.filters, self.values)]

def get_filter_value(self, column_name):
def get_filter_value(self, column_name: str) -> Any:
"""
Returns the filtered value for a certain column

Expand All @@ -258,7 +284,7 @@ def get_filter_value(self, column_name):
if flt.column_name == column_name:
return value

def get_filters_values_tojson(self):
def get_filters_values_tojson(self) -> List[Tuple[str, str, Any]]:
return [
(flt.column_name, as_unicode(flt.name), value)
for flt, value in zip(self.filters, self.values)
Expand Down
Loading