From be40048e5616c6bde95cd3b72b3c452062f4a3a8 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Tue, 14 Apr 2020 09:20:43 +0300 Subject: [PATCH 1/7] Add OpenAPI docs to /api/v1/chart/data EP --- superset/charts/api.py | 45 +-- superset/charts/schemas.py | 414 ++++++++++++++++++++++++ superset/connectors/base/models.py | 19 +- superset/connectors/druid/models.py | 68 ++-- superset/connectors/sqla/models.py | 46 +-- superset/examples/birth_names.py | 2 +- superset/examples/world_bank.py | 2 +- superset/typing.py | 2 + superset/utils/core.py | 49 ++- superset/utils/pandas_postprocessing.py | 9 +- 10 files changed, 556 insertions(+), 100 deletions(-) diff --git a/superset/charts/api.py b/superset/charts/api.py index be0df118b20c3..66af2437d08e1 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=line-too-long import logging from typing import Any, Dict @@ -41,6 +42,7 @@ from superset.charts.commands.update import UpdateChartCommand from superset.charts.filters import ChartFilter, ChartNameOrDescriptionFilter from superset.charts.schemas import ( + CHART_DATA_SCHEMAS, ChartPostSchema, ChartPutSchema, get_delete_ids_schema, @@ -381,42 +383,14 @@ def data(self) -> Response: Takes a query context constructed in the client and returns payload data response for the given query. requestBody: - description: Query context schema + description: >- + A query context consists of a datasource from which to fetch data + and one or many query objects. required: true content: application/json: schema: - type: object - properties: - datasource: - type: object - description: The datasource where the query will run - properties: - id: - type: integer - type: - type: string - queries: - type: array - items: - type: object - properties: - granularity: - type: string - groupby: - type: array - items: - type: string - metrics: - type: array - items: - type: object - filters: - type: array - items: - type: string - row_limit: - type: integer + $ref: "#/components/schemas/ChartDataQueryContext" responses: 200: description: Query result @@ -533,3 +507,10 @@ def thumbnail( return Response( FileWrapper(screenshot), mimetype="image/png", direct_passthrough=True ) + + def add_apispec_components(self, api_spec): + for chart_type in CHART_DATA_SCHEMAS: + api_spec.components.schema( + chart_type.__name__, schema=chart_type, + ) + super().add_apispec_components(api_spec) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index bf1b57b321922..132431a9f9f74 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -59,3 +59,417 @@ class ChartPutSchema(Schema): datasource_id = fields.Integer(allow_none=True) datasource_type = fields.String(allow_none=True) dashboards = fields.List(fields.Integer()) + + +class ChartDataColumn(Schema): + column_name = fields.String( + description="The name of the target column", example="mycol", + ) + type = fields.String(description="Type of target column", example="BIGINT",) + + +class ChartDataAdhocMetric(Schema): + """ + Ad-hoc metrics are used to define metrics outside the datasource. + """ + + expressionType = fields.String( + description="Simple or SQL metric", + required=True, + enum=["SIMPLE", "SQL"], + example="SQL", + ) + aggregate = fields.String( + description="Aggregation operator. Only required for simple expression types.", + required=False, + enum=["AVG", "COUNT", "COUNT_DISTINCT", "MAX", "MIN", "SUM"], + ) + column = fields.Nested(ChartDataColumn) + sqlExpression = fields.String( + description="The metric as defined by a SQL aggregate expression. " + "Only required for SQL expression type.", + required=False, + example="SUM(weight * observations) / SUM(weight)", + ) + label = fields.String( + description="Label for the metric. Is automatically generated unless " + "hasCustomLabel is true, in which case label must be defined.", + required=False, + example="Weighted observations", + ) + hasCustomLabel = fields.Boolean( + description="When false, the label will be automatically generated based on " + "the aggregate expression. When true, a custom label has to be " + "specified.", + required=False, + example=True, + ) + optionName = fields.String( + description="Unique identifier. Can be any string value, as long as all " + "metrics have a unique identifier. If undefined, a random name " + "will be generated.", + required=False, + example="metric_aec60732-fac0-4b17-b736-93f1a5c93e30", + ) + + +class ChartDataAggregateConfig(fields.Dict): + def __init__(self): + super().__init__( + description="The keys are the name of the aggregate column to be created, " + "and the values specify the details of how to apply the " + "aggregation. If an operator requires additional options, " + "these can be passed here to be unpacked in the operator call. The following " + "numpy operators are supported: average, argmin, argmax, cumsum, cumprod, " + "max, mean, median, nansum, nanmin, nanmax, nanmean, nanmedian, min, " + "percentile, prod, product, std, sum, var. Any options required by the " + "operator can be passed to the `options` object.\n\n" + "In the example, a new column `first_quantile` is created based on values " + "in the column `my_col` using the `percentile` operator with " + "the `q=0.25` parameter.", + example={ + "first_quantile": { + "operator": "percentile", + "column": "my_col", + "options": {"q": 0.25}, + } + }, + ) + + +class ChartDataPostProcessingOperationOptions(Schema): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class ChartDataPostProcessingAggregateOptions(ChartDataPostProcessingOperationOptions): + """ + Aggregate operation config. + """ + + groupby = ( + fields.List( + fields.String( + allow_none=False, description="Columns by which to group by", + ), + minLength=1, + required=True, + ), + ) + aggregates = ChartDataAggregateConfig() + + +class ChartDataPostProcessingRollingOptions(ChartDataPostProcessingOperationOptions): + """ + Rolling operation config. + """ + + columns = ( + fields.Dict( + description="columns on which to perform rolling, mapping source column to " + "target column. For instance, `{'y': 'y'}` will replace the " + "column `y` with the rolling value in `y`, while `{'y': 'y2'}` " + "will add a column `y2` based on rolling values calculated " + "from `y`, leaving the original column `y` unchanged.", + example={"weekly_rolling_sales": "sales"}, + ), + ) + rolling_type = fields.String( + description="Type of rolling window. Any numpy function will work.", + enum=[ + "average", + "argmin", + "argmax", + "cumsum", + "cumprod", + "max", + "mean", + "median", + "nansum", + "nanmin", + "nanmax", + "nanmean", + "nanmedian", + "min", + "percentile", + "prod", + "product", + "std", + "sum", + "var", + ], + required=True, + example="percentile", + ) + window = fields.Integer( + description="Size of the rolling window in days.", required=True, example=7, + ) + rolling_type_options = fields.Dict( + desctiption="Optional options to pass to rolling method. Needed for " + "e.g. quantile operation.", + required=False, + example={}, + ) + center = fields.Boolean( + description="Should the label be at the center of the window. Default: `false`", + required=False, + example=False, + ) + win_type = fields.String( + description="Type of window function. See " + "[SciPy window functions](https://docs.scipy.org/doc/scipy/reference/signal.windows.html#module-scipy.signal.windows) " + "for more details. Some window functions require passing " + "additional parameters to `rolling_type_options`. For instance, " + "to use `gaussian`, the parameter `std` needs to be provided.", + required=False, + enum=[ + "boxcar", + "triang", + "blackman", + "hamming", + "bartlett", + "parzen", + "bohman", + "blackmanharris", + "nuttall", + "barthann", + "kaiser", + "gaussian", + "general_gaussian", + "slepian", + "exponential", + ], + ) + min_periods = fields.Integer( + description="The minimum amount of periods required for a row to be included " + "in the result set.", + required=False, + example=7, + ) + + +class ChartDataPostProcessingSelectOptions(ChartDataPostProcessingOperationOptions): + """ + Sort operation config. + """ + + columns = fields.List( + fields.String(), + description="Columns which to select from the input data, in the desired " + "order. If columns are renamed, the old column name should be " + "referenced here.", + example=["country", "gender", "age"], + ) + rename = fields.List( + fields.Dict(), + description="columns which to rename, mapping source column to target column. " + "For instance, `{'y': 'y2'}` will rename the column `y` to `y2`.", + example=[{"age": "average_age"}], + ) + + +class ChartDataPostProcessingSortOptions(ChartDataPostProcessingOperationOptions): + """ + Sort operation config. + """ + + columns = fields.Dict( + description="columns by by which to sort. The key specifies the column name, " + "value specifies if sorting in ascending order.", + example={"country": True, "gender": False}, + required=True, + ) + aggregates = ChartDataAggregateConfig() + + +class ChartDataPostProcessingPivotOptions(ChartDataPostProcessingOperationOptions): + """ + Pivot operation config. + """ + + index = ( + fields.List( + fields.String( + allow_none=False, + description="Columns to group by on the table index (=rows)", + ), + minLength=1, + required=True, + ), + ) + columns = fields.List( + fields.String( + allow_none=False, description="Columns to group by on the table columns", + ), + minLength=1, + required=True, + ) + metric_fill_value = fields.Number( + required=False, + description="Value to replace missing values with in aggregate calculations.", + ) + column_fill_value = fields.String( + required=False, description="Value to replace missing pivot columns names with." + ) + drop_missing_columns = fields.Boolean( + description="Do not include columns whose entries are all missing " + "(default: `true`).", + required=False, + ) + marginal_distributions = fields.Boolean( + description="Add totals for row/column. (default: `false`)", required=False, + ) + marginal_distribution_name = fields.String( + description="Name of marginal distribution row/column. (default: `All`)", + required=False, + ) + aggregates = ChartDataAggregateConfig() + + +class ChartDataPostProcessingOperation(Schema): + operation = fields.String( + description="Post processing operation type", + required=True, + enum=["aggregate", "pivot", "rolling", "select", "sort"], + example="aggregate", + ) + options = fields.Nested( + ChartDataPostProcessingOperationOptions(), + description="Options specifying how to perform the operation. Please refer " + "to the respective post processing operation option schemas. " + "For example, `ChartDataPostProcessingOperationOptions` specifies " + "the required options for the pivot operation.", + example={ + "groupby": ["country", "gender"], + "aggregates": { + "age_q1": { + "operator": "percentile", + "column": "age", + "options": {"q": 0.25}, + }, + "age_mean": {"operator": "mean", "column": "age",}, + }, + }, + ) + + +class ChartDataQueryObjectFilter(Schema): + col = fields.String( + description="The column to filter.", required=True, example="country" + ) + op = fields.String( + description="The comparison operator.", + enum=[filter_op.value for filter_op in utils.FilterOperationType], + required=True, + example="IN", + ) + val = fields.Raw( + description="The value or values to compare against. Can be a string, " + "integer, decimal or list, depending on the operator.", + example=["China", "France", "Japan"], + ) + + +class ChartDataQueryObject(Schema): + filters = ChartDataQueryObjectFilter() + granularity = fields.String( + description="To what level of granularity should the temporal column be " + "aggregated. Supports " + "[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) " + "durations.", + enum=[ + "PT1S", + "PT1M", + "PT5M", + "PT10M", + "PT15M", + "PT0.5H", + "PT1H", + "P1D", + "P1W", + "P1M", + "P0.25Y", + "P1Y", + ], + required=False, + example="P1D", + groupby=fields.List( + fields.String(description="Columns by which to group the query.",), + ), + ) + metrics = fields.List( + # TODO: add string type when support for `anyOf` is added to Marshmallow. + # strings are used to reference matrics stored in the datasource. + fields.Nested(ChartDataAdhocMetric), + description="Aggregate expressions. Metrics can be passed as both " + "references to datasource metrics (strings), or ad-hoc metrics" + "which are defined only within the query object.", + ) + post_processing = fields.List( + fields.Nested(ChartDataPostProcessingOperation), + description="Post processing operations to be applied to the result set. " + "Operations are applied to the result set in sequential order.", + ) + + +class ChartDataDatasource(Schema): + description = "Chart datasource" + id = fields.Integer(description="Datasource id", required=True,) + type = fields.String(description="Datasource type", enum=["druid", "sql"]) + + +def shape_schema_serialization_disambiguation(base_object, parent_obj): + class_to_schema = { + ChartDataDatasource.__name__: ChartDataDatasource, + ChartDataQueryContext.__name__: ChartDataQueryContext, + } + try: + return class_to_schema[base_object.__class__.__name__]() + except KeyError: + pass + + raise TypeError( + "Could not detect type. " + "Did not have a base or a length. " + "Are you sure this is a shape?" + ) + + +def shape_schema_deserialization_disambiguation(object_dict, parent_object_dict): + if object_dict.get("base"): + return ChartDataDatasource() + elif object_dict.get("length"): + return ChartDataQueryContext() + + raise TypeError( + "Could not detect type. " + "Did not have a base or a length. " + "Are you sure this is a shape?" + ) + + +class ContrivedShapeClass(object): + def __init__(self, main, others): + self.main = main + self.others = others + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + +class ChartDataQueryContext(Schema): + datasource = fields.Nested(ChartDataDatasource) + queries = fields.List(fields.Nested(ChartDataQueryObject)) + + +CHART_DATA_SCHEMAS = ( + ChartDataQueryContext, + # TODO: These should optimally be included in the QueryContext schema as an `anyOf` + # in ChartDataPostPricessingOperation.options, but since `anyOf` is not yet + # supported by Marshmallow/apispec, this is not currently possible. + ChartDataPostProcessingAggregateOptions, + ChartDataPostProcessingPivotOptions, + ChartDataPostProcessingRollingOptions, + ChartDataPostProcessingSelectOptions, + ChartDataPostProcessingSortOptions, +) diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index 8e1acc74adecd..f6cbd7cd51687 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -25,6 +25,7 @@ from superset.constants import NULL_STRING from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult from superset.models.slice import Slice +from superset.typing import FilterValues from superset.utils import core as utils METRIC_FORM_DATA_PARAMS = [ @@ -301,21 +302,23 @@ def data_for_slices(self, slices: List[Slice]) -> Dict[str, Any]: @staticmethod def filter_values_handler( - values, target_column_is_numeric=False, is_list_target=False + values: Optional[FilterValues], + target_column_is_numeric: bool = False, + is_list_target: bool = False, ): - def handle_single_value(v): + def handle_single_value(value): # backward compatibility with previous components if isinstance(value, str): value = value.strip("\t\n'\"") @@ -321,11 +324,11 @@ def handle_single_value(value): return value if isinstance(values, (list, tuple)): - values = [handle_single_value(v) for v in values] + values = [handle_single_value(v) for v in values] # type: ignore else: values = handle_single_value(values) if is_list_target and not isinstance(values, (tuple, list)): - values = [values] + values = [values] # type: ignore elif not is_list_target and isinstance(values, (tuple, list)): if values: values = values[0] diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 93e1e01bd563b..e843d5ba2a325 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -24,7 +24,7 @@ from datetime import datetime, timedelta from distutils.version import LooseVersion from multiprocessing.pool import ThreadPool -from typing import Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import cast, Dict, Iterable, List, Optional, Set, Tuple, Union import pandas as pd import sqlalchemy as sa @@ -1490,7 +1490,7 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter": filters = None for flt in raw_filters: col: Optional[str] = flt.get("col") - op: Optional[str] = flt.get["op"].upper() if "op" in flt else None + op: Optional[str] = flt["op"].upper() if "op" in flt else None eq: Optional[FilterValues] = flt.get("val") if ( not col @@ -1526,9 +1526,11 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter": target_column_is_numeric=is_numeric_col, ) + if eq is None: + continue # For these two ops, could have used Dimension, # but it doesn't support extraction functions - if op == FilterOperationType.EQUALS.value: + elif op == FilterOperationType.EQUALS.value: cond = Filter( dimension=col, value=eq, extraction_function=extraction_fn ) @@ -1536,7 +1538,8 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter": cond = ~Filter( dimension=col, value=eq, extraction_function=extraction_fn ) - elif op in (FilterOperationType.IN.value, FilterOperationType.NOT_IN.value): + elif is_list_target: + eq = cast(list, eq) fields = [] # ignore the filter if it has no value if not len(eq): diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index f40a7c1c5a5bc..8aac7e091e783 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -860,7 +860,7 @@ def get_sqla_query( # sqla utils.FilterOperationType.NOT_IN.value, ): cond = col_obj.get_sqla_col().in_(eq) - if NULL_STRING in eq: + if isinstance(eq, str) and NULL_STRING in eq: cond = or_(cond, col_obj.get_sqla_col() is None) if op == utils.FilterOperationType.NOT_IN.value: cond = ~cond From c8b78369b3e39bd867d6d2eba9eba37a1bfc08bb Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Thu, 16 Apr 2020 19:11:00 +0300 Subject: [PATCH 3/7] Fix unit test errors --- superset/charts/api.py | 3 +- superset/charts/schemas.py | 82 ++++++++++++++++------------- superset/connectors/druid/models.py | 4 +- 3 files changed, 47 insertions(+), 42 deletions(-) diff --git a/superset/charts/api.py b/superset/charts/api.py index 337f27b1e7127..518bcfc49a41b 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -44,6 +44,7 @@ from superset.charts.filters import ChartFilter, ChartNameOrDescriptionFilter from superset.charts.schemas import ( CHART_DATA_SCHEMAS, + ChartDataQueryContextSchema, ChartPostSchema, ChartPutSchema, get_delete_ids_schema, @@ -432,7 +433,7 @@ def data(self) -> Response: if not request.is_json: return self.response_400(message="Request is not JSON") try: - query_context = QueryContext(**request.json) + query_context, errors = ChartDataQueryContextSchema().load(request.json) except KeyError: return self.response_400(message="Request is incorrect") try: diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 7cb0dc45459ee..19faccdeb14d9 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -14,11 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Union +from typing import Any, Dict, Union -from marshmallow import fields, Schema, ValidationError +from marshmallow import fields, post_load, Schema, ValidationError from marshmallow.validate import Length +from superset.common.query_context import QueryContext from superset.exceptions import SupersetException from superset.utils import core as utils @@ -61,14 +62,14 @@ class ChartPutSchema(Schema): dashboards = fields.List(fields.Integer()) -class ChartDataColumn(Schema): +class ChartDataColumnSchema(Schema): column_name = fields.String( description="The name of the target column", example="mycol", ) type = fields.String(description="Type of target column", example="BIGINT",) -class ChartDataAdhocMetric(Schema): +class ChartDataAdhocMetricSchema(Schema): """ Ad-hoc metrics are used to define metrics outside the datasource. """ @@ -84,7 +85,7 @@ class ChartDataAdhocMetric(Schema): required=False, enum=["AVG", "COUNT", "COUNT_DISTINCT", "MAX", "MIN", "SUM"], ) - column = fields.Nested(ChartDataColumn) + column = fields.Nested(ChartDataColumnSchema) sqlExpression = fields.String( description="The metric as defined by a SQL aggregate expression. " "Only required for SQL expression type.", @@ -113,17 +114,18 @@ class ChartDataAdhocMetric(Schema): ) -class ChartDataAggregateConfig(fields.Dict): +class ChartDataAggregateConfigField(fields.Dict): def __init__(self) -> None: super().__init__( description="The keys are the name of the aggregate column to be created, " "and the values specify the details of how to apply the " "aggregation. If an operator requires additional options, " - "these can be passed here to be unpacked in the operator call. The following " - "numpy operators are supported: average, argmin, argmax, cumsum, cumprod, " - "max, mean, median, nansum, nanmin, nanmax, nanmean, nanmedian, min, " - "percentile, prod, product, std, sum, var. Any options required by the " - "operator can be passed to the `options` object.\n\n" + "these can be passed here to be unpacked in the operator call. The " + "following numpy operators are supported: average, argmin, argmax, cumsum, " + "cumprod, max, mean, median, nansum, nanmin, nanmax, nanmean, nanmedian, " + "min, percentile, prod, product, std, sum, var. Any options required by " + "the operator can be passed to the `options` object.\n" + "\n" "In the example, a new column `first_quantile` is created based on values " "in the column `my_col` using the `percentile` operator with " "the `q=0.25` parameter.", @@ -137,12 +139,12 @@ def __init__(self) -> None: ) -class ChartDataPostProcessingOperationOptions(Schema): +class ChartDataPostProcessingOperationOptionsSchema(Schema): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) -class ChartDataPostProcessingAggregateOptions(ChartDataPostProcessingOperationOptions): +class ChartDataAggregateOptionsSchema(ChartDataPostProcessingOperationOptionsSchema): """ Aggregate operation config. """ @@ -156,10 +158,10 @@ class ChartDataPostProcessingAggregateOptions(ChartDataPostProcessingOperationOp required=True, ), ) - aggregates = ChartDataAggregateConfig() + aggregates = ChartDataAggregateConfigField() -class ChartDataPostProcessingRollingOptions(ChartDataPostProcessingOperationOptions): +class ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchema): """ Rolling operation config. """ @@ -248,7 +250,7 @@ class ChartDataPostProcessingRollingOptions(ChartDataPostProcessingOperationOpti ) -class ChartDataPostProcessingSelectOptions(ChartDataPostProcessingOperationOptions): +class ChartDataSelectOptionsSchema(ChartDataPostProcessingOperationOptionsSchema): """ Sort operation config. """ @@ -268,7 +270,7 @@ class ChartDataPostProcessingSelectOptions(ChartDataPostProcessingOperationOptio ) -class ChartDataPostProcessingSortOptions(ChartDataPostProcessingOperationOptions): +class ChartDataSortOptionsSchema(ChartDataPostProcessingOperationOptionsSchema): """ Sort operation config. """ @@ -279,10 +281,10 @@ class ChartDataPostProcessingSortOptions(ChartDataPostProcessingOperationOptions example={"country": True, "gender": False}, required=True, ) - aggregates = ChartDataAggregateConfig() + aggregates = ChartDataAggregateConfigField() -class ChartDataPostProcessingPivotOptions(ChartDataPostProcessingOperationOptions): +class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema): """ Pivot operation config. """ @@ -323,10 +325,10 @@ class ChartDataPostProcessingPivotOptions(ChartDataPostProcessingOperationOption description="Name of marginal distribution row/column. (default: `All`)", required=False, ) - aggregates = ChartDataAggregateConfig() + aggregates = ChartDataAggregateConfigField() -class ChartDataPostProcessingOperation(Schema): +class ChartDataPostProcessingOperationSchema(Schema): operation = fields.String( description="Post processing operation type", required=True, @@ -334,7 +336,7 @@ class ChartDataPostProcessingOperation(Schema): example="aggregate", ) options = fields.Nested( - ChartDataPostProcessingOperationOptions(), + ChartDataPostProcessingOperationOptionsSchema(), description="Options specifying how to perform the operation. Please refer " "to the respective post processing operation option schemas. " "For example, `ChartDataPostProcessingOperationOptions` specifies " @@ -353,7 +355,7 @@ class ChartDataPostProcessingOperation(Schema): ) -class ChartDataQueryObjectFilter(Schema): +class ChartDataFilterSchema(Schema): col = fields.String( description="The column to filter.", required=True, example="country" ) @@ -370,7 +372,7 @@ class ChartDataQueryObjectFilter(Schema): ) -class ChartDataQueryObjectExtras(Schema): +class ChartDataExtrasSchema(Schema): time_range_endpoints = fields.List( fields.String(enum=["INCLUSIVE", "EXCLUSIVE"]), description="A list with two values, stating if start/end should be " @@ -391,8 +393,8 @@ class ChartDataQueryObjectExtras(Schema): ) -class ChartDataQueryObject(Schema): - filters = fields.Nested(ChartDataQueryObjectFilter()) +class ChartDataQueryObjectSchema(Schema): + filters = fields.Nested(ChartDataFilterSchema()) granularity = fields.String( description="To what level of granularity should the temporal column be " "aggregated. Supports " @@ -421,13 +423,13 @@ class ChartDataQueryObject(Schema): metrics = fields.List( # TODO: add string type when support for `anyOf` is added to Marshmallow. # strings are used to reference matrics stored in the datasource. - fields.Nested(ChartDataAdhocMetric), + fields.Nested(ChartDataAdhocMetricSchema), description="Aggregate expressions. Metrics can be passed as both " "references to datasource metrics (strings), or ad-hoc metrics" "which are defined only within the query object.", ) post_processing = fields.List( - fields.Nested(ChartDataPostProcessingOperation), + fields.Nested(ChartDataPostProcessingOperationSchema), description="Post processing operations to be applied to the result set. " "Operations are applied to the result set in sequential order.", ) @@ -482,25 +484,29 @@ class ChartDataQueryObject(Schema): ) -class ChartDataDatasource(Schema): +class ChartDataDatasourceSchema(Schema): description = "Chart datasource" id = fields.Integer(description="Datasource id", required=True,) type = fields.String(description="Datasource type", enum=["druid", "sql"]) -class ChartDataQueryContext(Schema): - datasource = fields.Nested(ChartDataDatasource) - queries = fields.List(fields.Nested(ChartDataQueryObject)) +class ChartDataQueryContextSchema(Schema): + datasource = fields.Nested(ChartDataDatasourceSchema) + queries = fields.List(fields.Nested(ChartDataQueryObjectSchema)) + + @post_load + def make_query_context(self, data: Dict[str, Any]) -> QueryContext: + return QueryContext(**data) CHART_DATA_SCHEMAS = ( - ChartDataQueryContext, + ChartDataQueryContextSchema, # TODO: These should optimally be included in the QueryContext schema as an `anyOf` # in ChartDataPostPricessingOperation.options, but since `anyOf` is not yet # supported by Marshmallow/apispec, this is not currently possible. - ChartDataPostProcessingAggregateOptions, - ChartDataPostProcessingPivotOptions, - ChartDataPostProcessingRollingOptions, - ChartDataPostProcessingSelectOptions, - ChartDataPostProcessingSortOptions, + ChartDataAggregateOptionsSchema, + ChartDataPivotOptionsSchema, + ChartDataRollingOptionsSchema, + ChartDataSelectOptionsSchema, + ChartDataSortOptionsSchema, ) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index e843d5ba2a325..20dd732d3842c 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -1526,11 +1526,9 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter": target_column_is_numeric=is_numeric_col, ) - if eq is None: - continue # For these two ops, could have used Dimension, # but it doesn't support extraction functions - elif op == FilterOperationType.EQUALS.value: + if op == FilterOperationType.EQUALS.value: cond = Filter( dimension=col, value=eq, extraction_function=extraction_fn ) From 9583ac0e214b4000267974ec5db1fa0f207d6093 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Thu, 16 Apr 2020 20:00:53 +0300 Subject: [PATCH 4/7] abc --- superset/charts/api.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/superset/charts/api.py b/superset/charts/api.py index 518bcfc49a41b..d2aab8f6e7b4c 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -50,7 +50,6 @@ get_delete_ids_schema, thumbnail_query_schema, ) -from superset.common.query_context import QueryContext from superset.constants import RouteMethod from superset.exceptions import SupersetSecurityException from superset.extensions import event_logger, security_manager @@ -434,6 +433,8 @@ def data(self) -> Response: return self.response_400(message="Request is not JSON") try: query_context, errors = ChartDataQueryContextSchema().load(request.json) + if errors: + raise self.response_400(message=_("Request is incorrect")) except KeyError: return self.response_400(message="Request is incorrect") try: From 5d348a5be18940d48b8ddb84d5bb7061eb3c3632 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Fri, 17 Apr 2020 01:02:20 +0300 Subject: [PATCH 5/7] Fix errors uncovered by schema validation and add unit test for invalid payload --- superset/charts/api.py | 6 ++++-- superset/charts/schemas.py | 33 ++++++++++++++++++++------------- superset/common/query_object.py | 8 ++++---- tests/charts/api_tests.py | 11 +++++++++++ 4 files changed, 39 insertions(+), 19 deletions(-) diff --git a/superset/charts/api.py b/superset/charts/api.py index d2aab8f6e7b4c..0db37383daf65 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -23,7 +23,7 @@ from flask import g, make_response, redirect, request, Response, url_for from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface -from flask_babel import ngettext +from flask_babel import gettext as _, ngettext from werkzeug.wrappers import Response as WerkzeugResponse from werkzeug.wsgi import FileWrapper @@ -434,7 +434,9 @@ def data(self) -> Response: try: query_context, errors = ChartDataQueryContextSchema().load(request.json) if errors: - raise self.response_400(message=_("Request is incorrect")) + return self.response_400( + message=_("Request is incorrect: %(error)s", error=errors) + ) except KeyError: return self.response_400(message="Request is incorrect") try: diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 19faccdeb14d9..a52119e551560 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -219,7 +219,8 @@ class ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchem ) win_type = fields.String( description="Type of window function. See " - "[SciPy window functions](https://docs.scipy.org/doc/scipy/reference/signal.windows.html#module-scipy.signal.windows) " + "[SciPy window functions](https://docs.scipy.org/doc/scipy/reference" + "/signal.windows.html#module-scipy.signal.windows) " "for more details. Some window functions require passing " "additional parameters to `rolling_type_options`. For instance, " "to use `gaussian`, the parameter `std` needs to be provided.", @@ -359,7 +360,7 @@ class ChartDataFilterSchema(Schema): col = fields.String( description="The column to filter.", required=True, example="country" ) - op = fields.String( + op = fields.String( # pylint: disable=invalid-name description="The comparison operator.", enum=[filter_op.value for filter_op in utils.FilterOperationType], required=True, @@ -373,6 +374,7 @@ class ChartDataFilterSchema(Schema): class ChartDataExtrasSchema(Schema): + time_range_endpoints = fields.List( fields.String(enum=["INCLUSIVE", "EXCLUSIVE"]), description="A list with two values, stating if start/end should be " @@ -394,7 +396,7 @@ class ChartDataExtrasSchema(Schema): class ChartDataQueryObjectSchema(Schema): - filters = fields.Nested(ChartDataFilterSchema()) + filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False) granularity = fields.String( description="To what level of granularity should the temporal column be " "aggregated. Supports " @@ -416,22 +418,22 @@ class ChartDataQueryObjectSchema(Schema): ], required=False, example="P1D", - groupby=fields.List( - fields.String(description="Columns by which to group the query.",), - ), + ) + groupby = fields.List( + fields.String(description="Columns by which to group the query.",), ) metrics = fields.List( - # TODO: add string type when support for `anyOf` is added to Marshmallow. - # strings are used to reference matrics stored in the datasource. - fields.Nested(ChartDataAdhocMetricSchema), + fields.Raw(), description="Aggregate expressions. Metrics can be passed as both " "references to datasource metrics (strings), or ad-hoc metrics" - "which are defined only within the query object.", + "which are defined only within the query object. See " + "`ChartDataAdhocMetricSchema` for the structure of ad-hoc metrics.", ) post_processing = fields.List( fields.Nested(ChartDataPostProcessingOperationSchema), description="Post processing operations to be applied to the result set. " "Operations are applied to the result set in sequential order.", + required=False, ) time_range = fields.String( description="A time rage, either expressed as a colon separated string " @@ -494,16 +496,21 @@ class ChartDataQueryContextSchema(Schema): datasource = fields.Nested(ChartDataDatasourceSchema) queries = fields.List(fields.Nested(ChartDataQueryObjectSchema)) + # pylint: disable=no-self-use @post_load def make_query_context(self, data: Dict[str, Any]) -> QueryContext: - return QueryContext(**data) + query_context = QueryContext(**data) + return query_context + + # pylint: enable=no-self-use CHART_DATA_SCHEMAS = ( ChartDataQueryContextSchema, # TODO: These should optimally be included in the QueryContext schema as an `anyOf` - # in ChartDataPostPricessingOperation.options, but since `anyOf` is not yet - # supported by Marshmallow/apispec, this is not currently possible. + # in ChartDataPostPricessingOperation.options, but since `anyOf` is not + # by Marshmallow<3, this is not currently possible. + ChartDataAdhocMetricSchema, ChartDataAggregateOptionsSchema, ChartDataPivotOptionsSchema, ChartDataRollingOptionsSchema, diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 9051341befb8b..31a62418bff22 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -47,9 +47,9 @@ class QueryObject: is_timeseries: bool time_shift: Optional[timedelta] groupby: List[str] - metrics: List[Union[Dict, str]] + metrics: List[Union[Dict[str, Any], str]] row_limit: int - filter: List[str] + filter: List[Dict[str, Any]] timeseries_limit: int timeseries_limit_metric: Optional[Dict] order_desc: bool @@ -61,9 +61,9 @@ class QueryObject: def __init__( self, granularity: str, - metrics: List[Union[Dict, str]], + metrics: List[Union[Dict[str, Any], str]], groupby: Optional[List[str]] = None, - filters: Optional[List[str]] = None, + filters: Optional[List[Dict[str, Any]]] = None, time_range: Optional[str] = None, time_shift: Optional[str] = None, is_timeseries: bool = False, diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index 257b89ba37df0..68c4ce8d41de6 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -659,6 +659,17 @@ def test_chart_data(self): data = json.loads(rv.data.decode("utf-8")) self.assertEqual(data[0]["rowcount"], 100) + def test_invalid_chart_data(self): + """ + Query API: Test chart data query + """ + self.login(username="admin") + query_context = self._get_query_context() + query_context["datasource"] = "abc" + uri = "api/v1/chart/data" + rv = self.client.post(uri, json=query_context) + self.assertEqual(rv.status_code, 400) + def test_query_exec_not_allowed(self): """ Query API: Test chart data query not allowed From a8d7fe004011c093702b5c87694535c59efd5f03 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Fri, 17 Apr 2020 16:02:51 +0300 Subject: [PATCH 6/7] Add schema for response --- superset/charts/api.py | 31 +++-------------------- superset/charts/schemas.py | 52 +++++++++++++++++++++++++++++++++++++- tests/charts/api_tests.py | 2 +- 3 files changed, 55 insertions(+), 30 deletions(-) diff --git a/superset/charts/api.py b/superset/charts/api.py index 0db37383daf65..c01663e2b7c6f 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -391,39 +391,14 @@ def data(self) -> Response: content: application/json: schema: - $ref: "#/components/schemas/ChartDataQueryContext" + $ref: "#/components/schemas/ChartDataQueryContextSchema" responses: 200: description: Query result content: application/json: schema: - type: array - items: - type: object - properties: - cache_key: - type: string - cached_dttm: - type: string - cache_timeout: - type: integer - error: - type: string - is_cached: - type: boolean - query: - type: string - status: - type: string - stacktrace: - type: string - rowcount: - type: integer - data: - type: array - items: - type: object + $ref: "#/components/schemas/ChartDataResponseSchema" 400: $ref: '#/components/responses/400' 500: @@ -445,7 +420,7 @@ def data(self) -> Response: return self.response_401() payload_json = query_context.get_payload() response_data = simplejson.dumps( - payload_json, default=json_int_dttm_ser, ignore_nan=True + {"result": payload_json}, default=json_int_dttm_ser, ignore_nan=True ) resp = make_response(response_data, 200) resp.headers["Content-Type"] = "application/json; charset=utf-8" diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index a52119e551560..0a7035c6c0ca7 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -337,7 +337,7 @@ class ChartDataPostProcessingOperationSchema(Schema): example="aggregate", ) options = fields.Nested( - ChartDataPostProcessingOperationOptionsSchema(), + ChartDataPostProcessingOperationOptionsSchema, description="Options specifying how to perform the operation. Please refer " "to the respective post processing operation option schemas. " "For example, `ChartDataPostProcessingOperationOptions` specifies " @@ -505,8 +505,58 @@ def make_query_context(self, data: Dict[str, Any]) -> QueryContext: # pylint: enable=no-self-use +class ChartDataResponseResult(Schema): + cache_key = fields.String( + description="Unique cache key for query object", required=True, allow_none=True, + ) + cached_dttm = fields.String( + description="Cache timestamp", required=True, allow_none=True, + ) + cache_timeout = fields.Integer( + description="Cache timeout in following order: custom timeout, datasource " + "timeout, default config timeout.", + required=True, + allow_none=True, + ) + error = fields.String(description="Error", allow_none=True,) + is_cached = fields.Boolean( + description="Is the result cached", required=True, allow_none=None, + ) + query = fields.String( + description="The executed query statement", required=True, allow_none=False, + ) + status = fields.String( + description="Status of the query", + enum=[ + "stopped", + "failed", + "pending", + "running", + "scheduled", + "success", + "timed_out", + ], + allow_none=False, + ) + stacktrace = fields.String( + desciption="Stacktrace if there was an error", allow_none=True, + ) + rowcount = fields.Integer( + description="Amount of rows in result set", allow_none=False, + ) + data = fields.List(fields.Dict(), description="A list with results") + + +class ChartDataResponseSchema(Schema): + result = fields.List( + fields.Nested(ChartDataResponseResult), + description="A list of results for each corresponding query in the request.", + ) + + CHART_DATA_SCHEMAS = ( ChartDataQueryContextSchema, + ChartDataResponseSchema, # TODO: These should optimally be included in the QueryContext schema as an `anyOf` # in ChartDataPostPricessingOperation.options, but since `anyOf` is not # by Marshmallow<3, this is not currently possible. diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index 68c4ce8d41de6..1f64bae703ee9 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -657,7 +657,7 @@ def test_chart_data(self): rv = self.client.post(uri, json=query_context) self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data[0]["rowcount"], 100) + self.assertEqual(data["result"][0]["rowcount"], 100) def test_invalid_chart_data(self): """ From 20fb8b2b24ea7c31e76acd13bdc0392b2787fb34 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Fri, 17 Apr 2020 16:35:44 +0300 Subject: [PATCH 7/7] Remove unnecessary pylint disable --- superset/charts/api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/superset/charts/api.py b/superset/charts/api.py index c01663e2b7c6f..be4f40747b093 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=line-too-long import logging from typing import Any, Dict