From e7943e3b23570b08d174597a8b0f11d5f5fb2239 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Wed, 1 Sep 2021 12:02:40 +0300 Subject: [PATCH 01/40] add support for adhoc columns to api and sqla model --- setup.cfg | 2 +- superset/charts/schemas.py | 6 ++-- superset/common/query_context.py | 5 +-- superset/common/query_object.py | 13 ++++---- superset/connectors/sqla/models.py | 35 +++++++++++++++----- superset/typing.py | 8 +++++ superset/utils/core.py | 53 ++++++++++++++++++++++++++++++ 7 files changed, 102 insertions(+), 20 deletions(-) diff --git a/setup.cfg b/setup.cfg index 9a108f76a480e..aa6737842e7df 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,7 @@ combine_as_imports = true include_trailing_comma = true line_length = 88 known_first_party = superset -known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,cron_descriptor,croniter,cryptography,dateutil,deprecation,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_jwt_extended,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,graphlib,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pgsanity,pkg_resources,polyline,prison,progress,pyarrow,pyhive,pyparsing,pytest,pytest_mock,pytz,redis,requests,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,urllib3,werkzeug,wtforms,wtforms_json,yaml +known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,cron_descriptor,croniter,cryptography,dateutil,deprecation,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_jwt_extended,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pgsanity,pkg_resources,polyline,prison,progress,pyarrow,pyhive,pyparsing,pytest,pytest_mock,pytz,redis,requests,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,urllib3,werkzeug,wtforms,wtforms_json,yaml multi_line_output = 3 order_by_type = false diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 9916c0b221ec2..17610bfebb821 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -572,7 +572,8 @@ class ChartDataBoxplotOptionsSchema(ChartDataPostProcessingOperationOptionsSchem """ groupby = fields.List( - fields.String(description="Columns by which to group the query.",), + fields.Raw(), + description="Columns by which to group the query.", allow_none=True, ) @@ -582,6 +583,7 @@ class ChartDataBoxplotOptionsSchema(ChartDataPostProcessingOperationOptionsSchem "references to datasource metrics (strings), or ad-hoc metrics" "which are defined only within the query object. See " "`ChartDataAdhocMetricSchema` for the structure of ad-hoc metrics.", + allow_none=True, ) whisker_type = fields.String( @@ -1041,7 +1043,7 @@ class Meta: # pylint: disable=too-few-public-methods allow_none=True, ) columns = fields.List( - fields.String(), + fields.Raw(), description="Columns which to select in the query.", allow_none=True, ) diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 566b01c8613be..4944f67ac8426 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -46,6 +46,7 @@ DatasourceDict, DTTM_ALIAS, error_msg_from_exception, + get_column_names_from_columns, get_column_names_from_metrics, get_metric_names, normalize_dttm_col, @@ -450,8 +451,8 @@ def get_df_payload( try: invalid_columns = [ col - for col in query_obj.columns - + query_obj.groupby + for col in get_column_names_from_columns(query_obj.columns) + + get_column_names_from_columns(query_obj.groupby) + get_column_names_from_metrics(query_obj.metrics or []) if col not in self.datasource.column_names and col != DTTM_ALIAS ] diff --git a/superset/common/query_object.py b/superset/common/query_object.py index bdf5f89e964db..c15ee1b25234d 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -25,13 +25,14 @@ from superset.connectors.base.models import BaseDatasource from superset.connectors.connector_registry import ConnectorRegistry from superset.exceptions import QueryObjectValidationError -from superset.typing import Metric, OrderBy +from superset.typing import Column, Metric, OrderBy from superset.utils import pandas_postprocessing from superset.utils.core import ( ChartDataResultType, DatasourceDict, DTTM_ALIAS, find_duplicates, + get_column_names, get_metric_names, is_adhoc_metric, json_int_dttm_ser, @@ -81,7 +82,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes inner_to_dttm: Optional[datetime] is_timeseries: bool time_shift: Optional[timedelta] - groupby: List[str] + groupby: List[Column] metrics: Optional[List[Metric]] row_limit: int row_offset: int @@ -90,7 +91,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes timeseries_limit_metric: Optional[Metric] order_desc: bool extras: Dict[str, Any] - columns: List[str] + columns: List[Column] orderby: List[OrderBy] post_processing: List[Dict[str, Any]] datasource: Optional[BaseDatasource] @@ -107,7 +108,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals apply_fetch_values_predicate: bool = False, granularity: Optional[str] = None, metrics: Optional[List[Metric]] = None, - groupby: Optional[List[str]] = None, + groupby: Optional[List[Column]] = None, filters: Optional[List[QueryObjectFilterClause]] = None, time_range: Optional[str] = None, time_shift: Optional[str] = None, @@ -118,7 +119,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals timeseries_limit_metric: Optional[Metric] = None, order_desc: bool = True, extras: Optional[Dict[str, Any]] = None, - columns: Optional[List[str]] = None, + columns: Optional[List[Column]] = None, orderby: Optional[List[OrderBy]] = None, post_processing: Optional[List[Optional[Dict[str, Any]]]] = None, is_rowcount: bool = False, @@ -246,7 +247,7 @@ def metric_names(self) -> List[str]: def column_names(self) -> List[str]: """Return column names (labels). Reserved for future adhoc calculated columns.""" - return self.columns + return get_column_names((self.columns or []) + (self.groupby or [])) def validate( self, raise_exceptions: Optional[bool] = True diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 33bacc6be5b66..bd236e6ddf5dd 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -84,7 +84,7 @@ from superset.models.core import Database from superset.models.helpers import AuditMixinNullable, QueryResult from superset.sql_parse import ParsedQuery -from superset.typing import AdhocMetric, Metric, OrderBy, QueryObjectDict +from superset.typing import AdhocColumn, AdhocMetric, Metric, OrderBy, QueryObjectDict from superset.utils import core as utils from superset.utils.core import ( GenericDataType, @@ -857,6 +857,20 @@ def adhoc_metric_to_sqla( return self.make_sqla_column_compatible(sqla_metric, label) + def adhoc_column_to_sqla(self, column: AdhocColumn) -> ColumnElement: + """ + Turn an adhoc metric into a sqlalchemy column. + + :param dict column: Adhoc column definition + :returns: The metric defined as a sqlalchemy column + :rtype: sqlalchemy.sql.column + """ + label = utils.get_column_name(column) + tp = self.get_template_processor() + expression = tp.process_template(cast(str, column["sqlExpression"])) + sqla_metric = literal_column(expression) + return self.make_sqla_column_compatible(sqla_metric, label) + def make_sqla_column_compatible( self, sqla_col: ColumnElement, label: Optional[str] = None ) -> ColumnElement: @@ -1066,15 +1080,18 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma columns = groupby or columns for selected in columns: # if groupby field/expr equals granularity field/expr - table_col = columns_by_name.get(selected) - if table_col and table_col.type_generic == GenericDataType.TEMPORAL: - outer = table_col.get_timestamp_expression(time_grain, selected) - # if groupby field equals a selected column - elif table_col: - outer = table_col.get_sqla_col() + if isinstance(selected, str): + table_col = columns_by_name.get(selected) + if table_col and table_col.type_generic == GenericDataType.TEMPORAL: + outer = table_col.get_timestamp_expression(time_grain, selected) + # if groupby field equals a selected column + elif table_col: + outer = table_col.get_sqla_col() + else: + outer = literal_column(f"({selected})") + outer = self.make_sqla_column_compatible(outer, selected) else: - outer = literal_column(f"({selected})") - outer = self.make_sqla_column_compatible(outer, selected) + outer = self.adhoc_column_to_sqla(selected) groupby_exprs_sans_timestamp[outer.name] = outer select_exprs.append(outer) elif columns: diff --git a/superset/typing.py b/superset/typing.py index d076402df0ac9..66b6cd4c38491 100644 --- a/superset/typing.py +++ b/superset/typing.py @@ -58,6 +58,13 @@ class AdhocMetric(TypedDict, total=False): aggregate: str column: Optional[AdhocMetricColumn] expressionType: Literal["SIMPLE", "SQL"] + hasCustomLabel: Optional[bool] + label: Optional[str] + sqlExpression: Optional[str] + + +class AdhocColumn(TypedDict, total=False): + hasCustomLabel: Optional[bool] label: Optional[str] sqlExpression: Optional[str] @@ -72,6 +79,7 @@ class AdhocMetric(TypedDict, total=False): FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]] FormData = Dict[str, Any] Granularity = Union[str, Dict[str, Union[str, float]]] +Column = Union[AdhocColumn, str] Metric = Union[AdhocMetric, str] OrderBy = Tuple[Metric, bool] QueryObjectDict = Dict[str, Any] diff --git a/superset/utils/core.py b/superset/utils/core.py index 646b04c0227bc..b52e9bb18dc2e 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -97,8 +97,10 @@ SupersetTimeoutException, ) from superset.typing import ( + AdhocColumn, AdhocMetric, AdhocMetricColumn, + Column, FilterValues, FlaskResponse, FormData, @@ -1280,6 +1282,29 @@ def is_adhoc_metric(metric: Metric) -> TypeGuard[AdhocMetric]: return isinstance(metric, dict) and "expressionType" in metric +def is_adhoc_column(column: Metric) -> TypeGuard[AdhocColumn]: + return isinstance(column, dict) + + +def get_column_name(column: Column) -> str: + """ + Extract label from column + + :param column: object to extract label from + :return: String representation of column + :raises ValueError: if metric object is invalid + """ + if isinstance(column, dict): + label = column.get("label") + if label: + return label + expr = column.get("sqlExpression") + if expr: + return expr + raise Exception("Missing label") + return column + + def get_metric_name(metric: Metric) -> str: """ Extract label from metric @@ -1309,6 +1334,10 @@ def get_metric_name(metric: Metric) -> str: return metric # type: ignore +def get_column_names(columns: Sequence[Column]) -> List[str]: + return [column for column in map(get_column_name, columns) if column] + + def get_metric_names(metrics: Sequence[Metric]) -> List[str]: return [metric for metric in map(get_metric_name, metrics) if metric] @@ -1533,6 +1562,30 @@ def get_form_data_token(form_data: Dict[str, Any]) -> str: return form_data.get("token") or "token_" + uuid.uuid4().hex[:8] +def get_column_name_from_column(column: Column) -> Optional[str]: + """ + Extract the column that a metric is referencing. If the metric isn't + a simple metric, always returns `None`. + + :param column: Ad-hoc metric + :return: column name if simple metric, otherwise None + """ + if isinstance(column, str): + return column + return None + + +def get_column_names_from_columns(columns: List[Column]) -> List[str]: + """ + Extract the column that a metric is referencing. If the metric isn't + a simple metric, always returns `None`. + + :param columns: Ad-hoc metric + :return: column name if simple metric, otherwise None + """ + return [col for col in map(get_column_name_from_column, columns) if col] + + def get_column_name_from_metric(metric: Metric) -> Optional[str]: """ Extract the column that a metric is referencing. If the metric isn't From 0da5b0b3016d1010c301fecf510aa74339289095 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Thu, 2 Sep 2021 07:14:35 +0300 Subject: [PATCH 02/40] fix some types --- .../components/GroupBy/GroupByFilterPlugin.tsx | 18 ++++++++++-------- .../components/Range/RangeFilterPlugin.tsx | 4 +++- .../filters/components/Select/buildQuery.ts | 6 ++++-- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/superset-frontend/src/filters/components/GroupBy/GroupByFilterPlugin.tsx b/superset-frontend/src/filters/components/GroupBy/GroupByFilterPlugin.tsx index b8452cac4425c..e3c202c5df23c 100644 --- a/superset-frontend/src/filters/components/GroupBy/GroupByFilterPlugin.tsx +++ b/superset-frontend/src/filters/components/GroupBy/GroupByFilterPlugin.tsx @@ -16,7 +16,13 @@ * specific language governing permissions and limitations * under the License. */ -import { ensureIsArray, ExtraFormData, t, tn } from '@superset-ui/core'; +import { + ensureIsArray, + ExtraFormData, + getColumnLabel, + t, + tn, +} from '@superset-ui/core'; import React, { useEffect, useState } from 'react'; import { FormItemProps } from 'antd/lib/form'; import { Select } from 'src/components'; @@ -62,15 +68,11 @@ export default function PluginFilterGroupBy(props: PluginFilterGroupByProps) { // so we can process it like this `JSON.stringify` or start to use `Immer` }, [JSON.stringify(defaultValue), multiSelect]); - const groupby = formData?.groupby?.[0]?.length - ? formData?.groupby?.[0] - : null; + const groupbys = ensureIsArray(formData.groupby).map(getColumnLabel); + const groupby = groupbys[0].length ? groupbys[0] : null; const withData = groupby - ? data.filter(dataItem => - // @ts-ignore - groupby.includes(dataItem.column_name), - ) + ? data.filter(row => groupby.includes(row.column_name as string)) : data; const columns = data ? withData : []; diff --git a/superset-frontend/src/filters/components/Range/RangeFilterPlugin.tsx b/superset-frontend/src/filters/components/Range/RangeFilterPlugin.tsx index 2cd8129edcf1f..557b7d6563911 100644 --- a/superset-frontend/src/filters/components/Range/RangeFilterPlugin.tsx +++ b/superset-frontend/src/filters/components/Range/RangeFilterPlugin.tsx @@ -17,6 +17,8 @@ * under the License. */ import { + ensureIsArray, + getColumnLabel, getNumberFormatter, NumberFormats, styled, @@ -91,7 +93,7 @@ export default function RangeFilterPlugin(props: PluginFilterRangeProps) { // @ts-ignore const { min, max }: { min: number; max: number } = row; const { groupby, defaultValue, inputRef } = formData; - const [col = ''] = groupby || []; + const [col = ''] = ensureIsArray(groupby).map(getColumnLabel); const [value, setValue] = useState<[number, number]>( defaultValue ?? [min, max], ); diff --git a/superset-frontend/src/filters/components/Select/buildQuery.ts b/superset-frontend/src/filters/components/Select/buildQuery.ts index a66a855011f3f..572cd23275f6f 100644 --- a/superset-frontend/src/filters/components/Select/buildQuery.ts +++ b/superset-frontend/src/filters/components/Select/buildQuery.ts @@ -19,6 +19,7 @@ import { buildQueryContext, GenericDataType, + getColumnLabel, QueryObject, QueryObjectFilterClause, } from '@superset-ui/core'; @@ -36,14 +37,15 @@ const buildQuery: BuildQuery = ( const extraFilters: QueryObjectFilterClause[] = []; if (search) { columns.forEach(column => { - if (coltypeMap[column] === GenericDataType.STRING) { + const label = getColumnLabel(column); + if (coltypeMap[label] === GenericDataType.STRING) { extraFilters.push({ col: column, op: 'ILIKE', val: `%${search}%`, }); } else if ( - coltypeMap[column] === GenericDataType.NUMERIC && + coltypeMap[label] === GenericDataType.NUMERIC && !Number.isNaN(Number(search)) ) { // for numeric columns we apply a >= where clause From 7368ed24e3a4df9949a5942df4f46575c0439456 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Thu, 2 Sep 2021 10:24:15 +0300 Subject: [PATCH 03/40] fix duplicates in column names --- superset/common/query_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/common/query_object.py b/superset/common/query_object.py index c15ee1b25234d..ad9b017621436 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -247,7 +247,7 @@ def metric_names(self) -> List[str]: def column_names(self) -> List[str]: """Return column names (labels). Reserved for future adhoc calculated columns.""" - return get_column_names((self.columns or []) + (self.groupby or [])) + return get_column_names(list(set((self.columns or []) + (self.groupby or [])))) def validate( self, raise_exceptions: Optional[bool] = True From e2b1d5d132eb87271bb98b67f02a1bfbe86740bd Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Thu, 2 Sep 2021 11:42:03 +0300 Subject: [PATCH 04/40] fix more lint --- .../DndFilterSelect.test.tsx | 17 +++++++++++++++-- .../components/Select/SelectFilterPlugin.tsx | 3 ++- .../src/filters/components/Select/buildQuery.ts | 3 ++- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/superset-frontend/src/explore/components/controls/DndColumnSelectControl/DndFilterSelect.test.tsx b/superset-frontend/src/explore/components/controls/DndColumnSelectControl/DndFilterSelect.test.tsx index 40962c7c60b67..c3668e1bdf7d1 100644 --- a/superset-frontend/src/explore/components/controls/DndColumnSelectControl/DndFilterSelect.test.tsx +++ b/superset-frontend/src/explore/components/controls/DndColumnSelectControl/DndFilterSelect.test.tsx @@ -43,6 +43,11 @@ const defaultProps: DndFilterSelectProps = { actions: { setControlValue: jest.fn() }, }; +const extraFormData = { + viz_type: 'my_viz', + datasource: 'table__1', +}; + test('renders with default props', () => { render(, { useDnd: true }); expect(screen.getByText('Drop columns or metrics here')).toBeInTheDocument(); @@ -63,7 +68,11 @@ test('renders options with saved metric', () => { render( , { useDnd: true, @@ -100,7 +109,11 @@ test('renders options with adhoc metric', () => { render( , { useDnd: true, diff --git a/superset-frontend/src/filters/components/Select/SelectFilterPlugin.tsx b/superset-frontend/src/filters/components/Select/SelectFilterPlugin.tsx index 718193706e0d6..90b7c2deca230 100644 --- a/superset-frontend/src/filters/components/Select/SelectFilterPlugin.tsx +++ b/superset-frontend/src/filters/components/Select/SelectFilterPlugin.tsx @@ -24,6 +24,7 @@ import { ensureIsArray, ExtraFormData, GenericDataType, + getColumnLabel, JsonObject, smartDateDetailedFormatter, t, @@ -97,7 +98,7 @@ export default function PluginFilterSelect(props: PluginFilterSelectProps) { defaultToFirstItem, searchAllOptions, } = formData; - const groupby = ensureIsArray(formData.groupby); + const groupby = ensureIsArray(formData.groupby).map(getColumnLabel); const [col] = groupby; const [initialColtypeMap] = useState(coltypeMap); const [dataMask, dispatchDataMask] = useImmerReducer(reducer, { diff --git a/superset-frontend/src/filters/components/Select/buildQuery.ts b/superset-frontend/src/filters/components/Select/buildQuery.ts index 572cd23275f6f..e621056a18b9b 100644 --- a/superset-frontend/src/filters/components/Select/buildQuery.ts +++ b/superset-frontend/src/filters/components/Select/buildQuery.ts @@ -20,6 +20,7 @@ import { buildQueryContext, GenericDataType, getColumnLabel, + isPhysicalColumn, QueryObject, QueryObjectFilterClause, } from '@superset-ui/core'; @@ -36,7 +37,7 @@ const buildQuery: BuildQuery = ( const { columns = [], filters = [] } = baseQueryObject; const extraFilters: QueryObjectFilterClause[] = []; if (search) { - columns.forEach(column => { + columns.filter(isPhysicalColumn).forEach(column => { const label = getColumnLabel(column); if (coltypeMap[label] === GenericDataType.STRING) { extraFilters.push({ From 634ddbf1d542f9d4e44632f3bac6c99ef3dc584e Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Thu, 2 Sep 2021 19:47:26 +0300 Subject: [PATCH 05/40] fix schema and dedup --- superset/charts/schemas.py | 3 ++- superset/common/query_object.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 17610bfebb821..94dd5ed91aa3b 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -965,7 +965,8 @@ class Meta: # pylint: disable=too-few-public-methods deprecated=True, ) groupby = fields.List( - fields.String(description="Columns by which to group the query.",), + fields.Raw(), + description="Columns by which to group the query.", allow_none=True, ) metrics = fields.List( diff --git a/superset/common/query_object.py b/superset/common/query_object.py index ad9b017621436..0db34d64a7c32 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -247,7 +247,7 @@ def metric_names(self) -> List[str]: def column_names(self) -> List[str]: """Return column names (labels). Reserved for future adhoc calculated columns.""" - return get_column_names(list(set((self.columns or []) + (self.groupby or [])))) + return list(set(get_column_names((self.columns or []) + (self.groupby or [])))) def validate( self, raise_exceptions: Optional[bool] = True From 43ee0ed1fcea3e905d7be0daf204e641cb65fc78 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Fri, 3 Sep 2021 09:57:02 +0300 Subject: [PATCH 06/40] clean up some logic --- superset/charts/schemas.py | 3 +-- superset/common/query_object.py | 8 +++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 94dd5ed91aa3b..94bc919e0ee05 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -572,8 +572,7 @@ class ChartDataBoxplotOptionsSchema(ChartDataPostProcessingOperationOptionsSchem """ groupby = fields.List( - fields.Raw(), - description="Columns by which to group the query.", + fields.String(description="Columns by which to group the query.",), allow_none=True, ) diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 0db34d64a7c32..66aad53024ed0 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -245,9 +245,11 @@ def metric_names(self) -> List[str]: @property def column_names(self) -> List[str]: - """Return column names (labels). Reserved for future adhoc calculated - columns.""" - return list(set(get_column_names((self.columns or []) + (self.groupby or [])))) + """Return column names (labels). Gives priority to groupbys if both groupbys + and metrics are non-empty, otherwise returns column labels.""" + return get_column_names( + self.groupby if self.metrics and self.groupby else self.columns + ) def validate( self, raise_exceptions: Optional[bool] = True From 71981bcc9e4987050b424aa8a937248fb65323dd Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Fri, 3 Sep 2021 13:25:08 +0300 Subject: [PATCH 07/40] first pass at fixing viz.py --- setup.cfg | 2 +- superset/utils/core.py | 34 +++++------ superset/viz.py | 135 ++++++++++++++++++++++++----------------- 3 files changed, 98 insertions(+), 73 deletions(-) diff --git a/setup.cfg b/setup.cfg index aa6737842e7df..9a108f76a480e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,7 @@ combine_as_imports = true include_trailing_comma = true line_length = 88 known_first_party = superset -known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,cron_descriptor,croniter,cryptography,dateutil,deprecation,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_jwt_extended,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pgsanity,pkg_resources,polyline,prison,progress,pyarrow,pyhive,pyparsing,pytest,pytest_mock,pytz,redis,requests,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,urllib3,werkzeug,wtforms,wtforms_json,yaml +known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,cron_descriptor,croniter,cryptography,dateutil,deprecation,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_jwt_extended,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,graphlib,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pgsanity,pkg_resources,polyline,prison,progress,pyarrow,pyhive,pyparsing,pytest,pytest_mock,pytz,redis,requests,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,urllib3,werkzeug,wtforms,wtforms_json,yaml multi_line_output = 3 order_by_type = false diff --git a/superset/utils/core.py b/superset/utils/core.py index b52e9bb18dc2e..0202552f27d56 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1282,7 +1282,7 @@ def is_adhoc_metric(metric: Metric) -> TypeGuard[AdhocMetric]: return isinstance(metric, dict) and "expressionType" in metric -def is_adhoc_column(column: Metric) -> TypeGuard[AdhocColumn]: +def is_adhoc_column(column: Column) -> TypeGuard[AdhocColumn]: return isinstance(column, dict) @@ -1334,15 +1334,15 @@ def get_metric_name(metric: Metric) -> str: return metric # type: ignore -def get_column_names(columns: Sequence[Column]) -> List[str]: - return [column for column in map(get_column_name, columns) if column] +def get_column_names(columns: Optional[Sequence[Column]]) -> List[str]: + return [column for column in map(get_column_name, columns or []) if column] -def get_metric_names(metrics: Sequence[Metric]) -> List[str]: - return [metric for metric in map(get_metric_name, metrics) if metric] +def get_metric_names(metrics: Optional[Sequence[Metric]]) -> List[str]: + return [metric for metric in map(get_metric_name, metrics or []) if metric] -def get_first_metric_name(metrics: Sequence[Metric]) -> Optional[str]: +def get_first_metric_name(metrics: Optional[Sequence[Metric]]) -> Optional[str]: metric_labels = get_metric_names(metrics) return metric_labels[0] if metric_labels else None @@ -1564,24 +1564,24 @@ def get_form_data_token(form_data: Dict[str, Any]) -> str: def get_column_name_from_column(column: Column) -> Optional[str]: """ - Extract the column that a metric is referencing. If the metric isn't - a simple metric, always returns `None`. + Extract the physical column that a column is referencing. If the column is + an adhoc column, always returns `None`. - :param column: Ad-hoc metric - :return: column name if simple metric, otherwise None + :param column: Physical and ad-hoc column + :return: column name if physical column, otherwise None """ - if isinstance(column, str): - return column - return None + if is_adhoc_column(column): + return None + return column # type: ignore def get_column_names_from_columns(columns: List[Column]) -> List[str]: """ - Extract the column that a metric is referencing. If the metric isn't - a simple metric, always returns `None`. + Extract the physical columns that a list of columns are referencing. Ignore + adhoc columns - :param columns: Ad-hoc metric - :return: column name if simple metric, otherwise None + :param columns: Physical and adhoc columns + :return: column names of all physical columns """ return [col for col in map(get_column_name_from_column, columns) if col] diff --git a/superset/viz.py b/superset/viz.py index 357b6c8f9bc5d..4d62c6fd8fe9d 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -21,7 +21,6 @@ Superset can render. """ import copy -import inspect import logging import math import re @@ -53,7 +52,7 @@ from geopy.point import Point from pandas.tseries.frequencies import to_offset -from superset import app, db, is_feature_enabled +from superset import app, is_feature_enabled from superset.constants import NULL_STRING from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( @@ -64,14 +63,18 @@ SupersetSecurityException, ) from superset.extensions import cache_manager, security_manager -from superset.models.cache import CacheKey from superset.models.helpers import QueryResult -from superset.typing import Metric, QueryObjectDict, VizData, VizPayload +from superset.typing import Column, Metric, QueryObjectDict, VizData, VizPayload from superset.utils import core as utils, csv from superset.utils.cache import set_and_log_cache from superset.utils.core import ( DTTM_ALIAS, ExtraFiltersReasonType, + get_column_name, + get_column_names, + get_column_names_from_columns, + get_metric_names, + is_adhoc_column, JS_MAX_INTEGER, merge_extra_filters, QueryMode, @@ -139,7 +142,7 @@ def __init__( self.query = "" self.token = utils.get_form_data_token(form_data) - self.groupby: List[str] = self.form_data.get("groupby") or [] + self.groupby: List[Column] = self.form_data.get("groupby") or [] self.time_shift = timedelta() self.status: Optional[str] = None @@ -311,21 +314,32 @@ def process_query_filters(self) -> None: merge_extra_filters(self.form_data) utils.split_adhoc_filters_into_base_filters(self.form_data) + @staticmethod + def dedup_columns(*columns_args: Optional[List[Column]]) -> List[Column]: + # dedup groupby and columns while preserving order + labels: List[str] = [] + deduped_columns: List[Column] = [] + for columns in columns_args: + for column in columns or []: + label = get_column_name(column) + if label not in labels: + deduped_columns.append(column) + return deduped_columns + def query_obj(self) -> QueryObjectDict: """Building a query object""" form_data = self.form_data self.process_query_filters() - gb = self.groupby metrics = self.all_metrics or [] - columns = form_data.get("columns") or [] - # merge list and dedup while preserving order - groupby = list(OrderedDict.fromkeys(gb + columns)) + + groupby = self.dedup_columns(self.groupby, form_data.get("columns")) + groupby_labels = get_column_names(groupby) is_timeseries = self.is_timeseries - if DTTM_ALIAS in groupby: - groupby.remove(DTTM_ALIAS) + if DTTM_ALIAS in groupby_labels: + del groupby[groupby_labels.index(DTTM_ALIAS)] is_timeseries = True granularity = form_data.get("granularity") or form_data.get("granularity_sqla") @@ -526,10 +540,12 @@ def get_df_payload( try: invalid_columns = [ col - for col in (query_obj.get("columns") or []) - + (query_obj.get("groupby") or []) + for col in get_column_names_from_columns( + query_obj.get("columns") or [] + ) + + get_column_names_from_columns(query_obj.get("groupby") or []) + utils.get_column_names_from_metrics( - cast(List[Metric], query_obj.get("metrics") or [],) + cast(List[Metric], query_obj.get("metrics") or []) ) if col not in self.datasource.column_names ] @@ -689,10 +705,12 @@ def process_metrics(self) -> None: percent_columns: List[str] = [] # percent columns that needs extra computation if self.query_mode == QueryMode.RAW: - columns = utils.get_metric_names(fd.get("all_columns") or []) + columns = get_metric_names(fd.get("all_columns")) else: - columns = utils.get_metric_names(self.groupby + (fd.get("metrics") or [])) - percent_columns = utils.get_metric_names(fd.get("percent_metrics") or []) + columns = utils.get_column_names(self.groupby) + get_metric_names( + fd.get("metrics") + ) + percent_columns = get_metric_names(fd.get("percent_metrics") or []) self.columns = columns self.percent_columns = percent_columns @@ -728,7 +746,7 @@ def query_obj(self) -> QueryObjectDict: sort_by = fd.get("timeseries_limit_metric") if sort_by: sort_by_label = utils.get_metric_name(sort_by) - if sort_by_label not in utils.get_metric_names(d["metrics"]): + if sort_by_label not in get_metric_names(d["metrics"]): d["metrics"].append(sort_by) d["orderby"] = [(sort_by, not fd.get("order_desc", True))] elif d["metrics"]: @@ -805,7 +823,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: values: Union[List[str], str] = self.metric_labels if fd.get("groupby"): values = self.metric_labels[0] - columns = fd.get("groupby") + columns = get_column_names(fd["groupby"]) pt = df.pivot_table(index=DTTM_ALIAS, columns=columns, values=values) pt.index = pt.index.map(str) pt = pt.sort_index() @@ -851,12 +869,14 @@ def query_obj(self) -> QueryObjectDict: ) if not metrics: raise QueryObjectValidationError(_("Please choose at least one metric")) - if set(groupby) & set(columns): + deduped_cols = self.dedup_columns(groupby, columns) + + if len(deduped_cols) < (len(groupby) + len(columns)): raise QueryObjectValidationError(_("Group By' and 'Columns' can't overlap")) sort_by = self.form_data.get("timeseries_limit_metric") if sort_by: sort_by_label = utils.get_metric_name(sort_by) - if sort_by_label not in utils.get_metric_names(d["metrics"]): + if sort_by_label not in get_metric_names(d["metrics"]): d["metrics"].append(sort_by) if self.form_data.get("order_desc"): d["orderby"] = [(sort_by, not self.form_data.get("order_desc", True))] @@ -914,18 +934,22 @@ def get_data(self, df: pd.DataFrame) -> VizData: groupby = self.form_data.get("groupby") or [] columns = self.form_data.get("columns") or [] - for column_name in groupby + columns: - column = self.datasource.get_column(column_name) - if column and column.is_temporal: - ts = df[column_name].apply(self._format_datetime) - df[column_name] = ts + for column in groupby + columns: + if is_adhoc_column(column): + # TODO: check data type + pass + else: + column_obj = self.datasource.get_column(column) + if column_obj and column_obj.is_temporal: + ts = df[column].apply(self._format_datetime) + df[column] = ts if self.form_data.get("transpose_pivot"): groupby, columns = columns, groupby df = df.pivot_table( - index=groupby, - columns=columns, + index=get_column_names(groupby), + columns=get_column_names(columns), values=metrics, aggfunc=aggfuncs, margins=self.form_data.get("pivot_margins"), @@ -963,7 +987,7 @@ def query_obj(self) -> QueryObjectDict: sort_by = self.form_data.get("timeseries_limit_metric") if sort_by: sort_by_label = utils.get_metric_name(sort_by) - if sort_by_label not in utils.get_metric_names(d["metrics"]): + if sort_by_label not in get_metric_names(d["metrics"]): d["metrics"].append(sort_by) if self.form_data.get("order_desc"): d["orderby"] = [(sort_by, not self.form_data.get("order_desc", True))] @@ -984,7 +1008,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None - df = df.set_index(self.form_data.get("groupby")) + df = df.set_index(get_column_names(self.form_data.get("groupby"))) chart_data = [ {"name": metric, "children": self._nest(metric, df)} for metric in df.columns @@ -1100,7 +1124,7 @@ def query_obj(self) -> QueryObjectDict: d["groupby"].append(form_data.get("series")) # dedup groupby if it happens to be the same - d["groupby"] = list(dict.fromkeys(d["groupby"])) + d["groupby"] = self.dedup_columns(d["groupby"]) self.x_metric = form_data["x"] self.y_metric = form_data["y"] @@ -1124,7 +1148,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: df["y"] = df[[utils.get_metric_name(self.y_metric)]] df["size"] = df[[utils.get_metric_name(self.z_metric)]] df["shape"] = "circle" - df["group"] = df[[self.series]] + df["group"] = df[[get_column_name(self.series)]] # type: ignore series: Dict[Any, List[Any]] = defaultdict(list) for row in df.to_dict(orient="records"): @@ -1237,7 +1261,7 @@ def query_obj(self) -> QueryObjectDict: is_asc = not self.form_data.get("order_desc") if sort_by: sort_by_label = utils.get_metric_name(sort_by) - if sort_by_label not in utils.get_metric_names(d["metrics"]): + if sort_by_label not in get_metric_names(d["metrics"]): d["metrics"].append(sort_by) d["orderby"] = [(sort_by, is_asc)] return d @@ -1316,7 +1340,7 @@ def process_data(self, df: pd.DataFrame, aggregate: bool = False) -> VizData: if aggregate: df = df.pivot_table( index=DTTM_ALIAS, - columns=fd.get("groupby"), + columns=get_column_names(fd.get("groupby")), values=self.metric_labels, fill_value=0, aggfunc=sum, @@ -1324,7 +1348,7 @@ def process_data(self, df: pd.DataFrame, aggregate: bool = False) -> VizData: else: df = df.pivot_table( index=DTTM_ALIAS, - columns=fd.get("groupby"), + columns=get_column_names(fd.get("groupby")), values=self.metric_labels, fill_value=self.pivot_fill_value, ) @@ -1700,15 +1724,15 @@ def get_data(self, df: pd.DataFrame) -> VizData: chart_data = [] if len(self.groupby) > 0: - groups = df.groupby(self.groupby) + groups = df.groupby(get_column_names(self.groupby)) else: groups = [((), df)] for keys, data in groups: chart_data.extend( [ { - "key": self.labelify(keys, column), - "values": data[column].tolist(), + "key": self.labelify(keys, get_column_name(column)), + "values": data[get_column_name(column)].tolist(), } for column in self.columns ] @@ -1741,7 +1765,7 @@ def query_obj(self) -> QueryObjectDict: sort_by = fd.get("timeseries_limit_metric") if sort_by: sort_by_label = utils.get_metric_name(sort_by) - if sort_by_label not in utils.get_metric_names(d["metrics"]): + if sort_by_label not in get_metric_names(d["metrics"]): d["metrics"].append(sort_by) d["orderby"] = [(sort_by, not fd.get("order_desc", True))] elif d["metrics"]: @@ -1757,21 +1781,22 @@ def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data metrics = self.metric_labels - columns = fd.get("columns") or [] + columns = get_column_names(fd.get("columns")) + groupby = get_column_names(self.groupby) # pandas will throw away nulls when grouping/pivoting, # so we substitute NULL_STRING for any nulls in the necessary columns - filled_cols = self.groupby + columns + filled_cols = groupby + columns df = df.copy() df[filled_cols] = df[filled_cols].fillna(value=NULL_STRING) sortby = utils.get_metric_name( self.form_data.get("timeseries_limit_metric") or metrics[0] ) - row = df.groupby(self.groupby).sum()[sortby].copy() + row = df.groupby(groupby).sum()[sortby].copy() is_asc = not self.form_data.get("order_desc") row.sort_values(ascending=is_asc, inplace=True) - pt = df.pivot_table(index=self.groupby, columns=columns, values=metrics) + pt = df.pivot_table(index=groupby, columns=columns, values=metrics) if fd.get("contribution"): pt = pt.T pt = (pt / pt.sum()).T @@ -1781,7 +1806,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: pt = pt[metrics] chart_data = [] for name, ys in pt.items(): - if pt[name].dtype.kind not in "biufc" or name in self.groupby: + if pt[name].dtype.kind not in "biufc" or name in groupby: continue if isinstance(name, str): series_title = name @@ -1817,7 +1842,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None fd = copy.deepcopy(self.form_data) - cols = fd.get("groupby") or [] + cols = get_column_names(fd.get("groupby")) cols.extend(["m1", "m2"]) metric = utils.get_metric_name(fd["metric"]) secondary_metric = ( @@ -1872,7 +1897,7 @@ def query_obj(self) -> QueryObjectDict: def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None - source, target = self.groupby + source, target = get_column_names(self.groupby) (value,) = self.metric_labels df.rename( columns={source: "source", target: "target", value: "value",}, inplace=True, @@ -1976,7 +2001,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None fd = self.form_data - cols = [fd.get("entity")] + cols = get_column_names([fd.get("entity")]) # type: ignore metric = self.metric_labels[0] cols += [metric] ndf = df[cols] @@ -2009,7 +2034,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: from superset.examples import countries fd = self.form_data - cols = [fd.get("entity")] + cols = get_column_names([fd.get("entity")]) # type: ignore metric = utils.get_metric_name(fd["metric"]) secondary_metric = ( utils.get_metric_name(fd["secondary_metric"]) @@ -2136,7 +2161,7 @@ def query_obj(self) -> QueryObjectDict: sort_by = self.form_data.get("timeseries_limit_metric") if sort_by: sort_by_label = utils.get_metric_name(sort_by) - if sort_by_label not in utils.get_metric_names(d["metrics"]): + if sort_by_label not in get_metric_names(d["metrics"]): d["metrics"].append(sort_by) if self.form_data.get("order_desc"): d["orderby"] = [(sort_by, not self.form_data.get("order_desc", True))] @@ -2174,8 +2199,8 @@ def get_data(self, df: pd.DataFrame) -> VizData: return None fd = self.form_data - x = fd.get("all_columns_x") - y = fd.get("all_columns_y") + x = get_column_name(fd.get("all_columns_x")) # type: ignore + y = get_column_name(fd.get("all_columns_y")) # type: ignore v = self.metric_labels[0] if x == y: df.columns = ["x", "y", "v"] @@ -2763,7 +2788,7 @@ def query_obj(self) -> QueryObjectDict: return d def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: - geojson = d[self.form_data["geojson"]] + geojson = d[get_column_name(self.form_data["geojson"])] return json.loads(geojson) @@ -2849,7 +2874,7 @@ def query_obj(self) -> QueryObjectDict: sort_by = self.form_data.get("timeseries_limit_metric") if sort_by: sort_by_label = utils.get_metric_name(sort_by) - if sort_by_label not in utils.get_metric_names(d["metrics"]): + if sort_by_label not in get_metric_names(d["metrics"]): d["metrics"].append(sort_by) if self.form_data.get("order_desc"): d["orderby"] = [(sort_by, not self.form_data.get("order_desc", True))] @@ -2872,7 +2897,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: return None fd = self.form_data - groups = fd.get("groupby") + groups = get_column_names(fd.get("groupby")) metrics = self.metric_labels df = df.pivot_table(index=DTTM_ALIAS, columns=groups, values=metrics) cols = [] @@ -3095,7 +3120,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None fd = self.form_data - groups = fd.get("groupby", []) + groups = get_column_names(fd.get("groupby")) time_op = fd.get("time_series_option", "not_time") if not len(groups): raise ValueError("Please choose at least one groupby") From ee38f5404256a31a88707a378a814377f3bdcc2b Mon Sep 17 00:00:00 2001 From: Kamil Gabryjelski Date: Fri, 3 Sep 2021 14:06:24 +0200 Subject: [PATCH 08/40] Add frontend support for adhoc columns --- .../ColumnSelectPopover.tsx | 98 ++++++++++++++----- .../ColumnSelectPopoverTrigger.tsx | 5 +- .../DndColumnSelect.tsx | 28 ++++-- .../DndColumnSelectControl/OptionWrapper.tsx | 24 +++-- .../controls/DndColumnSelectControl/types.ts | 4 +- .../utils/optionSelector.ts | 59 ++++++++--- 6 files changed, 163 insertions(+), 55 deletions(-) diff --git a/superset-frontend/src/explore/components/controls/DndColumnSelectControl/ColumnSelectPopover.tsx b/superset-frontend/src/explore/components/controls/DndColumnSelectControl/ColumnSelectPopover.tsx index c58b3d382e6d7..7df02b8e4271a 100644 --- a/superset-frontend/src/explore/components/controls/DndColumnSelectControl/ColumnSelectPopover.tsx +++ b/superset-frontend/src/explore/components/controls/DndColumnSelectControl/ColumnSelectPopover.tsx @@ -18,14 +18,20 @@ */ /* eslint-disable camelcase */ import React, { useCallback, useMemo, useState } from 'react'; +import { + AdhocColumn, + isAdhocColumn, + isSavedExpression, + t, + styled, +} from '@superset-ui/core'; +import { ColumnMeta } from '@superset-ui/chart-controls'; import Tabs from 'src/components/Tabs'; import Button from 'src/components/Button'; import { NativeSelect as Select } from 'src/components/Select'; -import { t, styled } from '@superset-ui/core'; - import { Form, FormItem } from 'src/components/Form'; +import { SQLEditor } from 'src/components/AsyncAceEditor'; import { StyledColumnOption } from 'src/explore/components/optionRenderers'; -import { ColumnMeta } from '@superset-ui/chart-controls'; const StyledSelect = styled(Select)` .metric-option { @@ -41,8 +47,8 @@ const StyledSelect = styled(Select)` interface ColumnSelectPopoverProps { columns: ColumnMeta[]; - editedColumn?: ColumnMeta; - onChange: (column: ColumnMeta) => void; + editedColumn?: ColumnMeta | AdhocColumn; + onChange: (column: ColumnMeta | AdhocColumn) => void; onClose: () => void; } @@ -52,18 +58,24 @@ const ColumnSelectPopover = ({ onChange, onClose, }: ColumnSelectPopoverProps) => { - const [ - initialCalculatedColumn, - initialSimpleColumn, - ] = editedColumn?.expression - ? [editedColumn, undefined] - : [undefined, editedColumn]; - const [selectedCalculatedColumn, setSelectedCalculatedColumn] = useState( - initialCalculatedColumn, - ); - const [selectedSimpleColumn, setSelectedSimpleColumn] = useState( - initialSimpleColumn, + const [initialAdhocColumn, initialCalculatedColumn, initialSimpleColumn]: [ + AdhocColumn?, + ColumnMeta?, + ColumnMeta?, + ] = isAdhocColumn(editedColumn) + ? [editedColumn, undefined, undefined] + : isSavedExpression(editedColumn) + ? [undefined, editedColumn, undefined] + : [undefined, undefined, editedColumn as ColumnMeta]; + const [adhocColumn, setAdhocColumn] = useState( + initialAdhocColumn, ); + const [selectedCalculatedColumn, setSelectedCalculatedColumn] = useState< + ColumnMeta | undefined + >(initialCalculatedColumn); + const [selectedSimpleColumn, setSelectedSimpleColumn] = useState< + ColumnMeta | undefined + >(initialSimpleColumn); const [calculatedColumns, simpleColumns] = useMemo( () => @@ -81,6 +93,12 @@ const ColumnSelectPopover = ({ [columns], ); + const onSqlExpressionChange = useCallback(sqlExpression => { + setAdhocColumn({ label: 'test', sqlExpression }); + setSelectedSimpleColumn(undefined); + setSelectedCalculatedColumn(undefined); + }, []); + const onCalculatedColumnChange = useCallback( selectedColumnName => { const selectedColumn = calculatedColumns.find( @@ -88,6 +106,7 @@ const ColumnSelectPopover = ({ ); setSelectedCalculatedColumn(selectedColumn); setSelectedSimpleColumn(undefined); + setAdhocColumn(undefined); }, [calculatedColumns], ); @@ -99,33 +118,52 @@ const ColumnSelectPopover = ({ ); setSelectedCalculatedColumn(undefined); setSelectedSimpleColumn(selectedColumn); + setAdhocColumn(undefined); }, [simpleColumns], ); - const defaultActiveTabKey = - initialSimpleColumn || calculatedColumns.length === 0 ? 'simple' : 'saved'; + const defaultActiveTabKey = initialAdhocColumn + ? 'sqlExpression' + : initialSimpleColumn || calculatedColumns.length === 0 + ? 'simple' + : 'saved'; const onSave = useCallback(() => { - const selectedColumn = selectedCalculatedColumn || selectedSimpleColumn; + const selectedColumn = + adhocColumn || selectedCalculatedColumn || selectedSimpleColumn; if (!selectedColumn) { return; } onChange(selectedColumn); onClose(); - }, [onChange, onClose, selectedCalculatedColumn, selectedSimpleColumn]); + }, [ + adhocColumn, + onChange, + onClose, + selectedCalculatedColumn, + selectedSimpleColumn, + ]); const onResetStateAndClose = useCallback(() => { setSelectedCalculatedColumn(initialCalculatedColumn); setSelectedSimpleColumn(initialSimpleColumn); + setAdhocColumn(initialAdhocColumn); onClose(); - }, [initialCalculatedColumn, initialSimpleColumn, onClose]); + }, [ + initialAdhocColumn, + initialCalculatedColumn, + initialSimpleColumn, + onClose, + ]); - const stateIsValid = selectedCalculatedColumn || selectedSimpleColumn; + const stateIsValid = + adhocColumn || selectedCalculatedColumn || selectedSimpleColumn; const hasUnsavedChanges = selectedCalculatedColumn?.column_name !== initialCalculatedColumn?.column_name || - selectedSimpleColumn?.column_name !== initialSimpleColumn?.column_name; + selectedSimpleColumn?.column_name !== initialSimpleColumn?.column_name || + adhocColumn?.sqlExpression !== initialAdhocColumn?.sqlExpression; const filterOption = useCallback( (input, option) => @@ -199,6 +237,20 @@ const ColumnSelectPopover = ({ + + +