diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index ae7367a51673b..a8729940dcdb0 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -108,32 +108,49 @@ def load_birth_names( print(f"Creating table [{tbl_name}] reference") obj = TBL(table_name=tbl_name) db.session.add(obj) - obj.main_dttm_col = "ds" - obj.database = database - obj.filter_select_enabled = True - obj.fetch_metadata() - if not any(col.column_name == "num_california" for col in obj.columns): + _set_table_metadata(obj, database) + _add_table_metrics(obj) + + db.session.commit() + + slices, _ = create_slices(obj, admin_owner=True) + create_dashboard(slices) + + +def _set_table_metadata(datasource: "BaseDatasource", database: "Database") -> None: + datasource.main_dttm_col = "ds" # type: ignore + datasource.database = database + datasource.filter_select_enabled = True + datasource.fetch_metadata() + + +def _add_table_metrics(datasource: "BaseDatasource") -> None: + if not any(col.column_name == "num_california" for col in datasource.columns): col_state = str(column("state").compile(db.engine)) col_num = str(column("num").compile(db.engine)) - obj.columns.append( + datasource.columns.append( TableColumn( column_name="num_california", expression=f"CASE WHEN {col_state} = 'CA' THEN {col_num} ELSE 0 END", ) ) - if not any(col.metric_name == "sum__num" for col in obj.metrics): + if not any(col.metric_name == "sum__num" for col in datasource.metrics): col = str(column("num").compile(db.engine)) - obj.metrics.append(SqlMetric(metric_name="sum__num", expression=f"SUM({col})")) - - db.session.commit() + datasource.metrics.append( + SqlMetric(metric_name="sum__num", expression=f"SUM({col})") + ) - slices, _ = create_slices(obj) - create_dashboard(slices) + for col in datasource.columns: + if col.column_name == "ds": + col.is_dttm = True # type: ignore + break -def create_slices(tbl: BaseDatasource) -> Tuple[List[Slice], List[Slice]]: +def create_slices( + tbl: BaseDatasource, admin_owner: bool +) -> Tuple[List[Slice], List[Slice]]: metrics = [ { "expressionType": "SIMPLE", @@ -160,9 +177,17 @@ def create_slices(tbl: BaseDatasource) -> Tuple[List[Slice], List[Slice]]: "markup_type": "markdown", } - slice_props = dict( - datasource_id=tbl.id, datasource_type="table", owners=[admin], created_by=admin - ) + if admin_owner: + slice_props = dict( + datasource_id=tbl.id, + datasource_type="table", + owners=[admin], + created_by=admin, + ) + else: + slice_props = dict( + datasource_id=tbl.id, datasource_type="table", owners=[], created_by=admin + ) print("Creating some slices") slices = [ @@ -475,7 +500,7 @@ def create_slices(tbl: BaseDatasource) -> Tuple[List[Slice], List[Slice]]: return slices, misc_slices -def create_dashboard(slices: List[Slice]) -> None: +def create_dashboard(slices: List[Slice]) -> Dashboard: print("Creating a dashboard") dash = db.session.query(Dashboard).filter_by(slug="births").first() @@ -779,3 +804,4 @@ def create_dashboard(slices: List[Slice]) -> None: dash.position_json = json.dumps(pos, indent=4) dash.slug = "births" db.session.commit() + return dash diff --git a/tests/access_tests.py b/tests/access_tests.py index 2dec294bbb2df..4e568a57d5f39 100644 --- a/tests/access_tests.py +++ b/tests/access_tests.py @@ -19,6 +19,7 @@ import json import unittest from unittest import mock +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices import pytest @@ -142,6 +143,7 @@ def test_override_role_permissions_is_admin_only(self): ) self.assertNotEqual(405, response.status_code) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_override_role_permissions_1_table(self): response = self.client.post( "/superset/override_role_permissions/", @@ -160,6 +162,7 @@ def test_override_role_permissions_1_table(self): "datasource_access", updated_override_me.permissions[0].permission.name ) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_override_role_permissions_druid_and_table(self): response = self.client.post( "/superset/override_role_permissions/", @@ -187,7 +190,9 @@ def test_override_role_permissions_druid_and_table(self): ) self.assertEqual(3, len(perms)) - @pytest.mark.usefixtures("load_energy_table_with_slice") + @pytest.mark.usefixtures( + "load_energy_table_with_slice", "load_birth_names_dashboard_with_slices" + ) def test_override_role_permissions_drops_absent_perms(self): override_me = security_manager.find_role("override_me") override_me.permissions.append( @@ -247,6 +252,7 @@ def test_clean_requests_after_role_extend(self): gamma_user = security_manager.find_user(username="gamma") gamma_user.roles.remove(security_manager.find_role("test_role1")) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_clean_requests_after_alpha_grant(self): session = db.session diff --git a/tests/cache_tests.py b/tests/cache_tests.py index 3ffd52a378163..43a4cf6f3fcd1 100644 --- a/tests/cache_tests.py +++ b/tests/cache_tests.py @@ -17,9 +17,12 @@ """Unit tests for Superset with caching""" import json +import pytest + from superset import app, db from superset.extensions import cache_manager from superset.utils.core import QueryStatus +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices from .base_tests import SupersetTestCase @@ -34,6 +37,7 @@ def tearDown(self): cache_manager.cache.clear() cache_manager.data_cache.clear() + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_no_data_cache(self): data_cache_config = app.config["DATA_CACHE_CONFIG"] app.config["DATA_CACHE_CONFIG"] = {"CACHE_TYPE": "null"} @@ -54,6 +58,7 @@ def test_no_data_cache(self): self.assertFalse(resp["is_cached"]) self.assertFalse(resp_from_cache["is_cached"]) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_slice_data_cache(self): # Override cache config data_cache_config = app.config["DATA_CACHE_CONFIG"] diff --git a/tests/celery_tests.py b/tests/celery_tests.py index f492cf037db1e..689eea814bea9 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -23,6 +23,7 @@ import time import unittest.mock as mock from typing import Optional +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices import pytest @@ -160,6 +161,7 @@ def test_run_sync_query_dont_exist(setup_sqllab, ctas_method): } +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) def test_run_sync_query_cta(setup_sqllab, ctas_method): tmp_table_name = f"{TEST_SYNC}_{ctas_method.lower()}" @@ -173,7 +175,10 @@ def test_run_sync_query_cta(setup_sqllab, ctas_method): assert QueryStatus.SUCCESS == results["status"], results assert len(results["data"]) > 0 + delete_tmp_view_or_table(tmp_table_name, ctas_method) + +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_run_sync_query_cta_no_data(setup_sqllab): sql_empty_result = "SELECT * FROM birth_names WHERE name='random'" result = run_sql(sql_empty_result) @@ -184,6 +189,7 @@ def test_run_sync_query_cta_no_data(setup_sqllab): assert QueryStatus.SUCCESS == query.status +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) @mock.patch( "superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME @@ -208,7 +214,10 @@ def test_run_sync_query_cta_config(setup_sqllab, ctas_method): results = run_sql(query.select_sql) assert QueryStatus.SUCCESS == results["status"], result + delete_tmp_view_or_table(f"{CTAS_SCHEMA_NAME}.{tmp_table_name}", ctas_method) + +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) @mock.patch( "superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME @@ -232,7 +241,10 @@ def test_run_async_query_cta_config(setup_sqllab, ctas_method): == query.executed_sql ) + delete_tmp_view_or_table(f"{CTAS_SCHEMA_NAME}.{tmp_table_name}", ctas_method) + +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) def test_run_async_cta_query(setup_sqllab, ctas_method): table_name = f"{TEST_ASYNC_CTA}_{ctas_method.lower()}" @@ -252,7 +264,10 @@ def test_run_async_cta_query(setup_sqllab, ctas_method): assert query.select_as_cta assert query.select_as_cta_used + delete_tmp_view_or_table(table_name, ctas_method) + +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) def test_run_async_cta_query_with_lower_limit(setup_sqllab, ctas_method): tmp_table = f"{TEST_ASYNC_LOWER_LIMIT}_{ctas_method.lower()}" @@ -272,6 +287,8 @@ def test_run_async_cta_query_with_lower_limit(setup_sqllab, ctas_method): assert query.select_as_cta assert query.select_as_cta_used + delete_tmp_view_or_table(tmp_table, ctas_method) + SERIALIZATION_DATA = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))] CURSOR_DESCR = ( @@ -306,6 +323,7 @@ def test_new_data_serialization(): assert isinstance(data[0], bytes) +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_default_payload_serialization(): use_new_deserialization = False db_engine_spec = BaseEngineSpec() @@ -338,6 +356,7 @@ def test_default_payload_serialization(): assert isinstance(serialized, str) +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_msgpack_payload_serialization(): use_new_deserialization = True db_engine_spec = BaseEngineSpec() @@ -406,3 +425,7 @@ def my_task(): my_task() finally: flask._app_ctx_stack.push(popped_app) + + +def delete_tmp_view_or_table(name: str, db_object_type: str): + db.get_engine().execute(f"DROP {db_object_type} IF EXISTS {name}") diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index b8a46cf0b71ae..99ab93cb85aeb 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -18,16 +18,19 @@ """Unit tests for Superset""" import json from typing import List, Optional -from datetime import datetime +from datetime import datetime, timedelta from io import BytesIO from unittest import mock from zipfile import is_zipfile, ZipFile +from superset.models.sql_lab import Query +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices + import humanize import prison import pytest import yaml -from sqlalchemy import and_ +from sqlalchemy import and_, or_ from sqlalchemy.sql import func from tests.test_app import app @@ -41,7 +44,7 @@ from superset.models.reports import ReportSchedule, ReportScheduleType from superset.models.slice import Slice from superset.utils import core as utils -from superset.utils.core import AnnotationType, get_example_database +from superset.utils.core import AnnotationType, get_example_database, get_main_database from tests.base_api_tests import ApiOwnersTestCaseMixin from tests.base_tests import SupersetTestCase, post_assert_metric, test_client @@ -57,6 +60,7 @@ from tests.fixtures.query_context import get_query_context, ANNOTATION_LAYERS from tests.fixtures.unicode_dashboard import load_unicode_dashboard_with_slice from tests.annotation_layers.fixtures import create_annotation_layers +from tests.utils.get_dashboards import get_dashboards_ids CHART_DATA_URI = "api/v1/chart/data" CHARTS_FIXTURE_COUNT = 10 @@ -431,10 +435,12 @@ def test_delete_bulk_chart_not_owned(self): db.session.delete(user_alpha2) db.session.commit() + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_create_chart(self): """ Chart API: Test create chart """ + dashboards_ids = get_dashboards_ids(db, ["world_health", "births"]) admin_id = self.get_user("admin").id chart_data = { "slice_name": "name1", @@ -445,7 +451,7 @@ def test_create_chart(self): "cache_timeout": 1000, "datasource_id": 1, "datasource_type": "table", - "dashboards": [1, 2], + "dashboards": dashboards_ids, } self.login(username="admin") uri = f"api/v1/chart/" @@ -733,6 +739,7 @@ def test_get_chart_not_found(self): rv = self.get_assert_metric(uri, "get") self.assertEqual(rv.status_code, 404) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_chart_no_data_access(self): """ Chart API: Test get chart without data access @@ -747,8 +754,11 @@ def test_get_chart_no_data_access(self): rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) - @pytest.mark.usefixtures("load_unicode_dashboard_with_slice") - @pytest.mark.usefixtures("load_energy_table_with_slice") + @pytest.mark.usefixtures( + "load_energy_table_with_slice", + "load_birth_names_dashboard_with_slices", + "load_unicode_dashboard_with_slice", + ) def test_get_charts(self): """ Chart API: Test get charts @@ -788,6 +798,7 @@ def test_get_charts_changed_on(self): db.session.delete(chart) db.session.commit() + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_charts_filter(self): """ Chart API: Test get charts filter @@ -995,7 +1006,9 @@ def test_get_time_range(self): self.assertEqual(len(data["result"]), 3) @pytest.mark.usefixtures( - "load_unicode_dashboard_with_slice", "load_energy_table_with_slice" + "load_unicode_dashboard_with_slice", + "load_energy_table_with_slice", + "load_birth_names_dashboard_with_slices", ) def test_get_charts_page(self): """ @@ -1028,6 +1041,7 @@ def test_get_charts_no_data_access(self): data = json.loads(rv.data.decode("utf-8")) self.assertEqual(data["count"], 0) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_simple(self): """ Chart data API: Test chart data query @@ -1037,8 +1051,10 @@ def test_chart_data_simple(self): rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["result"][0]["rowcount"], 45) + expected_row_count = self.get_expected_row_count("client_id_1") + self.assertEqual(data["result"][0]["rowcount"], expected_row_count) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_applied_time_extras(self): """ Chart data API: Test chart data query with applied time extras @@ -1060,8 +1076,10 @@ def test_chart_data_applied_time_extras(self): data["result"][0]["rejected_filters"], [{"column": "__time_origin", "reason": "not_druid_datasource"},], ) - self.assertEqual(data["result"][0]["rowcount"], 45) + expected_row_count = self.get_expected_row_count("client_id_2") + self.assertEqual(data["result"][0]["rowcount"], expected_row_count) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_limit_offset(self): """ Chart data API: Test chart data query with limit and offset @@ -1090,6 +1108,7 @@ def test_chart_data_limit_offset(self): self.assertEqual(result["rowcount"], 5) self.assertEqual(result["data"][0]["name"], expected_name) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch( "superset.common.query_object.config", {**app.config, "ROW_LIMIT": 7}, ) @@ -1105,6 +1124,7 @@ def test_chart_data_default_row_limit(self): result = response_payload["result"][0] self.assertEqual(result["rowcount"], 7) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch( "superset.common.query_context.config", {**app.config, "SAMPLES_ROW_LIMIT": 5}, ) @@ -1131,6 +1151,7 @@ def test_chart_data_incorrect_result_type(self): rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") self.assertEqual(rv.status_code, 400) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_incorrect_result_format(self): """ Chart data API: Test chart data with unsupported result format @@ -1141,6 +1162,7 @@ def test_chart_data_incorrect_result_format(self): rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") self.assertEqual(rv.status_code, 400) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_query_result_type(self): """ Chart data API: Test chart data with query result format @@ -1151,6 +1173,7 @@ def test_chart_data_query_result_type(self): rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") self.assertEqual(rv.status_code, 200) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_csv_result_format(self): """ Chart data API: Test chart data with CSV result format @@ -1161,6 +1184,7 @@ def test_chart_data_csv_result_format(self): rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") self.assertEqual(rv.status_code, 200) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_mixed_case_filter_op(self): """ Chart data API: Ensure mixed case filter operator generates valid result @@ -1208,6 +1232,7 @@ def test_chart_data_prophet(self): self.assertIn("sum__num__yhat_lower", row) self.assertEqual(result["rowcount"], 47) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_query_missing_filter(self): """ Chart data API: Ensure filter referencing missing column is ignored @@ -1223,6 +1248,7 @@ def test_chart_data_query_missing_filter(self): response_payload = json.loads(rv.data.decode("utf-8")) assert "non_existent_filter" not in response_payload["result"][0]["query"] + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_no_data(self): """ Chart data API: Test chart data with empty result @@ -1283,6 +1309,7 @@ def test_query_exec_not_allowed(self): rv = self.post_assert_metric(CHART_DATA_URI, payload, "data") self.assertEqual(rv.status_code, 401) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_jinja_filter_request(self): """ Chart data API: Ensure request referencing filters via jinja renders a correct query @@ -1325,6 +1352,7 @@ def test_chart_data_async(self): "superset.extensions.feature_flag_manager._feature_flags", GLOBAL_ASYNC_QUERIES=True, ) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_async_results_type(self): """ Chart data API: Test chart data query non-JSON format (async) @@ -1353,6 +1381,7 @@ def test_chart_data_async_invalid_token(self): rv = post_assert_metric(test_client, CHART_DATA_URI, request_payload, "data") self.assertEqual(rv.status_code, 401) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch.dict( "superset.extensions.feature_flag_manager._feature_flags", GLOBAL_ASYNC_QUERIES=True, @@ -1379,8 +1408,9 @@ def mock_run(self, **kwargs): ) data = json.loads(rv.data.decode("utf-8")) + expected_row_count = self.get_expected_row_count("client_id_3") self.assertEqual(rv.status_code, 200) - self.assertEqual(data["result"][0]["rowcount"], 45) + self.assertEqual(data["result"][0]["rowcount"], expected_row_count) @mock.patch.dict( "superset.extensions.feature_flag_manager._feature_flags", @@ -1609,7 +1639,9 @@ def test_import_chart_invalid(self): "message": {"metadata.yaml": {"type": ["Must be equal to Slice."]}} } - @pytest.mark.usefixtures("create_annotation_layers") + @pytest.mark.usefixtures( + "create_annotation_layers", "load_birth_names_dashboard_with_slices" + ) def test_chart_data_annotations(self): """ Chart data API: Test chart data query @@ -1648,3 +1680,32 @@ def test_chart_data_annotations(self): data = json.loads(rv.data.decode("utf-8")) # response should only contain interval and event data, not formula self.assertEqual(len(data["result"][0]["annotation_data"]), 2) + + def get_expected_row_count(self, client_id: str) -> int: + start_date = datetime.now() + start_date = start_date.replace( + year=start_date.year - 100, hour=0, minute=0, second=0 + ) + + quoted_table_name = self.quote_name("birth_names") + sql = f""" + SELECT COUNT(*) AS rows_count FROM ( + SELECT name AS name, SUM(num) AS sum__num + FROM {quoted_table_name} + WHERE ds >= '{start_date.strftime("%Y-%m-%d %H:%M:%S")}' + AND gender = 'boy' + GROUP BY name + ORDER BY sum__num DESC + LIMIT 100) AS inner__query + """ + resp = self.run_sql(sql, client_id, raise_on_error=True) + db.session.query(Query).delete() + db.session.commit() + return resp["data"][0]["rows_count"] + + def quote_name(self, name: str): + if get_main_database().backend in {"presto", "hive"}: + return get_example_database().inspector.engine.dialect.identifier_preparer.quote_identifier( + name + ) + return name diff --git a/tests/conftest.py b/tests/conftest.py index 0e197b84167d2..efa549f6a3346 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -43,14 +43,12 @@ def setup_sample_data() -> Any: examples.load_css_templates() examples.load_world_bank_health_n_pop(sample=True) - examples.load_birth_names(sample=True) yield with app.app_context(): engine = get_example_database().get_sqla_engine() engine.execute("DROP TABLE wb_health_population") - engine.execute("DROP TABLE birth_names") # drop sqlachemy tables diff --git a/tests/core_tests.py b/tests/core_tests.py index 97ae5eb2ada04..16620b406da24 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -25,6 +25,7 @@ import logging from typing import Dict, List from urllib.parse import quote +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices import pytest import pytz @@ -100,6 +101,7 @@ def test_dashboard_endpoint(self): resp = self.client.get("/superset/dashboard/-1/") assert resp.status_code == 404 + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_slice_endpoint(self): self.login(username="admin") slc = self.get_slice("Girls", db.session) @@ -114,6 +116,7 @@ def test_slice_endpoint(self): resp = self.client.get("/superset/slice/-1/") assert resp.status_code == 404 + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_viz_cache_key(self): self.login(username="admin") slc = self.get_slice("Girls", db.session) @@ -327,6 +330,7 @@ def test_filter_endpoint(self): assert len(resp) > 0 assert "energy_target0" in resp + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_slice_data(self): # slice data should have some required attributes self.login(username="admin") @@ -372,6 +376,7 @@ def test_add_slice(self): resp = self.client.get(url) self.assertEqual(resp.status_code, 200) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_user_slices_for_owners(self): self.login(username="alpha") user = security_manager.find_user("alpha") @@ -577,7 +582,9 @@ def test_databaseview_edit(self, username="admin"): database.allow_run_async = False db.session.commit() - @pytest.mark.usefixtures("load_energy_table_with_slice") + @pytest.mark.usefixtures( + "load_energy_table_with_slice", "load_birth_names_dashboard_with_slices" + ) def test_warm_up_cache(self): self.login() slc = self.get_slice("Girls", db.session) @@ -602,6 +609,7 @@ def test_warm_up_cache(self): + quote(json.dumps([{"col": "name", "op": "in", "val": ["Jennifer"]}])) ) == [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}] + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_cache_logging(self): store_cache_keys = app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = True @@ -649,12 +657,21 @@ def test_gamma(self): assert "Charts" in self.get_resp("/chart/list/") assert "Dashboards" in self.get_resp("/dashboard/list/") + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_csv_endpoint(self): self.login() - sql = """ + client_id = "{}".format(random.getrandbits(64))[:10] + get_name_sql = """ + SELECT name + FROM birth_names + LIMIT 1 + """ + resp = self.run_sql(get_name_sql, client_id, raise_on_error=True) + name = resp["data"][0]["name"] + sql = f""" SELECT name FROM birth_names - WHERE name = 'James' + WHERE name = '{name}' LIMIT 1 """ client_id = "{}".format(random.getrandbits(64))[:10] @@ -662,18 +679,19 @@ def test_csv_endpoint(self): resp = self.get_resp("/superset/csv/{}".format(client_id)) data = csv.reader(io.StringIO(resp)) - expected_data = csv.reader(io.StringIO("name\nJames\n")) + expected_data = csv.reader(io.StringIO(f"name\n{name}\n")) client_id = "{}".format(random.getrandbits(64))[:10] self.run_sql(sql, client_id, raise_on_error=True) resp = self.get_resp("/superset/csv/{}".format(client_id)) data = csv.reader(io.StringIO(resp)) - expected_data = csv.reader(io.StringIO("name\nJames\n")) + expected_data = csv.reader(io.StringIO(f"name\n{name}\n")) self.assertEqual(list(expected_data), list(data)) self.logout() + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_extra_table_metadata(self): self.login() example_db = utils.get_example_database() @@ -730,6 +748,7 @@ def test_fetch_datasource_metadata(self): for k in keys: self.assertIn(k, resp.keys()) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_user_profile(self, username="admin"): self.login(username=username) slc = self.get_slice("Girls", db.session) @@ -762,6 +781,7 @@ def test_user_profile(self, username="admin"): data = self.get_json_resp(f"/superset/fave_dashboards_by_username/{username}/") self.assertNotIn("message", data) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_slice_id_is_always_logged_correctly_on_web_request(self): # superset/explore case slc = db.session.query(Slice).filter_by(slice_name="Girls").one() @@ -845,6 +865,7 @@ def test_slice_payload_no_datasource(self): "The datasource associated with this chart no longer exists", ) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_explore_json(self): tbl_id = self.table_ids.get("birth_names") form_data = { @@ -867,6 +888,7 @@ def test_explore_json(self): self.assertEqual(rv.status_code, 200) self.assertEqual(data["rowcount"], 2) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch.dict( "superset.extensions.feature_flag_manager._feature_flags", GLOBAL_ASYNC_QUERIES=True, @@ -897,6 +919,7 @@ def test_explore_json_async(self): keys, ["channel_id", "job_id", "user_id", "status", "errors", "result_url"] ) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch.dict( "superset.extensions.feature_flag_manager._feature_flags", GLOBAL_ASYNC_QUERIES=True, @@ -922,6 +945,7 @@ def test_explore_json_async_results_format(self): ) self.assertEqual(rv.status_code, 200) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch( "superset.utils.cache_manager.CacheManager.cache", new_callable=mock.PropertyMock, @@ -1029,6 +1053,7 @@ def test_schemas_access_for_csv_upload_endpoint( assert data == ["this_schema_is_allowed_too"] self.delete_fake_db() + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_select_star(self): self.login(username="admin") examples_db = utils.get_example_database() diff --git a/tests/dashboard_tests.py b/tests/dashboard_tests.py index 5874de8bb7fd5..8d4df16e67848 100644 --- a/tests/dashboard_tests.py +++ b/tests/dashboard_tests.py @@ -20,6 +20,7 @@ import json import unittest from random import random +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices import pytest from flask import escape, url_for @@ -129,6 +130,7 @@ def test_new_dashboard(self): dash_count_after = db.session.query(func.count(Dashboard.id)).first()[0] self.assertEqual(dash_count_before + 1, dash_count_after) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_dashboard_modes(self): self.login(username="admin") dash = db.session.query(Dashboard).filter_by(slug="births").first() @@ -142,6 +144,7 @@ def test_dashboard_modes(self): self.assertIn("standalone_mode": true", resp) self.assertIn('', resp) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_save_dash(self, username="admin"): self.login(username=username) dash = db.session.query(Dashboard).filter_by(slug="births").first() @@ -213,6 +216,7 @@ def test_save_dash_with_invalid_filters(self, username="admin"): new_url = updatedDash.url self.assertNotIn("region", new_url) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_save_dash_with_dashboard_title(self, username="admin"): self.login(username=username) dash = db.session.query(Dashboard).filter_by(slug="births").first() @@ -234,6 +238,7 @@ def test_save_dash_with_dashboard_title(self, username="admin"): data["dashboard_title"] = origin_title self.get_resp(url, data=dict(data=json.dumps(data))) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_save_dash_with_colors(self, username="admin"): self.login(username=username) dash = db.session.query(Dashboard).filter_by(slug="births").first() @@ -263,7 +268,9 @@ def test_save_dash_with_colors(self, username="admin"): self.get_resp(url, data=dict(data=json.dumps(data))) @pytest.mark.usefixtures( - "cleanup_copied_dash", "load_unicode_dashboard_with_position" + "load_birth_names_dashboard_with_slices", + "cleanup_copied_dash", + "load_unicode_dashboard_with_position", ) def test_copy_dash(self, username="admin"): self.login(username=username) @@ -303,7 +310,9 @@ def test_copy_dash(self, username="admin"): if key not in ["modified", "changed_on", "changed_on_humanized"]: self.assertEqual(slc[key], resp["slices"][index][key]) - @pytest.mark.usefixtures("load_energy_table_with_slice") + @pytest.mark.usefixtures( + "load_energy_table_with_slice", "load_birth_names_dashboard_with_slices" + ) def test_add_slices(self, username="admin"): self.login(username=username) dash = db.session.query(Dashboard).filter_by(slug="births").first() @@ -332,6 +341,7 @@ def test_add_slices(self, username="admin"): dash.slices = [o for o in dash.slices if o.slice_name != "Energy Force Layout"] db.session.commit() + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_remove_slices(self, username="admin"): self.login(username=username) dash = db.session.query(Dashboard).filter_by(slug="births").first() @@ -364,6 +374,7 @@ def test_remove_slices(self, username="admin"): data = dash.data self.assertEqual(len(data["slices"]), origin_slices_length - 1) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_public_user_dashboard_access(self): table = db.session.query(SqlaTable).filter_by(table_name="birth_names").one() @@ -404,6 +415,7 @@ def test_public_user_dashboard_access(self): # Cleanup self.revoke_public_access_to_table(table) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_dashboard_with_created_by_can_be_accessed_by_public_users(self): self.logout() table = db.session.query(SqlaTable).filter_by(table_name="birth_names").one() @@ -419,6 +431,7 @@ def test_dashboard_with_created_by_can_be_accessed_by_public_users(self): # Cleanup self.revoke_public_access_to_table(table) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_only_owners_can_save(self): dash = db.session.query(Dashboard).filter_by(slug="births").first() dash.owners = [] diff --git a/tests/dashboards/api_tests.py b/tests/dashboards/api_tests.py index cd91badc24c7a..299dd2de64d9e 100644 --- a/tests/dashboards/api_tests.py +++ b/tests/dashboards/api_tests.py @@ -22,6 +22,7 @@ from typing import List, Optional from unittest.mock import patch from zipfile import is_zipfile, ZipFile +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices import pytest import prison @@ -29,7 +30,7 @@ from sqlalchemy.sql import func from freezegun import freeze_time -from sqlalchemy import and_ +from sqlalchemy import and_, or_ from superset import db, security_manager from superset.models.dashboard import Dashboard from superset.models.core import FavStar, FavStarClassName @@ -47,7 +48,7 @@ dataset_config, dataset_metadata_config, ) - +from tests.utils.get_dashboards import get_dashboards_ids DASHBOARDS_FIXTURE_COUNT = 10 @@ -654,6 +655,7 @@ def test_delete_bulk_dashboard_admin_not_owned(self): model = db.session.query(Dashboard).get(dashboard_id) self.assertEqual(model, None) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_delete_dashboard_not_owned(self): """ Dashboard API: Test delete try not owned @@ -679,6 +681,7 @@ def test_delete_dashboard_not_owned(self): db.session.delete(user_alpha2) db.session.commit() + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_delete_bulk_dashboard_not_owned(self): """ Dashboard API: Test delete bulk try not owned @@ -906,6 +909,7 @@ def test_update_dashboard(self): db.session.delete(model) db.session.commit() + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_update_dashboard_chart_owners(self): """ Dashboard API: Test update chart owners @@ -1071,6 +1075,7 @@ def test_update_published(self): db.session.delete(model) db.session.commit() + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_update_dashboard_not_owned(self): """ Dashboard API: Test update dashboard not owned @@ -1147,8 +1152,8 @@ def test_export_bundle(self): """ Dashboard API: Test dashboard export """ - argument = [1, 2] - uri = f"api/v1/dashboard/export/?q={prison.dumps(argument)}" + dashboards_ids = get_dashboards_ids(db, ["world_health", "births"]) + uri = f"api/v1/dashboard/export/?q={prison.dumps(dashboards_ids)}" self.login(username="admin") rv = self.client.get(uri) diff --git a/tests/databases/api_tests.py b/tests/databases/api_tests.py index a49ce3eddd417..46af4f06293c0 100644 --- a/tests/databases/api_tests.py +++ b/tests/databases/api_tests.py @@ -21,6 +21,7 @@ from io import BytesIO from unittest import mock from zipfile import is_zipfile, ZipFile +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices import prison import pytest @@ -559,6 +560,7 @@ def test_delete_database_with_report(self): } self.assertEqual(response, expected_response) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_table_metadata(self): """ Database API: Test get table metadata info @@ -622,6 +624,7 @@ def test_get_table_metadata_no_db_permission(self): rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_select_star(self): """ Database API: Test get select star @@ -843,7 +846,9 @@ def test_test_connection_unsafe_uri(self): app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False @pytest.mark.usefixtures( - "load_unicode_dashboard_with_position", "load_energy_table_with_slice" + "load_unicode_dashboard_with_position", + "load_energy_table_with_slice", + "load_birth_names_dashboard_with_slices", ) def test_get_database_related_objects(self): """ diff --git a/tests/databases/commands_tests.py b/tests/databases/commands_tests.py index 70fcd7c6186f3..3b1767fdc2397 100644 --- a/tests/databases/commands_tests.py +++ b/tests/databases/commands_tests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=no-self-use, invalid-name - from unittest.mock import patch import pytest @@ -31,6 +30,8 @@ from superset.models.core import Database from superset.utils.core import backend, get_example_database from tests.base_tests import SupersetTestCase +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices +from tests.fixtures.energy_dashboard import load_energy_table_with_slice from tests.fixtures.importexport import ( database_config, database_metadata_config, @@ -41,10 +42,15 @@ class TestExportDatabasesCommand(SupersetTestCase): @patch("superset.security.manager.g") + @pytest.mark.usefixtures( + "load_birth_names_dashboard_with_slices", "load_energy_table_with_slice" + ) def test_export_database_command(self, mock_g): mock_g.user = security_manager.find_user("admin") example_db = get_example_database() + db_uuid = example_db.uuid + command = ExportDatabasesCommand([example_db.id]) contents = dict(command.run()) @@ -68,6 +74,18 @@ def test_export_database_command(self, mock_g): assert core_files.issubset(set(contents.keys())) + if example_db.backend == "postgresql": + ds_type = "TIMESTAMP WITHOUT TIME ZONE" + elif example_db.backend == "hive": + ds_type = "TIMESTAMP" + elif example_db.backend == "presto": + ds_type = "VARCHAR(255)" + else: + ds_type = "DATETIME" + if example_db.backend == "mysql": + big_int_type = "BIGINT(20)" + else: + big_int_type = "BIGINT" metadata = yaml.safe_load(contents["databases/examples.yaml"]) assert metadata == ( { @@ -87,153 +105,149 @@ def test_export_database_command(self, mock_g): metadata = yaml.safe_load(contents["datasets/examples/birth_names.yaml"]) metadata.pop("uuid") - assert metadata == { - "table_name": "birth_names", - "main_dttm_col": None, - "description": "Adding a DESCRip", - "default_endpoint": "", - "offset": 66, - "cache_timeout": 55, - "schema": "", - "sql": "", - "params": None, - "template_params": None, - "filter_select_enabled": True, - "fetch_values_predicate": None, - "extra": None, - "metrics": [ - { - "metric_name": "ratio", - "verbose_name": "Ratio Boys/Girls", - "metric_type": None, - "expression": "sum(num_boys) / sum(num_girls)", - "description": "This represents the ratio of boys/girls", - "d3format": ".2%", - "extra": None, - "warning_text": "no warning", - }, - { - "metric_name": "sum__num", - "verbose_name": "Babies", - "metric_type": None, - "expression": "SUM(num)", - "description": "", - "d3format": "", - "extra": None, - "warning_text": "", - }, + + metadata["columns"].sort(key=lambda x: x["column_name"]) + expected_metadata = { + "cache_timeout": None, + "columns": [ { - "metric_name": "count", - "verbose_name": "", - "metric_type": None, - "expression": "count(1)", + "column_name": "ds", "description": None, - "d3format": None, - "extra": None, - "warning_text": None, + "expression": None, + "filterable": True, + "groupby": True, + "is_active": True, + "is_dttm": True, + "python_date_format": None, + "type": ds_type, + "verbose_name": None, }, - ], - "columns": [ { - "column_name": "num_california", - "verbose_name": None, - "is_dttm": False, - "is_active": None, - "type": "NUMBER", - "groupby": False, - "filterable": False, - "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END", + "column_name": "gender", "description": None, + "expression": None, + "filterable": True, + "groupby": True, + "is_active": True, + "is_dttm": False, "python_date_format": None, + "type": "STRING" if example_db.backend == "hive" else "VARCHAR(16)", + "verbose_name": None, }, { - "column_name": "ds", - "verbose_name": "", - "is_dttm": True, - "is_active": None, - "type": "DATETIME", - "groupby": True, - "filterable": True, - "expression": "", + "column_name": "name", "description": None, + "expression": None, + "filterable": True, + "groupby": True, + "is_active": True, + "is_dttm": False, "python_date_format": None, + "type": "STRING" + if example_db.backend == "hive" + else "VARCHAR(255)", + "verbose_name": None, }, { - "column_name": "num_girls", - "verbose_name": None, - "is_dttm": False, - "is_active": None, - "type": "BIGINT(20)", - "groupby": False, - "filterable": False, - "expression": "", + "column_name": "num", "description": None, + "expression": None, + "filterable": True, + "groupby": True, + "is_active": True, + "is_dttm": False, "python_date_format": None, + "type": big_int_type, + "verbose_name": None, }, { - "column_name": "gender", - "verbose_name": None, - "is_dttm": False, - "is_active": None, - "type": "VARCHAR(16)", - "groupby": True, - "filterable": True, - "expression": "", + "column_name": "num_california", "description": None, + "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END", + "filterable": True, + "groupby": True, + "is_active": True, + "is_dttm": False, "python_date_format": None, + "type": None, + "verbose_name": None, }, { "column_name": "state", - "verbose_name": None, - "is_dttm": None, - "is_active": None, - "type": "VARCHAR(10)", - "groupby": True, - "filterable": True, - "expression": None, "description": None, + "expression": None, + "filterable": True, + "groupby": True, + "is_active": True, + "is_dttm": False, "python_date_format": None, + "type": "STRING" if example_db.backend == "hive" else "VARCHAR(10)", + "verbose_name": None, }, { "column_name": "num_boys", - "verbose_name": None, - "is_dttm": None, - "is_active": None, - "type": "BIGINT(20)", - "groupby": True, - "filterable": True, - "expression": None, "description": None, + "expression": None, + "filterable": True, + "groupby": True, + "is_active": True, + "is_dttm": False, "python_date_format": None, + "type": big_int_type, + "verbose_name": None, }, { - "column_name": "num", - "verbose_name": None, - "is_dttm": None, - "is_active": None, - "type": "BIGINT(20)", - "groupby": True, - "filterable": True, - "expression": None, + "column_name": "num_girls", "description": None, + "expression": None, + "filterable": True, + "groupby": True, + "is_active": True, + "is_dttm": False, "python_date_format": None, + "type": big_int_type, + "verbose_name": None, }, + ], + "database_uuid": str(db_uuid), + "default_endpoint": None, + "description": "", + "extra": None, + "fetch_values_predicate": None, + "filter_select_enabled": True, + "main_dttm_col": "ds", + "metrics": [ { - "column_name": "name", - "verbose_name": None, - "is_dttm": None, - "is_active": None, - "type": "VARCHAR(255)", - "groupby": True, - "filterable": True, - "expression": None, + "d3format": None, "description": None, - "python_date_format": None, + "expression": "COUNT(*)", + "extra": None, + "metric_name": "count", + "metric_type": "count", + "verbose_name": "COUNT(*)", + "warning_text": None, + }, + { + "d3format": None, + "description": None, + "expression": "SUM(num)", + "extra": None, + "metric_name": "sum__num", + "metric_type": None, + "verbose_name": None, + "warning_text": None, }, ], + "offset": 0, + "params": None, + "schema": None, + "sql": None, + "table_name": "birth_names", + "template_params": None, "version": "1.0.0", - "database_uuid": str(example_db.uuid), } + expected_metadata["columns"].sort(key=lambda x: x["column_name"]) + assert metadata == expected_metadata @patch("superset.security.manager.g") def test_export_database_command_no_access(self, mock_g): diff --git a/tests/datasets/api_tests.py b/tests/datasets/api_tests.py index ee0c224465fe4..ba1e9996ffb3b 100644 --- a/tests/datasets/api_tests.py +++ b/tests/datasets/api_tests.py @@ -40,6 +40,7 @@ from superset.utils.dict_import_export import export_to_dict from tests.base_tests import SupersetTestCase from tests.conftest import CTAS_SCHEMA_NAME +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices from tests.fixtures.energy_dashboard import load_energy_table_with_slice from tests.fixtures.importexport import ( database_config, @@ -272,13 +273,12 @@ def pg_test_query_parameter(query_parameter, expected_response): ) ) schema_values = [ - "", "admin_database", "information_schema", "public", ] expected_response = { - "count": 4, + "count": 3, "result": [{"text": val, "value": val} for val in schema_values], } self.login(username="admin") @@ -301,15 +301,10 @@ def pg_test_query_parameter(query_parameter, expected_response): ) query_parameter = {"page": 0, "page_size": 1} - pg_test_query_parameter( - query_parameter, {"count": 4, "result": [{"text": "", "value": ""}]}, - ) - - query_parameter = {"page": 1, "page_size": 1} pg_test_query_parameter( query_parameter, { - "count": 4, + "count": 3, "result": [{"text": "admin_database", "value": "admin_database"}], }, ) @@ -1182,6 +1177,7 @@ def test_export_dataset_bundle_gamma(self): # gamma users by default do not have access to this dataset assert rv.status_code == 404 + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_dataset_related_objects(self): """ Dataset API: Test get chart and dashboard count related to a dataset diff --git a/tests/datasets/commands_tests.py b/tests/datasets/commands_tests.py index 78ed44a82ae47..01bc91c0bef52 100644 --- a/tests/datasets/commands_tests.py +++ b/tests/datasets/commands_tests.py @@ -234,7 +234,8 @@ def test_import_v0_dataset_cli_export(self): assert len(dataset.metrics) == 2 assert dataset.main_dttm_col == "ds" assert dataset.filter_select_enabled - assert [col.column_name for col in dataset.columns] == [ + dataset.columns.sort(key=lambda obj: obj.column_name) + expected_columns = [ "num_california", "ds", "state", @@ -244,6 +245,8 @@ def test_import_v0_dataset_cli_export(self): "num_girls", "num", ] + expected_columns.sort() + assert [col.column_name for col in dataset.columns] == expected_columns db.session.delete(dataset) db.session.commit() diff --git a/tests/datasource_tests.py b/tests/datasource_tests.py index 14ad01d3dcbbd..290e1351e54d2 100644 --- a/tests/datasource_tests.py +++ b/tests/datasource_tests.py @@ -18,15 +18,30 @@ import json from copy import deepcopy -from superset import app, db +import pytest + +from superset import app, ConnectorRegistry, db from superset.connectors.sqla.models import SqlaTable from superset.utils.core import get_example_database +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices from .base_tests import SupersetTestCase from .fixtures.datasource import datasource_post class TestDatasource(SupersetTestCase): + def setUp(self): + self.original_attrs = {} + self.datasource = None + + def tearDown(self): + if self.datasource: + for key, value in self.original_attrs.items(): + setattr(self.datasource, key, value) + + db.session.commit() + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_external_metadata_for_physical_table(self): self.login(username="admin") tbl = self.get_table_by_name("birth_names") @@ -105,6 +120,12 @@ def compare_lists(self, l1, l2, key): def test_save(self): self.login(username="admin") tbl_id = self.get_table_by_name("birth_names").id + + self.datasource = ConnectorRegistry.get_datasource("table", tbl_id, db.session) + + for key in self.datasource.export_fields: + self.original_attrs[key] = getattr(self.datasource, key) + datasource_post["id"] = tbl_id data = dict(data=json.dumps(datasource_post)) resp = self.get_json_resp("/datasource/save/", data) @@ -130,6 +151,11 @@ def test_change_database(self): db_id = tbl.database_id datasource_post["id"] = tbl_id + self.datasource = ConnectorRegistry.get_datasource("table", tbl_id, db.session) + + for key in self.datasource.export_fields: + self.original_attrs[key] = getattr(self.datasource, key) + new_db = self.create_fake_db() datasource_post["database"]["id"] = new_db.id @@ -145,6 +171,11 @@ def test_change_database(self): def test_save_duplicate_key(self): self.login(username="admin") tbl_id = self.get_table_by_name("birth_names").id + self.datasource = ConnectorRegistry.get_datasource("table", tbl_id, db.session) + + for key in self.datasource.export_fields: + self.original_attrs[key] = getattr(self.datasource, key) + datasource_post_copy = deepcopy(datasource_post) datasource_post_copy["id"] = tbl_id datasource_post_copy["columns"].extend( @@ -172,6 +203,14 @@ def test_save_duplicate_key(self): def test_get_datasource(self): self.login(username="admin") tbl = self.get_table_by_name("birth_names") + self.datasource = ConnectorRegistry.get_datasource("table", tbl.id, db.session) + + for key in self.datasource.export_fields: + self.original_attrs[key] = getattr(self.datasource, key) + + datasource_post["id"] = tbl.id + data = dict(data=json.dumps(datasource_post)) + self.get_json_resp("/datasource/save/", data) url = f"/datasource/get/{tbl.type}/{tbl.id}/" resp = self.get_json_resp(url) self.assertEqual(resp.get("type"), "table") @@ -199,6 +238,11 @@ def my_check(datasource): self.login(username="admin") tbl = self.get_table_by_name("birth_names") + self.datasource = ConnectorRegistry.get_datasource("table", tbl.id, db.session) + + for key in self.datasource.export_fields: + self.original_attrs[key] = getattr(self.datasource, key) + url = f"/datasource/get/{tbl.type}/{tbl.id}/" tbl.health_check(commit=True, force=True) resp = self.get_json_resp(url) diff --git a/tests/fixtures/birth_names_dashboard.py b/tests/fixtures/birth_names_dashboard.py new file mode 100644 index 0000000000000..d07bbf43e4fb2 --- /dev/null +++ b/tests/fixtures/birth_names_dashboard.py @@ -0,0 +1,202 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import string +from datetime import date, datetime +from random import choice, getrandbits, randint, random, uniform +from typing import Any, Dict, List + +import pandas as pd +import pytest +from pandas import DataFrame +from sqlalchemy import DateTime, String, TIMESTAMP + +from superset import ConnectorRegistry, db +from superset.connectors.sqla.models import SqlaTable +from superset.models.core import Database +from superset.models.dashboard import Dashboard +from superset.models.slice import Slice +from superset.utils.core import get_example_database +from tests.dashboard_utils import create_dashboard, create_table_for_dashboard +from tests.test_app import app + + +@pytest.fixture() +def load_birth_names_dashboard_with_slices(): + dash_id_to_delete, slices_ids_to_delete = _load_data() + yield + with app.app_context(): + _cleanup(dash_id_to_delete, slices_ids_to_delete) + + +@pytest.fixture(scope="module") +def load_birth_names_dashboard_with_slices_module_scope(): + dash_id_to_delete, slices_ids_to_delete = _load_data() + yield + with app.app_context(): + _cleanup(dash_id_to_delete, slices_ids_to_delete) + + +def _load_data(): + table_name = "birth_names" + + with app.app_context(): + database = get_example_database() + df = _get_dataframe(database) + dtype = { + "ds": DateTime if database.backend != "presto" else String(255), + "gender": String(16), + "state": String(10), + "name": String(255), + } + table = _create_table(df, table_name, database, dtype) + + from superset.examples.birth_names import create_slices, create_dashboard + + slices, _ = create_slices(table, admin_owner=False) + dash = create_dashboard(slices) + slices_ids_to_delete = [slice.id for slice in slices] + dash_id_to_delete = dash.id + return dash_id_to_delete, slices_ids_to_delete + + +def _create_table( + df: DataFrame, table_name: str, database: "Database", dtype: Dict[str, Any] +): + table = create_table_for_dashboard(df, table_name, database, dtype) + from superset.examples.birth_names import _add_table_metrics, _set_table_metadata + + _set_table_metadata(table, database) + _add_table_metrics(table) + db.session.commit() + return table + + +def _cleanup(dash_id: int, slices_ids: List[int]) -> None: + table_id = db.session.query(SqlaTable).filter_by(table_name="birth_names").one().id + datasource = ConnectorRegistry.get_datasource("table", table_id, db.session) + columns = [column for column in datasource.columns] + metrics = [metric for metric in datasource.metrics] + + engine = get_example_database().get_sqla_engine() + engine.execute("DROP TABLE IF EXISTS birth_names") + for column in columns: + db.session.delete(column) + for metric in metrics: + db.session.delete(metric) + + dash = db.session.query(Dashboard).filter_by(id=dash_id).first() + + db.session.delete(dash) + for slice_id in slices_ids: + db.session.query(Slice).filter_by(id=slice_id).delete() + db.session.commit() + + +def _get_dataframe(database: Database) -> DataFrame: + data = _get_birth_names_data() + df = pd.DataFrame.from_dict(data) + if database.backend == "presto": + df.ds = df.ds.dt.strftime("%Y-%m-%d %H:%M%:%S") + return df + + +def _get_birth_names_data() -> List[Dict[Any, Any]]: + data = [] + names = generate_names() + for year in range(1960, 2020): + ds = datetime(year, 1, 1, 0, 0, 0) + for _ in range(20): + gender = "boy" if choice([True, False]) else "girl" + num = randint(1, 100000) + data.append( + { + "ds": ds, + "gender": gender, + "name": choice(names), + "num": num, + "state": choice(us_states), + "num_boys": num if gender == "boy" else 0, + "num_girls": num if gender == "girl" else 0, + } + ) + + return data + + +def generate_names() -> List[str]: + names = [] + for _ in range(250): + names.append( + "".join(choice(string.ascii_lowercase) for _ in range(randint(3, 12))) + ) + return names + + +us_states = [ + "AL", + "AK", + "AZ", + "AR", + "CA", + "CO", + "CT", + "DE", + "FL", + "GA", + "HI", + "ID", + "IL", + "IN", + "IA", + "KS", + "KY", + "LA", + "ME", + "MD", + "MA", + "MI", + "MN", + "MS", + "MO", + "MT", + "NE", + "NV", + "NH", + "NJ", + "NM", + "NY", + "NC", + "ND", + "OH", + "OK", + "OR", + "PA", + "RI", + "SC", + "SD", + "TN", + "TX", + "UT", + "VT", + "VA", + "WA", + "WV", + "WI", + "WY", + "other", +] diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index b0ef243a63d5a..2dc29f0acb6b2 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -18,6 +18,7 @@ """Unit tests for Superset""" import json import unittest +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices import pytest from flask import g @@ -240,6 +241,7 @@ def assert_only_exported_slc_fields(self, expected_dash, actual_dash): self.assertEqual(e_slc.datasource.schema, params["schema"]) self.assertEqual(e_slc.datasource.database.name, params["database_name"]) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_export_1_dashboard(self): self.login("admin") birth_dash = self.get_dash_by_slug("births") @@ -268,6 +270,7 @@ def test_export_1_dashboard(self): self.get_table_by_name("birth_names"), exported_tables[0] ) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_export_2_dashboards(self): self.login("admin") birth_dash = self.get_dash_by_slug("births") diff --git a/tests/model_tests.py b/tests/model_tests.py index 45dfee9c802de..e0eaf4ac460ed 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -17,6 +17,7 @@ # isort:skip_file import textwrap import unittest +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices import pandas import pytest @@ -214,6 +215,7 @@ def test_multi_statement(self): class TestSqlaTableModel(SupersetTestCase): + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_timestamp_expression(self): tbl = self.get_table_by_name("birth_names") ds_col = tbl.get_column("ds") @@ -233,6 +235,7 @@ def test_get_timestamp_expression(self): self.assertEqual(compiled, "DATE(DATE_ADD(ds, 1))") ds_col.expression = prev_ds_expr + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_timestamp_expression_epoch(self): tbl = self.get_table_by_name("birth_names") ds_col = tbl.get_column("ds") @@ -297,6 +300,7 @@ def query_with_expr_helper(self, is_timeseries, inner_join=True): self.assertFalse(qr.df.empty) return qr.df + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_query_with_expr_groupby_timeseries(self): if get_example_database().backend == "presto": # TODO(bkyryliuk): make it work for presto. @@ -313,29 +317,13 @@ def cannonicalize_df(df): name_list2 = cannonicalize_df(df1).name.values.tolist() self.assertFalse(df2.empty) - expected_namelist = [ - "Anthony", - "Brian", - "Christopher", - "Daniel", - "David", - "Eric", - "James", - "Jeffrey", - "John", - "Joseph", - "Kenneth", - "Kevin", - "Mark", - "Michael", - "Paul", - ] - assert name_list2 == expected_namelist - assert name_list1 == expected_namelist + assert name_list2 == name_list1 + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_query_with_expr_groupby(self): self.query_with_expr_helper(is_timeseries=False) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_sql_mutator(self): tbl = self.get_table_by_name("birth_names") query_obj = dict( @@ -381,14 +369,18 @@ def test_query_with_non_existent_metrics(self): self.assertTrue("Metric 'invalid' does not exist", context.exception) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_data_for_slices(self): tbl = self.get_table_by_name("birth_names") slc = ( metadata_db.session.query(Slice) - .filter_by(datasource_id=tbl.id, datasource_type=tbl.type) + .filter_by( + datasource_id=tbl.id, + datasource_type=tbl.type, + slice_name="Participants", + ) .first() ) - data_for_slices = tbl.data_for_slices([slc]) self.assertEqual(len(data_for_slices["columns"]), 0) self.assertEqual(len(data_for_slices["metrics"]), 1) diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py index 43bdccd335b7c..2201900477cdc 100644 --- a/tests/query_context_tests.py +++ b/tests/query_context_tests.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest + from superset import db from superset.charts.schemas import ChartDataQueryContextSchema from superset.connectors.connector_registry import ConnectorRegistry @@ -25,6 +27,7 @@ TimeRangeEndpoint, ) from tests.base_tests import SupersetTestCase +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices from tests.fixtures.query_context import get_query_context @@ -141,7 +144,7 @@ def test_handle_metrics_field(self): self.login(username="admin") adhoc_metric = { "expressionType": "SIMPLE", - "column": {"column_name": "sum_boys", "type": "BIGINT(20)"}, + "column": {"column_name": "num_boys", "type": "BIGINT(20)"}, "aggregate": "SUM", "label": "Boys", "optionName": "metric_11", @@ -166,6 +169,7 @@ def test_convert_deprecated_fields(self): self.assertEqual(query_object.granularity, "timecol") self.assertIn("having_druid", query_object.extras) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_csv_response_format(self): """ Ensure that CSV result format works @@ -224,6 +228,7 @@ def test_sql_injection_via_metrics(self): query_payload = query_context.get_payload() assert query_payload["queries"][0].get("error") is not None + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_samples_response_type(self): """ Ensure that samples result type works @@ -240,6 +245,7 @@ def test_samples_response_type(self): self.assertEqual(len(data), 5) self.assertNotIn("sum__num", data[0]) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_query_response_type(self): """ Ensure that query result type works diff --git a/tests/schedules_test.py b/tests/schedules_test.py index b18007e2f77a3..8ff7a528b97af 100644 --- a/tests/schedules_test.py +++ b/tests/schedules_test.py @@ -41,6 +41,9 @@ from superset.models.slice import Slice from tests.base_tests import SupersetTestCase from tests.utils import read_fixture +from tests.fixtures.birth_names_dashboard import ( + load_birth_names_dashboard_with_slices_module_scope, +) class TestSchedules(SupersetTestCase): @@ -138,6 +141,7 @@ def test_wider_schedules(self): else: self.assertEqual(len(schedules), 0) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices_module_scope") def test_complex_schedule(self): # Run the job on every Friday of March and May # On these days, run the job at diff --git a/tests/security_tests.py b/tests/security_tests.py index d8665276984a8..7ddc326c77f88 100644 --- a/tests/security_tests.py +++ b/tests/security_tests.py @@ -47,6 +47,7 @@ ) from .fixtures.energy_dashboard import load_energy_table_with_slice from .fixtures.unicode_dashboard import load_unicode_dashboard_with_slice +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices NEW_SECURITY_CONVERGE_VIEWS = ( "Annotation", @@ -1149,6 +1150,7 @@ def test_multiple_table_filter_alters_another_tables_query(self): assert tbl.get_extra_cache_keys(self.query_obj) == [1] assert "value > 1" in sql + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_rls_filter_alters_gamma_birth_names_query(self): g.user = self.get_user(username="gamma") tbl = self.get_table_by_name("birth_names") @@ -1161,6 +1163,7 @@ def test_rls_filter_alters_gamma_birth_names_query(self): in sql ) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_rls_filter_alters_no_role_user_birth_names_query(self): g.user = self.get_user(username="NoRlsRoleUser") tbl = self.get_table_by_name("birth_names") @@ -1173,6 +1176,7 @@ def test_rls_filter_alters_no_role_user_birth_names_query(self): # base query should be present assert self.BASE_FILTER_REGEX.search(sql) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_rls_filter_doesnt_alter_admin_birth_names_query(self): g.user = self.get_user(username="admin") tbl = self.get_table_by_name("birth_names") diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py index f6bb4fc765ec4..95c9c16b1bdcc 100644 --- a/tests/sqla_models_tests.py +++ b/tests/sqla_models_tests.py @@ -27,6 +27,7 @@ from superset.exceptions import QueryObjectValidationError from superset.models.core import Database from superset.utils.core import DbColumnType, get_example_database, FilterOperator +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices from .base_tests import SupersetTestCase @@ -165,6 +166,7 @@ def test_extra_cache_keys(self, flask_g): db.session.delete(table) db.session.commit() + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_where_operators(self): class FilterTestCase(NamedTuple): operator: str diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py index 8c2b10a2557a7..3959485567cde 100644 --- a/tests/sqllab_tests.py +++ b/tests/sqllab_tests.py @@ -18,17 +18,19 @@ """Unit tests for Sql Lab""" import json from datetime import datetime, timedelta -from random import random -from unittest import mock +import pytest from parameterized import parameterized +from random import random +from unittest import mock +from superset.extensions import db import prison -import pytest from superset import db, security_manager from superset.connectors.sqla.models import SqlaTable from superset.db_engine_specs import BaseEngineSpec from superset.errors import ErrorLevel, SupersetErrorType +from superset.models.core import Database from superset.models.sql_lab import Query, SavedQuery from superset.result_set import SupersetResultSet from superset.sql_lab import execute_sql_statements, SqlLabException @@ -41,6 +43,7 @@ from .base_tests import SupersetTestCase from .conftest import CTAS_SCHEMA_NAME +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices QUERY_1 = "SELECT * FROM birth_names LIMIT 1" QUERY_2 = "SELECT * FROM NO_TABLE" @@ -64,6 +67,7 @@ def tearDown(self): db.session.commit() db.session.close() + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_sql_json(self): self.login("admin") @@ -84,6 +88,7 @@ def test_sql_json(self): ] } + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_sql_json_to_saved_query_info(self): """ SQLLab: Test SQLLab query execution info propagation to saved queries @@ -115,6 +120,7 @@ def test_sql_json_to_saved_query_info(self): db.session.commit() @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW]) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_sql_json_cta_dynamic_db(self, ctas_method): examples_db = get_example_database() if examples_db.backend == "sqlite": @@ -146,8 +152,9 @@ def test_sql_json_cta_dynamic_db(self, ctas_method): data = engine.execute( f"SELECT * FROM admin_database.{tmp_table_name}" ).fetchall() + names_count = engine.execute(f"SELECT COUNT(*) FROM birth_names").first() self.assertEqual( - 100, len(data) + names_count[0], len(data) ) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True # cleanup @@ -155,6 +162,7 @@ def test_sql_json_cta_dynamic_db(self, ctas_method): examples_db.allow_ctas = old_allow_ctas db.session.commit() + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_multi_sql(self): self.login("admin") @@ -165,12 +173,14 @@ def test_multi_sql(self): data = self.run_sql(multi_sql, "2234") self.assertLess(0, len(data["data"])) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_explain(self): self.login("admin") data = self.run_sql("EXPLAIN SELECT * FROM birth_names", "1") self.assertLess(0, len(data["data"])) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_sql_json_has_access(self): examples_db = get_example_database() examples_db_permission_view = security_manager.add_permission_view_menu( @@ -312,6 +322,7 @@ def test_search_query_on_user(self): self.assertEqual(1, len(data)) self.assertEqual(data[0]["userId"], user_id) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_search_query_on_status(self): self.run_some_queries() self.login("admin") @@ -481,6 +492,7 @@ def test_sqllab_table_viz(self): ) db.session.commit() + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_sql_limit(self): self.login("admin") test_limit = 1 @@ -589,6 +601,7 @@ def test_api_database(self): ) self.delete_fake_db() + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch.dict( "superset.extensions.feature_flag_manager._feature_flags", {"ENABLE_TEMPLATE_PROCESSING": True}, diff --git a/tests/strategy_tests.py b/tests/strategy_tests.py index 29f736c989742..bba560f642b3d 100644 --- a/tests/strategy_tests.py +++ b/tests/strategy_tests.py @@ -19,6 +19,7 @@ import datetime import json from unittest.mock import MagicMock +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices from sqlalchemy import String, Date, Float @@ -184,6 +185,7 @@ def test_get_form_data(self): } self.assertEqual(result, expected) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_top_n_dashboards_strategy(self): # create a top visited dashboard db.session.query(Log).delete() @@ -204,7 +206,9 @@ def reset_tag(self, tag): db.session.delete(o) db.session.commit() - @pytest.mark.usefixtures("load_unicode_dashboard_with_slice") + @pytest.mark.usefixtures( + "load_unicode_dashboard_with_slice", "load_birth_names_dashboard_with_slices" + ) def test_dashboard_tags(self): tag1 = get_tag("tag1", db.session, TagTypes.custom) # delete first to make test idempotent diff --git a/tests/tasks/async_queries_tests.py b/tests/tasks/async_queries_tests.py index f816e0bb60e65..e44a51510e255 100644 --- a/tests/tasks/async_queries_tests.py +++ b/tests/tasks/async_queries_tests.py @@ -32,6 +32,7 @@ load_explore_json_into_cache, ) from tests.base_tests import SupersetTestCase +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices from tests.fixtures.query_context import get_query_context from tests.test_app import app @@ -42,6 +43,7 @@ def get_table_by_name(name: str) -> SqlaTable: class TestAsyncQueries(SupersetTestCase): + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch.object(async_query_manager, "update_job") def test_load_chart_data_into_cache(self, mock_update_job): async_query_manager.init_app(app) @@ -79,6 +81,7 @@ def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_comman errors = [{"message": "Error: foo"}] mock_update_job.assert_called_with(job_metadata, "error", errors=errors) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch.object(async_query_manager, "update_job") def test_load_explore_json_into_cache(self, mock_update_job): async_query_manager.init_app(app) diff --git a/tests/utils/get_dashboards.py b/tests/utils/get_dashboards.py new file mode 100644 index 0000000000000..03260fb94d07f --- /dev/null +++ b/tests/utils/get_dashboards.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List + +from flask_appbuilder import SQLA + +from superset.models.dashboard import Dashboard + + +def get_dashboards_ids(db: SQLA, dashboard_slugs: List[str]) -> List[int]: + result = ( + db.session.query(Dashboard.id).filter(Dashboard.slug.in_(dashboard_slugs)).all() + ) + return [row[0] for row in result] diff --git a/tests/utils_tests.py b/tests/utils_tests.py index b47f7d104f8c4..a16db4b27b31b 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -24,8 +24,10 @@ import os import re from unittest.mock import Mock, patch +from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices import numpy +import pytest from flask import Flask, g import marshmallow from sqlalchemy.exc import ArgumentError @@ -990,6 +992,7 @@ def test_get_form_data_globals(self) -> None: self.assertEqual(slc, None) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_log_this(self) -> None: # TODO: Add additional scenarios. self.login(username="admin")