diff --git a/UPDATING.md b/UPDATING.md index c7e00b44e1756..5af1c5c501b57 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -43,6 +43,10 @@ assists people when migrating to a new version. set `SLACK_API_TOKEN` to fetch and serve Slack avatar links - [28134](https://github.com/apache/superset/pull/28134/) The default logging level was changed from DEBUG to INFO - which is the normal/sane default logging level for most software. +- [28205](https://github.com/apache/superset/pull/28205) The permission `all_database_access` now + more clearly provides access to all databases, as specified in its name. Before it only allowed + listing all databases in CRUD-view and dropdown and didn't provide access to data as it + seemed the name would imply. ## 4.0.0 diff --git a/superset/security/manager.py b/superset/security/manager.py index a84c0cec0d2a0..7903017262071 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -410,7 +410,9 @@ def can_access_all_datasources(self) -> bool: :returns: Whether the user can access all the datasources """ - return self.can_access("all_datasource_access", "all_datasource_access") + return self.can_access_all_databases() or self.can_access( + "all_datasource_access", "all_datasource_access" + ) def can_access_all_databases(self) -> bool: """ @@ -418,7 +420,6 @@ def can_access_all_databases(self) -> bool: :returns: Whether the user can access all the databases """ - return self.can_access("all_database_access", "all_database_access") def can_access_database(self, database: "Database") -> bool: @@ -2433,7 +2434,6 @@ def raise_for_ownership(self, resource: Model) -> None: if self.is_admin(): return - orig_resource = db.session.query(resource.__class__).get(resource.id) owners = orig_resource.owners if hasattr(orig_resource, "owners") else [] diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py index be08db539e153..b7b3199917872 100644 --- a/tests/integration_tests/base_tests.py +++ b/tests/integration_tests/base_tests.py @@ -25,7 +25,7 @@ from unittest.mock import Mock, patch, MagicMock import pandas as pd -from flask import Response +from flask import Response, g from flask_appbuilder.security.sqla import models as ab_models from flask_testing import TestCase from sqlalchemy.engine.interfaces import Dialect @@ -42,11 +42,12 @@ from superset.models.slice import Slice from superset.models.core import Database from superset.models.dashboard import Dashboard -from superset.utils.core import get_example_default_schema +from superset.utils.core import get_example_default_schema, shortid from superset.utils.database import get_example_database from superset.views.base_api import BaseSupersetModelRestApi FAKE_DB_NAME = "fake_db_100" +DEFAULT_PASSWORD = "general" test_client = app.test_client() @@ -133,7 +134,7 @@ def create_user_with_roles( username, f"{username}@superset.com", security_manager.find_role("Gamma"), # it needs a role - password="general", + password=DEFAULT_PASSWORD, ) db.session.commit() user_to_create = security_manager.find_user(username) @@ -147,6 +148,76 @@ def create_user_with_roles( db.session.commit() return user_to_create + @contextmanager + def temporary_user( + self, + clone_user=None, + username=None, + extra_roles=None, + extra_pvms=None, + login=False, + ): + """ + Create a temporary user for testing and delete it after the test + + with self.temporary_user(login=True, extra_roles=[Role(...)]) as user: + user.do_something() + + # user is automatically logged out and deleted after the test + """ + username = username or f"temp_user_{shortid()}" + temp_user = ab_models.User( + username=username, email=f"{username}@temp.com", active=True + ) + if clone_user: + temp_user.roles = clone_user.roles + temp_user.first_name = clone_user.first_name + temp_user.last_name = clone_user.last_name + temp_user.password = clone_user.password + else: + temp_user.first_name = temp_user.last_name = username + + if clone_user: + temp_user.roles = clone_user.roles + + if extra_roles: + temp_user.roles.extend(extra_roles) + + pvms = [] + temp_role = None + if extra_pvms: + temp_role = ab_models.Role(name=f"tmp_role_{shortid()}") + for pvm in extra_pvms: + if isinstance(pvm, (tuple, list)): + pvms.append(security_manager.find_permission_view_menu(*pvm)) + else: + pvms.append(pvm) + temp_role.permissions = pvms + temp_user.roles.append(temp_role) + db.session.add(temp_role) + db.session.commit() + + # Add the temp user to the session and commit to apply changes for the test + db.session.add(temp_user) + db.session.commit() + previous_g_user = g.user if hasattr(g, "user") else None + try: + if login: + resp = self.login(username=temp_user.username) + print(resp) + else: + g.user = temp_user + yield temp_user + finally: + # Revert changes after the test + if temp_role: + db.session.delete(temp_role) + if login: + self.logout() + db.session.delete(temp_user) + db.session.commit() + g.user = previous_g_user + @staticmethod def create_user( username: str, @@ -200,7 +271,7 @@ def get_or_create(self, cls, criteria, **kwargs): db.session.commit() return obj - def login(self, username, password="general"): + def login(self, username, password=DEFAULT_PASSWORD): return login(self.client, username, password) def get_slice(self, slice_name: str) -> Slice: @@ -249,8 +320,13 @@ def get_datasource_mock() -> BaseDatasource: datasource.query = Mock(return_value=results) datasource.database = Mock() datasource.database.db_engine_spec = Mock() + datasource.database.perm = "mock_database_perm" + datasource.schema_perm = "mock_schema_perm" + datasource.perm = "mock_datasource_perm" + datasource.__class__ = SqlaTable datasource.database.db_engine_spec.mutate_expression_label = lambda x: x datasource.owners = MagicMock() + datasource.id = 99999 return datasource def get_resp( diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 2bed0b36920d2..e6c68147c213e 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -1294,6 +1294,33 @@ def test_user_gets_none_filtered_energy_slices(self): data = json.loads(rv.data.decode("utf-8")) self.assertEqual(data["count"], 0) + @pytest.mark.usefixtures("load_energy_charts") + def test_user_gets_all_charts(self): + # test filtering on datasource_name + gamma_user = security_manager.find_user(username="gamma") + + def count_charts(): + uri = "api/v1/chart/" + rv = self.client.get(uri, "get_list") + self.assertEqual(rv.status_code, 200) + data = rv.get_json() + return data["count"] + + with self.temporary_user(gamma_user, login=True): + self.assertEqual(count_charts(), 0) + + perm = ("all_database_access", "all_database_access") + with self.temporary_user(gamma_user, extra_pvms=[perm], login=True): + assert count_charts() > 0 + + perm = ("all_datasource_access", "all_datasource_access") + with self.temporary_user(gamma_user, extra_pvms=[perm], login=True): + assert count_charts() > 0 + + # Back to normal + with self.temporary_user(gamma_user, login=True): + self.assertEqual(count_charts(), 0) + @pytest.mark.usefixtures("create_charts") def test_get_charts_favorite_filter(self): """ diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 92aad24a3cbb7..e3258651bb871 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -189,6 +189,40 @@ def create_dataset_import(self) -> BytesIO: buf.seek(0) return buf + @pytest.mark.usefixtures("load_energy_table_with_slice") + def test_user_gets_all_datasets(self): + # test filtering on datasource_name + gamma_user = security_manager.find_user(username="gamma") + + def count_datasets(): + uri = "api/v1/chart/" + rv = self.client.get(uri, "get_list") + print(rv.data) + self.assertEqual(rv.status_code, 200) + data = rv.get_json() + return data["count"] + + with self.temporary_user(gamma_user, login=True) as user: + assert count_datasets() == 0 + + all_db_pvm = ("all_database_access", "all_database_access") + with self.temporary_user( + gamma_user, extra_pvms=[all_db_pvm], login=True + ) as user: + self.login(username=user.username) + assert count_datasets() > 0 + + all_db_pvm = ("all_datasource_access", "all_datasource_access") + with self.temporary_user( + gamma_user, extra_pvms=[all_db_pvm], login=True + ) as user: + self.login(username=user.username) + assert count_datasets() > 0 + + # Back to normal + with self.temporary_user(gamma_user, login=True): + assert count_datasets() == 0 + def test_get_dataset_list(self): """ Dataset API: Test get dataset list diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index e93964316f830..f22847ca55931 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -28,7 +28,7 @@ import prison import pytest -from flask import current_app +from flask import current_app, g from flask_appbuilder.security.sqla.models import Role from superset.daos.datasource import DatasourceDAO # noqa: F401 from superset.models.dashboard import Dashboard @@ -1887,6 +1887,20 @@ def test_get_anonymous_roles(self): roles = security_manager.get_user_roles() self.assertEqual([security_manager.get_public_role()], roles) + def test_all_database_access(self): + gamma_user = security_manager.find_user(username="gamma") + g.user = gamma_user + + # Double checking that gamma users can't access all databases + assert not security_manager.can_access_all_databases() + assert not security_manager.can_access_datasource(self.get_datasource_mock()) + + all_db_pvm = ("all_database_access", "all_database_access") + + with self.temporary_user(gamma_user, extra_pvms=[all_db_pvm]): + assert security_manager.can_access_all_databases() + assert security_manager.can_access_datasource(self.get_datasource_mock()) + class TestDatasources(SupersetTestCase): @patch("superset.security.SupersetSecurityManager.can_access_database") diff --git a/tests/integration_tests/test_app.py b/tests/integration_tests/test_app.py index cd5692939c4fd..e88d0e5cb8d98 100644 --- a/tests/integration_tests/test_app.py +++ b/tests/integration_tests/test_app.py @@ -39,3 +39,4 @@ def login( data=dict(username=username, password=password), ).get_data(as_text=True) assert "User confirmation needed" not in resp + return resp