diff --git a/docs/static/resources/openapi.json b/docs/static/resources/openapi.json index 8077af91c1906..d7aecdc4c8b27 100644 --- a/docs/static/resources/openapi.json +++ b/docs/static/resources/openapi.json @@ -345,7 +345,7 @@ "AnnotationLayerRestApi.get_list": { "properties": { "changed_by": { - "$ref": "#/components/schemas/AnnotationLayerRestApi.get_list.User" + "$ref": "#/components/schemas/AnnotationLayerRestApi.get_list.User1" }, "changed_on": { "format": "date-time", @@ -356,7 +356,7 @@ "readOnly": true }, "created_by": { - "$ref": "#/components/schemas/AnnotationLayerRestApi.get_list.User1" + "$ref": "#/components/schemas/AnnotationLayerRestApi.get_list.User" }, "created_on": { "format": "date-time", @@ -502,13 +502,13 @@ "AnnotationRestApi.get_list": { "properties": { "changed_by": { - "$ref": "#/components/schemas/AnnotationRestApi.get_list.User" + "$ref": "#/components/schemas/AnnotationRestApi.get_list.User1" }, "changed_on_delta_humanized": { "readOnly": true }, "created_by": { - "$ref": "#/components/schemas/AnnotationRestApi.get_list.User1" + "$ref": "#/components/schemas/AnnotationRestApi.get_list.User" }, "end_dttm": { "format": "date-time", @@ -1783,7 +1783,7 @@ "readOnly": true }, "created_by": { - "$ref": "#/components/schemas/ChartDataRestApi.get_list.User3" + "$ref": "#/components/schemas/ChartDataRestApi.get_list.User2" }, "created_on_delta_humanized": { "readOnly": true @@ -1833,7 +1833,7 @@ "$ref": "#/components/schemas/ChartDataRestApi.get_list.User" }, "owners": { - "$ref": "#/components/schemas/ChartDataRestApi.get_list.User2" + "$ref": "#/components/schemas/ChartDataRestApi.get_list.User3" }, "params": { "nullable": true, @@ -1942,16 +1942,11 @@ "last_name": { "maxLength": 64, "type": "string" - }, - "username": { - "maxLength": 64, - "type": "string" } }, "required": [ "first_name", - "last_name", - "username" + "last_name" ], "type": "object" }, @@ -1968,11 +1963,16 @@ "last_name": { "maxLength": 64, "type": "string" + }, + "username": { + "maxLength": 64, + "type": "string" } }, "required": [ "first_name", - "last_name" + "last_name", + "username" ], "type": "object" }, @@ -2575,7 +2575,7 @@ "readOnly": true }, "created_by": { - "$ref": "#/components/schemas/ChartRestApi.get_list.User3" + "$ref": "#/components/schemas/ChartRestApi.get_list.User2" }, "created_on_delta_humanized": { "readOnly": true @@ -2625,7 +2625,7 @@ "$ref": "#/components/schemas/ChartRestApi.get_list.User" }, "owners": { - "$ref": "#/components/schemas/ChartRestApi.get_list.User2" + "$ref": "#/components/schemas/ChartRestApi.get_list.User3" }, "params": { "nullable": true, @@ -2734,16 +2734,11 @@ "last_name": { "maxLength": 64, "type": "string" - }, - "username": { - "maxLength": 64, - "type": "string" } }, "required": [ "first_name", - "last_name", - "username" + "last_name" ], "type": "object" }, @@ -2760,11 +2755,16 @@ "last_name": { "maxLength": 64, "type": "string" + }, + "username": { + "maxLength": 64, + "type": "string" } }, "required": [ "first_name", - "last_name" + "last_name", + "username" ], "type": "object" }, @@ -3027,13 +3027,13 @@ "CssTemplateRestApi.get_list": { "properties": { "changed_by": { - "$ref": "#/components/schemas/CssTemplateRestApi.get_list.User" + "$ref": "#/components/schemas/CssTemplateRestApi.get_list.User1" }, "changed_on_delta_humanized": { "readOnly": true }, "created_by": { - "$ref": "#/components/schemas/CssTemplateRestApi.get_list.User1" + "$ref": "#/components/schemas/CssTemplateRestApi.get_list.User" }, "created_on": { "format": "date-time", @@ -3415,7 +3415,7 @@ "readOnly": true }, "created_by": { - "$ref": "#/components/schemas/DashboardRestApi.get_list.User2" + "$ref": "#/components/schemas/DashboardRestApi.get_list.User1" }, "created_on_delta_humanized": { "readOnly": true @@ -3441,7 +3441,7 @@ "type": "string" }, "owners": { - "$ref": "#/components/schemas/DashboardRestApi.get_list.User1" + "$ref": "#/components/schemas/DashboardRestApi.get_list.User2" }, "position_json": { "nullable": true, @@ -3515,10 +3515,6 @@ }, "DashboardRestApi.get_list.User1": { "properties": { - "email": { - "maxLength": 64, - "type": "string" - }, "first_name": { "maxLength": 64, "type": "string" @@ -3530,22 +3526,20 @@ "last_name": { "maxLength": 64, "type": "string" - }, - "username": { - "maxLength": 64, - "type": "string" } }, "required": [ - "email", "first_name", - "last_name", - "username" + "last_name" ], "type": "object" }, "DashboardRestApi.get_list.User2": { "properties": { + "email": { + "maxLength": 64, + "type": "string" + }, "first_name": { "maxLength": 64, "type": "string" @@ -3557,11 +3551,17 @@ "last_name": { "maxLength": 64, "type": "string" + }, + "username": { + "maxLength": 64, + "type": "string" } }, "required": [ + "email", "first_name", - "last_name" + "last_name", + "username" ], "type": "object" }, @@ -4895,7 +4895,7 @@ "$ref": "#/components/schemas/DatasetRestApi.get.TableColumn" }, "created_by": { - "$ref": "#/components/schemas/DatasetRestApi.get.User2" + "$ref": "#/components/schemas/DatasetRestApi.get.User1" }, "created_on": { "format": "date-time", @@ -4959,7 +4959,7 @@ "type": "integer" }, "owners": { - "$ref": "#/components/schemas/DatasetRestApi.get.User1" + "$ref": "#/components/schemas/DatasetRestApi.get.User2" }, "schema": { "maxLength": 255, @@ -5173,23 +5173,14 @@ "maxLength": 64, "type": "string" }, - "id": { - "format": "int32", - "type": "integer" - }, "last_name": { "maxLength": 64, "type": "string" - }, - "username": { - "maxLength": 64, - "type": "string" } }, "required": [ "first_name", - "last_name", - "username" + "last_name" ], "type": "object" }, @@ -5199,14 +5190,23 @@ "maxLength": 64, "type": "string" }, + "id": { + "format": "int32", + "type": "integer" + }, "last_name": { "maxLength": 64, "type": "string" + }, + "username": { + "maxLength": 64, + "type": "string" } }, "required": [ "first_name", - "last_name" + "last_name", + "username" ], "type": "object" }, @@ -6949,7 +6949,7 @@ "type": "integer" }, "created_by": { - "$ref": "#/components/schemas/ReportScheduleRestApi.get_list.User2" + "$ref": "#/components/schemas/ReportScheduleRestApi.get_list.User1" }, "created_on": { "format": "date-time", @@ -6999,7 +6999,7 @@ "type": "string" }, "owners": { - "$ref": "#/components/schemas/ReportScheduleRestApi.get_list.User1" + "$ref": "#/components/schemas/ReportScheduleRestApi.get_list.User2" }, "recipients": { "$ref": "#/components/schemas/ReportScheduleRestApi.get_list.ReportRecipients" @@ -7060,10 +7060,6 @@ "maxLength": 64, "type": "string" }, - "id": { - "format": "int32", - "type": "integer" - }, "last_name": { "maxLength": 64, "type": "string" @@ -7081,6 +7077,10 @@ "maxLength": 64, "type": "string" }, + "id": { + "format": "int32", + "type": "integer" + }, "last_name": { "maxLength": 64, "type": "string" @@ -9507,6 +9507,17 @@ }, "type": "object" }, + "sql_lab_export_csv_schema": { + "properties": { + "client_id": { + "type": "string" + } + }, + "required": [ + "client_id" + ], + "type": "object" + }, "sql_lab_get_results_schema": { "properties": { "key": { @@ -16686,6 +16697,99 @@ ] } }, + "/api/v1/datasource/{datasource_type}/{datasource_id}/column/{column_name}/values/": { + "get": { + "parameters": [ + { + "description": "The type of datasource", + "in": "path", + "name": "datasource_type", + "required": true, + "schema": { + "type": "string" + } + }, + { + "description": "The id of the datasource", + "in": "path", + "name": "datasource_id", + "required": true, + "schema": { + "type": "integer" + } + }, + { + "description": "The name of the column to get values for", + "in": "path", + "name": "column_name", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "properties": { + "result": { + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "type": "object" + } + ] + }, + "type": "array" + } + }, + "type": "object" + } + } + }, + "description": "A List of distinct values for the column" + }, + "400": { + "$ref": "#/components/responses/400" + }, + "401": { + "$ref": "#/components/responses/401" + }, + "403": { + "$ref": "#/components/responses/403" + }, + "404": { + "$ref": "#/components/responses/404" + }, + "500": { + "$ref": "#/components/responses/500" + } + }, + "security": [ + { + "jwt": [] + } + ], + "summary": "Get possible values for a datasource column", + "tags": [ + "Datasources" + ] + } + }, "/api/v1/embedded_dashboard/{uuid}": { "get": { "description": "Get a report schedule log", @@ -19799,6 +19903,59 @@ ] } }, + "/api/v1/sqllab/export/": { + "get": { + "parameters": [ + { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/sql_lab_export_csv_schema" + } + } + }, + "in": "query", + "name": "q" + } + ], + "responses": { + "200": { + "content": { + "text/csv": { + "schema": { + "type": "string" + } + } + }, + "description": "SQL query results" + }, + "400": { + "$ref": "#/components/responses/400" + }, + "401": { + "$ref": "#/components/responses/401" + }, + "403": { + "$ref": "#/components/responses/403" + }, + "404": { + "$ref": "#/components/responses/404" + }, + "500": { + "$ref": "#/components/responses/500" + } + }, + "security": [ + { + "jwt": [] + } + ], + "summary": "Exports the SQL Query results to a CSV", + "tags": [ + "SQL Lab" + ] + } + }, "/api/v1/sqllab/results/": { "get": { "parameters": [ diff --git a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx index 81a4e47a11368..47f8d4acdd56a 100644 --- a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx +++ b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx @@ -17,6 +17,7 @@ * under the License. */ import React, { useCallback, useEffect, useState } from 'react'; +import rison from 'rison'; import { useDispatch } from 'react-redux'; import ButtonGroup from 'src/components/ButtonGroup'; import Alert from 'src/components/Alert'; @@ -219,6 +220,14 @@ const ResultSet = ({ } }; + const getExportCsvUrl = (clientId: string) => { + const params = rison.encode({ + client_id: clientId, + }); + + return `/api/v1/sqllab/export/?q=${params}`; + }; + const renderControls = () => { if (search || visualize || csv) { let { data } = query.results; @@ -257,7 +266,7 @@ const ResultSet = ({ /> )} {csv && ( - )} diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py index 283c3ab638707..df6acbfd267d2 100644 --- a/superset/sqllab/api.py +++ b/superset/sqllab/api.py @@ -16,6 +16,7 @@ # under the License. import logging from typing import Any, cast, Dict, Optional +from urllib import parse import simplejson as json from flask import request @@ -32,6 +33,7 @@ from superset.sql_lab import get_sql_results from superset.sqllab.command_status import SqlJsonExecutionStatus from superset.sqllab.commands.execute import CommandResult, ExecuteSqlCommand +from superset.sqllab.commands.export import SqlResultExportCommand from superset.sqllab.commands.results import SqlExecutionResultsCommand from superset.sqllab.exceptions import ( QueryIsForbiddenToAccessException, @@ -42,6 +44,7 @@ from superset.sqllab.schemas import ( ExecutePayloadSchema, QueryExecutionResponseSchema, + sql_lab_export_csv_schema, sql_lab_get_results_schema, ) from superset.sqllab.sql_json_executer import ( @@ -53,7 +56,7 @@ from superset.sqllab.validators import CanAccessQueryValidatorImpl from superset.superset_typing import FlaskResponse from superset.utils import core as utils -from superset.views.base import json_success +from superset.views.base import CsvResponse, generate_download_headers, json_success from superset.views.base_api import BaseSupersetApi, requires_json, statsd_metrics config = app.config @@ -72,6 +75,7 @@ class SqlLabRestApi(BaseSupersetApi): apispec_parameter_schemas = { "sql_lab_get_results_schema": sql_lab_get_results_schema, + "sql_lab_export_csv_schema": sql_lab_export_csv_schema, } openapi_spec_tag = "SQL Lab" openapi_spec_component_schemas = ( @@ -79,6 +83,73 @@ class SqlLabRestApi(BaseSupersetApi): QueryExecutionResponseSchema, ) + @expose("/export/") + @protect() + @statsd_metrics + @rison(sql_lab_export_csv_schema) + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" + f".export_csv", + log_to_statsd=False, + ) + def export_csv(self, **kwargs: Any) -> CsvResponse: + """Exports the SQL Query results to a CSV + --- + get: + summary: >- + Exports the SQL Query results to a CSV + parameters: + - in: query + name: q + content: + application/json: + schema: + $ref: '#/components/schemas/sql_lab_export_csv_schema' + responses: + 200: + description: SQL query results + content: + text/csv: + schema: + type: string + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + params = kwargs["rison"] + client_id = params.get("client_id") + result = SqlResultExportCommand(client_id=client_id).run() + + query = result.get("query") + data = result.get("data") + row_count = result.get("row_count") + + quoted_csv_name = parse.quote(query.name) + response = CsvResponse( + data, headers=generate_download_headers("csv", quoted_csv_name) + ) + event_info = { + "event_type": "data_export", + "client_id": client_id, + "row_count": row_count, + "database": query.database.name, + "schema": query.schema, + "sql": query.sql, + "exported_format": "csv", + } + event_rep = repr(event_info) + logger.debug( + "CSV exported: %s", event_rep, extra={"superset_event": event_info} + ) + return response + @expose("/results/") @protect() @statsd_metrics diff --git a/superset/sqllab/commands/export.py b/superset/sqllab/commands/export.py new file mode 100644 index 0000000000000..1c189d674e474 --- /dev/null +++ b/superset/sqllab/commands/export.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-few-public-methods, too-many-arguments +from __future__ import annotations + +import logging +from typing import Any, cast, Dict + +import pandas as pd +from flask_babel import gettext as __, lazy_gettext as _ + +from superset import app, db, results_backend, results_backend_use_msgpack +from superset.commands.base import BaseCommand +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import SupersetErrorException, SupersetSecurityException +from superset.models.sql_lab import Query +from superset.sql_parse import ParsedQuery +from superset.sqllab.limiting_factor import LimitingFactor +from superset.utils import core as utils, csv +from superset.utils.dates import now_as_float +from superset.views.utils import _deserialize_results_payload + +config = app.config + +logger = logging.getLogger(__name__) + + +class SqlResultExportCommand(BaseCommand): + _client_id: str + _query: Query + + def __init__( + self, + client_id: str, + ) -> None: + self._client_id = client_id + + def validate(self) -> None: + self._query = ( + db.session.query(Query).filter_by(client_id=self._client_id).one_or_none() + ) + if self._query is None: + raise SupersetErrorException( + SupersetError( + message=__( + "The query associated with these results could not be found. " + "You need to re-run the original query." + ), + error_type=SupersetErrorType.RESULTS_BACKEND_ERROR, + level=ErrorLevel.ERROR, + ), + status=404, + ) + + try: + self._query.raise_for_access() + except SupersetSecurityException: + raise SupersetErrorException( + SupersetError( + message=__("Cannot access the query"), + error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR, + level=ErrorLevel.ERROR, + ), + status=403, + ) + + def run( + self, + ) -> Dict[str, Any]: + self.validate() + blob = None + if results_backend and self._query.results_key: + logger.info( + "Fetching CSV from results backend [%s]", self._query.results_key + ) + blob = results_backend.get(self._query.results_key) + if blob: + logger.info("Decompressing") + payload = utils.zlib_decompress( + blob, decode=not results_backend_use_msgpack + ) + obj = _deserialize_results_payload( + payload, self._query, cast(bool, results_backend_use_msgpack) + ) + + df = pd.DataFrame( + data=obj["data"], + dtype=object, + columns=[c["name"] for c in obj["columns"]], + ) + + logger.info("Using pandas to convert to CSV") + else: + logger.info("Running a query to turn into CSV") + if self._query.select_sql: + sql = self._query.select_sql + limit = None + else: + sql = self._query.executed_sql + limit = ParsedQuery(sql).limit + if limit is not None and self._query.limiting_factor in { + LimitingFactor.QUERY, + LimitingFactor.DROPDOWN, + LimitingFactor.QUERY_AND_DROPDOWN, + }: + # remove extra row from `increased_limit` + limit -= 1 + df = self._query.database.get_df(sql, self._query.schema)[:limit] + + csv_data = csv.df_to_escaped_csv(df, index=False, **config["CSV_EXPORT"]) + + return { + "query": self._query, + "count": len(df.index), + "data": csv_data, + } diff --git a/superset/sqllab/schemas.py b/superset/sqllab/schemas.py index f238fda5c918f..428cdb89bb3e3 100644 --- a/superset/sqllab/schemas.py +++ b/superset/sqllab/schemas.py @@ -24,6 +24,14 @@ "required": ["key"], } +sql_lab_export_csv_schema = { + "type": "object", + "properties": { + "client_id": {"type": "string"}, + }, + "required": ["client_id"], +} + class ExecutePayloadSchema(Schema): database_id = fields.Integer(required=True) diff --git a/superset/views/core.py b/superset/views/core.py index 8d632dcde21bf..283ea5df0de72 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2392,6 +2392,7 @@ def _create_response_from_execution_context( # pylint: disable=invalid-name, no @has_access @event_logger.log_this @expose("/csv/") + @deprecated() def csv(self, client_id: str) -> FlaskResponse: # pylint: disable=no-self-use """Download the query results as csv.""" logger.info("Exporting CSV file [%s]", client_id) diff --git a/tests/integration_tests/sql_lab/api_tests.py b/tests/integration_tests/sql_lab/api_tests.py index 4c2080ad4cc2f..52668593213b2 100644 --- a/tests/integration_tests/sql_lab/api_tests.py +++ b/tests/integration_tests/sql_lab/api_tests.py @@ -19,6 +19,9 @@ import datetime import json import random +import csv +import pandas as pd +import io import pytest import prison @@ -26,7 +29,7 @@ from unittest import mock from tests.integration_tests.test_app import app -from superset import sql_lab +from superset import db, sql_lab from superset.common.db_query_status import QueryStatus from superset.models.core import Database from superset.utils.database import get_example_database, get_main_database @@ -176,3 +179,39 @@ def test_get_results_with_display_limit(self): self.assertEqual(result_limited, expected_limited) app.config["RESULTS_BACKEND_USE_MSGPACK"] = use_msgpack + + @mock.patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None) + @mock.patch("superset.models.core.Database.get_df") + def test_export_results(self, get_df_mock: mock.Mock) -> None: + self.login() + + database = Database( + database_name="my_export_database", sqlalchemy_uri="sqlite://" + ) + query_obj = Query( + client_id="test", + database=database, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql=None, + executed_sql="select * from bar limit 2", + limit=100, + select_as_cta=False, + rows=104, + error_message="none", + results_key="test_abc2", + ) + + db.session.add(database) + db.session.add(query_obj) + + get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]}) + + arguments = {"client_id": "test"} + resp = self.get_resp(f"/api/v1/sqllab/export/?q={prison.dumps(arguments)}") + data = csv.reader(io.StringIO(resp)) + expected_data = csv.reader(io.StringIO(f"foo\n1\n2")) + + self.assertEqual(list(expected_data), list(data)) + db.session.rollback() diff --git a/tests/integration_tests/sql_lab/commands_tests.py b/tests/integration_tests/sql_lab/commands_tests.py index 74c1fe7082103..4e2e6642a66d7 100644 --- a/tests/integration_tests/sql_lab/commands_tests.py +++ b/tests/integration_tests/sql_lab/commands_tests.py @@ -15,23 +15,259 @@ # specific language governing permissions and limitations # under the License. from unittest import mock, skip -from unittest.mock import patch +from unittest.mock import Mock, patch +import pandas as pd import pytest from superset import db, sql_lab from superset.common.db_query_status import QueryStatus -from superset.errors import SupersetErrorType -from superset.exceptions import SerializationError, SupersetErrorException +from superset.errors import ErrorLevel, SupersetErrorType +from superset.exceptions import ( + SerializationError, + SupersetError, + SupersetErrorException, + SupersetSecurityException, +) from superset.models.core import Database from superset.models.sql_lab import Query -from superset.sqllab.commands import results +from superset.sqllab.commands import export, results +from superset.sqllab.limiting_factor import LimitingFactor from superset.utils import core as utils from tests.integration_tests.base_tests import SupersetTestCase +class TestSqlResultExportCommand(SupersetTestCase): + def test_validation_query_not_found(self) -> None: + command = export.SqlResultExportCommand("asdf") + + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + query_obj = Query( + client_id="test1", + database=database, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql="select * from bar", + executed_sql="select * from bar", + limit=100, + select_as_cta=False, + rows=104, + error_message="none", + results_key="abc1", + ) + + db.session.add(database) + db.session.add(query_obj) + db.session.commit() + + with pytest.raises(SupersetErrorException) as ex_info: + command.run() + assert ex_info.value.error.error_type == SupersetErrorType.RESULTS_BACKEND_ERROR + + def test_validation_invalid_access(self) -> None: + command = export.SqlResultExportCommand("test2") + + database = Database(database_name="my_database2", sqlalchemy_uri="sqlite://") + query_obj = Query( + client_id="test2", + database=database, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql="select * from bar", + executed_sql="select * from bar", + limit=100, + select_as_cta=False, + rows=104, + error_message="none", + results_key="abc2", + ) + + db.session.add(database) + db.session.add(query_obj) + db.session.commit() + + with mock.patch( + "superset.security_manager.raise_for_access", + side_effect=SupersetSecurityException( + SupersetError( + "dummy", + SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, + ErrorLevel.ERROR, + ) + ), + ): + with pytest.raises(SupersetErrorException) as ex_info: + command.run() + assert ( + ex_info.value.error.error_type + == SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR + ) + + @patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None) + @patch("superset.models.core.Database.get_df") + def test_run_no_results_backend_select_sql(self, get_df_mock: Mock) -> None: + command = export.SqlResultExportCommand("test3") + + database = Database(database_name="my_database3", sqlalchemy_uri="sqlite://") + query_obj = Query( + client_id="test3", + database=database, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql="select * from bar", + executed_sql="select * from bar", + limit=100, + select_as_cta=False, + rows=104, + error_message="none", + results_key="abc3", + ) + + db.session.add(database) + db.session.add(query_obj) + db.session.commit() + + get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]}) + + result = command.run() + + data = result.get("data") + count = result.get("count") + query = result.get("query") + + assert data == "foo\n1\n2\n3\n" + assert count == 3 + assert query.client_id == "test3" + + @patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None) + @patch("superset.models.core.Database.get_df") + def test_run_no_results_backend_executed_sql(self, get_df_mock: Mock) -> None: + command = export.SqlResultExportCommand("test4") + + database = Database(database_name="my_database4", sqlalchemy_uri="sqlite://") + query_obj = Query( + client_id="test4", + database=database, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql=None, + executed_sql="select * from bar limit 2", + limit=100, + select_as_cta=False, + rows=104, + error_message="none", + results_key="abc4", + ) + + db.session.add(database) + db.session.add(query_obj) + db.session.commit() + + get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]}) + + result = command.run() + + data = result.get("data") + count = result.get("count") + query = result.get("query") + + assert data == "foo\n1\n2\n" + assert count == 2 + assert query.client_id == "test4" + + @patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None) + @patch("superset.models.core.Database.get_df") + def test_run_no_results_backend_executed_sql_limiting_factor( + self, get_df_mock: Mock + ) -> None: + command = export.SqlResultExportCommand("test5") + + database = Database(database_name="my_database5", sqlalchemy_uri="sqlite://") + query_obj = Query( + client_id="test5", + database=database, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql=None, + executed_sql="select * from bar limit 2", + limit=100, + select_as_cta=False, + rows=104, + error_message="none", + results_key="abc5", + limiting_factor=LimitingFactor.DROPDOWN, + ) + + db.session.add(database) + db.session.add(query_obj) + db.session.commit() + + get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]}) + + result = command.run() + + data = result.get("data") + count = result.get("count") + query = result.get("query") + + assert data == "foo\n1\n" + assert count == 1 + assert query.client_id == "test5" + + @patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None) + @patch("superset.sqllab.commands.export.results_backend_use_msgpack", False) + def test_run_with_results_backend(self) -> None: + command = export.SqlResultExportCommand("test6") + + database = Database(database_name="my_database6", sqlalchemy_uri="sqlite://") + query_obj = Query( + client_id="test6", + database=database, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql="select * from bar", + executed_sql="select * from bar", + limit=100, + select_as_cta=False, + rows=104, + error_message="none", + results_key="abc6", + ) + + db.session.add(database) + db.session.add(query_obj) + db.session.commit() + + data = [{"foo": i} for i in range(5)] + payload = { + "columns": [{"name": "foo"}], + "data": data, + } + serialized_payload = sql_lab._serialize_payload(payload, False) + compressed = utils.zlib_compress(serialized_payload) + + export.results_backend = mock.Mock() + export.results_backend.get.return_value = compressed + + result = command.run() + + data = result.get("data") + count = result.get("count") + query = result.get("query") + + assert data == "foo\n0\n1\n2\n3\n4\n" + assert count == 5 + assert query.client_id == "test6" + + class TestSqlExecutionResultsCommand(SupersetTestCase): - @mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) + @patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) def test_validation_no_results_backend(self) -> None: results.results_backend = None @@ -44,7 +280,7 @@ def test_validation_no_results_backend(self) -> None: == SupersetErrorType.RESULTS_BACKEND_NOT_CONFIGURED_ERROR ) - @mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) + @patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) def test_validation_data_cannot_be_retrieved(self) -> None: results.results_backend = mock.Mock() results.results_backend.get.return_value = None @@ -55,8 +291,8 @@ def test_validation_data_cannot_be_retrieved(self) -> None: command.run() assert ex_info.value.error.error_type == SupersetErrorType.RESULTS_BACKEND_ERROR - @mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) - def test_validation_query_not_found(self) -> None: + @patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) + def test_validation_data_not_found(self) -> None: data = [{"col_0": i} for i in range(100)] payload = { "status": QueryStatus.SUCCESS, @@ -75,8 +311,8 @@ def test_validation_query_not_found(self) -> None: command.run() assert ex_info.value.error.error_type == SupersetErrorType.RESULTS_BACKEND_ERROR - @mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) - def test_validation_query_not_found2(self) -> None: + @patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) + def test_validation_query_not_found(self) -> None: data = [{"col_0": i} for i in range(104)] payload = { "status": QueryStatus.SUCCESS, @@ -89,9 +325,9 @@ def test_validation_query_not_found2(self) -> None: results.results_backend = mock.Mock() results.results_backend.get.return_value = compressed - database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(database_name="my_database7", sqlalchemy_uri="sqlite://") query_obj = Query( - client_id="foo", + client_id="test8", database=database, tab_name="test_tab", sql_editor_id="test_editor_id", @@ -102,11 +338,12 @@ def test_validation_query_not_found2(self) -> None: select_as_cta=False, rows=104, error_message="none", - results_key="test_abc", + results_key="abc7", ) db.session.add(database) db.session.add(query_obj) + db.session.commit() with mock.patch( "superset.views.utils._deserialize_results_payload", @@ -120,7 +357,7 @@ def test_validation_query_not_found2(self) -> None: == SupersetErrorType.RESULTS_BACKEND_ERROR ) - @mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) + @patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) def test_run_succeeds(self) -> None: data = [{"col_0": i} for i in range(104)] payload = { @@ -134,9 +371,9 @@ def test_run_succeeds(self) -> None: results.results_backend = mock.Mock() results.results_backend.get.return_value = compressed - database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(database_name="my_database8", sqlalchemy_uri="sqlite://") query_obj = Query( - client_id="foo", + client_id="test9", database=database, tab_name="test_tab", sql_editor_id="test_editor_id", @@ -152,6 +389,7 @@ def test_run_succeeds(self) -> None: db.session.add(database) db.session.add(query_obj) + db.session.commit() command = results.SqlExecutionResultsCommand("test_abc", 1000) result = command.run()