diff --git a/UPDATING.md b/UPDATING.md index 67843c765162f..38c40bbdd0dfa 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -38,6 +38,9 @@ assists people when migrating to a new version. - [27697](https://github.com/apache/superset/pull/27697) [minor] flask-session bump leads to them deprecating `SESSION_USE_SIGNER`, check your configs as this flag won't do anything moving forward. +- [27849](https://github.com/apache/superset/pull/27849/) More of an FYI, but we have a + new config `SLACK_ENABLE_AVATARS` (False by default) that works in conjunction with + set `SLACK_API_TOKEN` to fetch and serve Slack avatar links ## 4.0.0 diff --git a/superset-frontend/src/components/FacePile/index.tsx b/superset-frontend/src/components/FacePile/index.tsx index 44cc62ce1d624..721f2f084cb3d 100644 --- a/superset-frontend/src/components/FacePile/index.tsx +++ b/superset-frontend/src/components/FacePile/index.tsx @@ -33,12 +33,14 @@ interface FacePileProps { const colorList = getCategoricalSchemeRegistry().get()?.colors ?? []; -const customAvatarStyler = (theme: SupersetTheme) => ` - width: ${theme.gridUnit * 6}px; - height: ${theme.gridUnit * 6}px; - line-height: ${theme.gridUnit * 6}px; - font-size: ${theme.typography.sizes.m}px; -`; +const customAvatarStyler = (theme: SupersetTheme) => { + const size = theme.gridUnit * 8; + return ` + width: ${size}px; + height: ${size}px; + line-height: ${size}px; + font-size: ${theme.typography.sizes.s}px;`; +}; const StyledAvatar = styled(Avatar)` ${({ theme }) => customAvatarStyler(theme)} @@ -58,6 +60,7 @@ export default function FacePile({ users, maxCount = 4 }: FacePileProps) { const name = `${first_name} ${last_name}`; const uniqueKey = `${id}-${first_name}-${last_name}`; const color = getRandomColor(uniqueKey, colorList); + const avatarUrl = `/api/v1/user/${id}/avatar.png`; return ( {first_name?.[0]?.toLocaleUpperCase()} {last_name?.[0]?.toLocaleUpperCase()} diff --git a/superset-frontend/src/pages/DashboardList/index.tsx b/superset-frontend/src/pages/DashboardList/index.tsx index 23261aff84a67..2e5a5a2cd87be 100644 --- a/superset-frontend/src/pages/DashboardList/index.tsx +++ b/superset-frontend/src/pages/DashboardList/index.tsx @@ -357,6 +357,9 @@ function DashboardList(props: DashboardListProps) { Header: t('Owners'), accessor: 'owners', disableSortBy: true, + cellProps: { + style: { padding: '0px' }, + }, size: 'xl', }, { diff --git a/superset/config.py b/superset/config.py index 4b9969bbf5e3b..787d52fa1453a 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1330,6 +1330,11 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument SLACK_API_TOKEN: Callable[[], str] | str | None = None SLACK_PROXY = None +# Whether Superset should use Slack avatars for users. +# If on, you'll want to add "https://avatars.slack-edge.com" to the list of allowed +# domains in your TALISMAN_CONFIG +SLACK_ENABLE_AVATARS = False + # The webdriver to use for generating reports. Use one of the following # firefox # Requires: geckodriver and firefox installations @@ -1454,6 +1459,7 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument "data:", "https://apachesuperset.gateway.scarf.sh", "https://static.scarf.sh/", + # "https://avatars.slack-edge.com", # Uncomment when SLACK_ENABLE_AVATARS is True ], "worker-src": ["'self'", "blob:"], "connect-src": [ @@ -1483,6 +1489,7 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument "data:", "https://apachesuperset.gateway.scarf.sh", "https://static.scarf.sh/", + "https://avatars.slack-edge.com", ], "worker-src": ["'self'", "blob:"], "connect-src": [ diff --git a/superset/daos/user.py b/superset/daos/user.py new file mode 100644 index 0000000000000..cc6696cbdcc74 --- /dev/null +++ b/superset/daos/user.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging + +from flask_appbuilder.security.sqla.models import User + +from superset.daos.base import BaseDAO +from superset.extensions import db +from superset.models.user_attributes import UserAttribute + +logger = logging.getLogger(__name__) + + +class UserDAO(BaseDAO[User]): + @staticmethod + def get_by_id(user_id: int) -> User: + return db.session.query(User).filter_by(id=user_id).one() + + @staticmethod + def set_avatar_url(user: User, url: str) -> None: + if user.extra_attributes: + user.extra_attributes[0].avatar_url = url + else: + attrs = UserAttribute(avatar_url=url, user_id=user.id) + user.extra_attributes = [attrs] + db.session.add(attrs) + db.session.commit() diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index eae06bcf8f964..e9060d0076a58 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -189,7 +189,7 @@ def init_views(self) -> None: ) from superset.views.sqllab import SqllabView from superset.views.tags import TagModelView, TagView - from superset.views.users.api import CurrentUserRestApi + from superset.views.users.api import CurrentUserRestApi, UserRestApi # # Setup API views @@ -204,6 +204,7 @@ def init_views(self) -> None: appbuilder.add_api(ChartDataRestApi) appbuilder.add_api(CssTemplateRestApi) appbuilder.add_api(CurrentUserRestApi) + appbuilder.add_api(UserRestApi) appbuilder.add_api(DashboardFilterStateRestApi) appbuilder.add_api(DashboardPermalinkRestApi) appbuilder.add_api(DashboardRestApi) diff --git a/superset/migrations/versions/2024-04-01_22-44_c22cb5c2e546_user_attr_avatar_url.py b/superset/migrations/versions/2024-04-01_22-44_c22cb5c2e546_user_attr_avatar_url.py new file mode 100644 index 0000000000000..0a5430c684edb --- /dev/null +++ b/superset/migrations/versions/2024-04-01_22-44_c22cb5c2e546_user_attr_avatar_url.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# # Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""empty message + +Revision ID: c22cb5c2e546 +Revises: be1b217cd8cd +Create Date: 2024-04-01 22:44:40.386543 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "c22cb5c2e546" +down_revision = "be1b217cd8cd" + + +def upgrade(): + op.add_column( + "user_attribute", sa.Column("avatar_url", sa.String(length=100), nullable=True) + ) + + +def downgrade(): + op.drop_column("user_attribute", "avatar_url") diff --git a/superset/migrations/versions/2024-04-11_00-49_bbf146925528_.py b/superset/migrations/versions/2024-04-11_00-49_bbf146925528_.py new file mode 100644 index 0000000000000..c9161c863c5ca --- /dev/null +++ b/superset/migrations/versions/2024-04-11_00-49_bbf146925528_.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""empty message + +Revision ID: bbf146925528 +Revises: ('678eefb4ab44', 'c22cb5c2e546') +Create Date: 2024-04-11 00:49:51.592325 + +""" + +# revision identifiers, used by Alembic. +revision = "bbf146925528" +down_revision = ("678eefb4ab44", "c22cb5c2e546") + +import sqlalchemy as sa +from alembic import op + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/superset/migrations/versions/2024-04-15_16-06_0dc386701747_.py b/superset/migrations/versions/2024-04-15_16-06_0dc386701747_.py new file mode 100644 index 0000000000000..d8087a41c036f --- /dev/null +++ b/superset/migrations/versions/2024-04-15_16-06_0dc386701747_.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""empty message + +Revision ID: 0dc386701747 +Revises: ('5ad7321c2169', 'bbf146925528') +Create Date: 2024-04-15 16:06:29.946059 + +""" + +# revision identifiers, used by Alembic. +revision = "0dc386701747" +down_revision = ("5ad7321c2169", "bbf146925528") + +import sqlalchemy as sa +from alembic import op + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/superset/models/user_attributes.py b/superset/models/user_attributes.py index b2af44a188945..55b6d8abad513 100644 --- a/superset/models/user_attributes.py +++ b/superset/models/user_attributes.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from flask_appbuilder import Model -from sqlalchemy import Column, ForeignKey, Integer +from sqlalchemy import Column, ForeignKey, Integer, String from sqlalchemy.orm import relationship from superset import security_manager @@ -39,6 +40,6 @@ class UserAttribute(Model, AuditMixinNullable): user = relationship( security_manager.user_model, backref="extra_attributes", foreign_keys=[user_id] ) - welcome_dashboard_id = Column(Integer, ForeignKey("dashboards.id")) welcome_dashboard = relationship("Dashboard") + avatar_url = Column(String(100)) diff --git a/superset/reports/notifications/slack.py b/superset/reports/notifications/slack.py index 345e9d5b26373..a7072ca20c4d5 100644 --- a/superset/reports/notifications/slack.py +++ b/superset/reports/notifications/slack.py @@ -24,7 +24,6 @@ import pandas as pd from flask import g from flask_babel import gettext as __ -from slack_sdk import WebClient from slack_sdk.errors import ( BotUserAccessError, SlackApiError, @@ -36,7 +35,6 @@ SlackTokenRotationError, ) -from superset import app from superset.reports.models import ReportRecipientType from superset.reports.notifications.base import BaseNotification from superset.reports.notifications.exceptions import ( @@ -47,6 +45,7 @@ ) from superset.utils.core import get_email_address_list from superset.utils.decorators import statsd_gauge +from superset.utils.slack import get_slack_client logger = logging.getLogger(__name__) @@ -181,10 +180,7 @@ def send(self) -> None: body = self._get_body() global_logs_context = getattr(g, "logs_context", {}) or {} try: - token = app.config["SLACK_API_TOKEN"] - if callable(token): - token = token() - client = WebClient(token=token, proxy=app.config["SLACK_PROXY"]) + client = get_slack_client() # files_upload returns SlackResponse as we run it in sync mode. if files: for file in files: diff --git a/superset/utils/slack.py b/superset/utils/slack.py new file mode 100644 index 0000000000000..8fa9013dfa547 --- /dev/null +++ b/superset/utils/slack.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from flask import current_app +from slack_sdk import WebClient + + +class SlackClientError(Exception): + pass + + +def get_slack_client() -> WebClient: + token: str = current_app.config["SLACK_API_TOKEN"] + if callable(token): + token = token() + return WebClient(token=token, proxy=current_app.config["SLACK_PROXY"]) + + +def get_user_avatar(email: str, client: WebClient = None) -> str: + client = client or get_slack_client() + try: + response = client.users_lookupByEmail(email=email) + except Exception as ex: + raise SlackClientError(f"Failed to lookup user by email: {email}") from ex + + user = response.data.get("user") + if user is None: + raise SlackClientError("No user found with that email.") + + profile = user.get("profile") + if profile is None: + raise SlackClientError("User found but no profile available.") + + avatar_url = profile.get("image_192") + if avatar_url is None: + raise SlackClientError("Profile image is not available.") + + return avatar_url diff --git a/superset/views/users/api.py b/superset/views/users/api.py index 5324975637d36..a7000b6b96c00 100644 --- a/superset/views/users/api.py +++ b/superset/views/users/api.py @@ -14,10 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from flask import g, Response +from flask import g, redirect, Response from flask_appbuilder.api import expose, safe from flask_jwt_extended.exceptions import NoAuthorizationError +from sqlalchemy.orm.exc import NoResultFound +from superset import app +from superset.daos.user import UserDAO +from superset.utils.slack import get_user_avatar, SlackClientError from superset.views.base_api import BaseSupersetApi from superset.views.users.schemas import UserResponseSchema from superset.views.utils import bootstrap_user_data @@ -93,3 +97,68 @@ def get_my_roles(self) -> Response: return self.response_401() user = bootstrap_user_data(g.user, include_perms=True) return self.response(200, result=user) + + +class UserRestApi(BaseSupersetApi): + """An API to get information about users""" + + resource_name = "user" + openapi_spec_tag = "User" + openapi_spec_component_schemas = (UserResponseSchema,) + + @expose("//avatar.png", methods=("GET",)) + @safe + def avatar(self, user_id: int) -> Response: + """Get a redirect to the avatar's URL for the user with the given ID. + --- + get: + summary: Get the user avatar + description: >- + Gets the avatar URL for the user with the given ID, or returns a 401 error + if the user is unauthenticated. + parameters: + - in: path + name: user_id + required: true + description: The ID of the user + schema: + type: string + responses: + 301: + description: A redirect to the user's avatar URL + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + """ + avatar_url = None + try: + user = UserDAO.get_by_id(user_id) + except NoResultFound: + return self.response_404() + + if not user: + return self.response_404() + + # fetch from the one-to-one relationship + if len(user.extra_attributes) > 0: + avatar_url = user.extra_attributes[0].avatar_url + + should_fetch_slack_avatar = app.config.get( + "SLACK_ENABLE_AVATARS" + ) and app.config.get("SLACK_API_TOKEN") + if not avatar_url and should_fetch_slack_avatar: + try: + # Fetching the avatar url from slack + avatar_url = get_user_avatar(user.email) + except SlackClientError: + return self.response_404() + + UserDAO.set_avatar_url(user, avatar_url) + + # Return a permanent redirect to the avatar URL + if avatar_url: + return redirect(avatar_url, code=301) + + # No avatar found, return a "no-content" response + return Response(status=204) diff --git a/tests/integration_tests/reports/commands_tests.py b/tests/integration_tests/reports/commands_tests.py index 9e92841a629a4..79102654d5126 100644 --- a/tests/integration_tests/reports/commands_tests.py +++ b/tests/integration_tests/reports/commands_tests.py @@ -1123,7 +1123,7 @@ def test_email_dashboard_report_schedule_force_screenshot( @pytest.mark.usefixtures( "load_birth_names_dashboard_with_slices", "create_report_slack_chart" ) -@patch("superset.reports.notifications.slack.WebClient.files_upload") +@patch("superset.utils.slack.WebClient.files_upload") @patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") def test_slack_chart_report_schedule( screenshot_mock, @@ -1157,7 +1157,7 @@ def test_slack_chart_report_schedule( @pytest.mark.usefixtures( "load_birth_names_dashboard_with_slices", "create_report_slack_chart" ) -@patch("superset.reports.notifications.slack.WebClient") +@patch("superset.utils.slack.WebClient") @patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") def test_slack_chart_report_schedule_with_errors( screenshot_mock, @@ -1211,7 +1211,7 @@ def test_slack_chart_report_schedule_with_errors( @pytest.mark.usefixtures( "load_birth_names_dashboard_with_slices", "create_report_slack_chart_with_csv" ) -@patch("superset.reports.notifications.slack.WebClient.files_upload") +@patch("superset.utils.slack.WebClient.files_upload") @patch("superset.utils.csv.urllib.request.urlopen") @patch("superset.utils.csv.urllib.request.OpenerDirector.open") @patch("superset.utils.csv.get_chart_csv_data") @@ -1250,7 +1250,7 @@ def test_slack_chart_report_schedule_with_csv( @pytest.mark.usefixtures( "load_birth_names_dashboard_with_slices", "create_report_slack_chart_with_text" ) -@patch("superset.reports.notifications.slack.WebClient.chat_postMessage") +@patch("superset.utils.slack.WebClient.chat_postMessage") @patch("superset.utils.csv.urllib.request.urlopen") @patch("superset.utils.csv.urllib.request.OpenerDirector.open") @patch("superset.utils.csv.get_chart_dataframe") @@ -1378,7 +1378,7 @@ def test_report_schedule_success_grace(create_alert_slack_chart_success): @pytest.mark.usefixtures("create_alert_slack_chart_grace") -@patch("superset.reports.notifications.slack.WebClient.files_upload") +@patch("superset.utils.slack.WebClient.files_upload") @patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") def test_report_schedule_success_grace_end( screenshot_mock, file_upload_mock, create_alert_slack_chart_grace @@ -1547,7 +1547,7 @@ def test_slack_chart_alert_no_attachment(email_mock, create_alert_email_chart): @pytest.mark.usefixtures( "load_birth_names_dashboard_with_slices", "create_report_slack_chart" ) -@patch("superset.reports.notifications.slack.WebClient") +@patch("superset.utils.slack.WebClient") @patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") def test_slack_token_callable_chart_report( screenshot_mock, diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index 90151bb6bdd60..02e60a927724c 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -1542,6 +1542,7 @@ def test_views_are_secured(self): ["AuthDBView", "logout"], ["CurrentUserRestApi", "get_me"], ["CurrentUserRestApi", "get_my_roles"], + ["UserRestApi", "avatar"], # TODO (embedded) remove Dashboard:embedded after uuids have been shipped ["Dashboard", "embedded"], ["EmbeddedView", "embedded"], diff --git a/tests/integration_tests/users/api_tests.py b/tests/integration_tests/users/api_tests.py index 5d7ebd61fbd6a..44711d96f27a9 100644 --- a/tests/integration_tests/users/api_tests.py +++ b/tests/integration_tests/users/api_tests.py @@ -20,10 +20,13 @@ from unittest.mock import patch from superset import security_manager +from superset.utils import slack from tests.integration_tests.base_tests import SupersetTestCase +from tests.integration_tests.conftest import with_config from tests.integration_tests.constants import ADMIN_USERNAME meUri = "/api/v1/me/" +AVATAR_URL = "/internal/avatar.png" class TestCurrentUserApi(SupersetTestCase): @@ -62,3 +65,27 @@ def test_get_me_anonymous(self, mock_g): mock_g.user = security_manager.get_anonymous_user rv = self.client.get(meUri) self.assertEqual(401, rv.status_code) + + +class TestUserApi(SupersetTestCase): + def test_avatar_with_invalid_user(self): + self.login(ADMIN_USERNAME) + response = self.client.get("/api/v1/user/NOT_A_USER/avatar.png") + assert response.status_code == 404 # Assuming no user found leads to 404 + response = self.client.get("/api/v1/user/999/avatar.png") + assert response.status_code == 404 # Assuming no user found leads to 404 + + def test_avatar_valid_user_no_avatar(self): + self.login(ADMIN_USERNAME) + + response = self.client.get("/api/v1/user/1/avatar.png", follow_redirects=False) + assert response.status_code == 204 + + @with_config({"SLACK_API_TOKEN": "dummy", "SLACK_ENABLE_AVATARS": True}) + @patch("superset.views.users.api.get_user_avatar", return_value=AVATAR_URL) + def test_avatar_with_valid_user(self, mock): + self.login(ADMIN_USERNAME) + response = self.client.get("/api/v1/user/1/avatar.png", follow_redirects=False) + mock.assert_called_once_with("admin@fab.org") + assert response.status_code == 301 + assert response.headers["Location"] == AVATAR_URL diff --git a/tests/unit_tests/dao/user_test.py b/tests/unit_tests/dao/user_test.py new file mode 100644 index 0000000000000..3808be28c4cd8 --- /dev/null +++ b/tests/unit_tests/dao/user_test.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest.mock import MagicMock + +import pytest +from flask_appbuilder.security.sqla.models import User +from sqlalchemy.orm import Query +from sqlalchemy.orm.exc import NoResultFound + +from superset.daos.user import db, UserDAO +from superset.models.user_attributes import UserAttribute + + +@pytest.fixture +def mock_db_session(mocker): + db = mocker.patch("superset.daos.user.db", autospec=True) + db.session = MagicMock() + db.session.query = MagicMock() + db.session.commit = MagicMock() + db.session.query.return_value = MagicMock() + return db.session + + +def test_get_by_id_found(mock_db_session): + # Setup + user_id = 1 + mock_user = User() + mock_user.id = user_id + mock_query = mock_db_session.query.return_value + mock_query.filter_by.return_value.one.return_value = mock_user + + # Execute + result = UserDAO.get_by_id(user_id) + + # Assert + mock_db_session.query.assert_called_with(User) + mock_query.filter_by.assert_called_with(id=user_id) + + +def test_get_by_id_not_found(mock_db_session): + # Setup + user_id = 1 + mock_query = mock_db_session.query.return_value + mock_query.filter_by.return_value.one.side_effect = NoResultFound + + # Execute & Assert + with pytest.raises(NoResultFound): + UserDAO.get_by_id(user_id) + + +def test_set_avatar_url_with_existing_attributes(mock_db_session): + # Setup + user = User() + user.id = 1 + user.extra_attributes = [UserAttribute(user_id=user.id, avatar_url="old_url")] + + # Execute + new_url = "http://newurl.com" + UserDAO.set_avatar_url(user, new_url) + + # Assert + assert user.extra_attributes[0].avatar_url == new_url + mock_db_session.add.assert_not_called() # No new attributes should be added + + +def test_set_avatar_url_without_existing_attributes(mock_db_session): + # Setup + user = User() + user.id = 1 + user.extra_attributes = [] + + # Execute + new_url = "http://newurl.com" + UserDAO.set_avatar_url(user, new_url) + + # Assert + assert len(user.extra_attributes) == 1 + assert user.extra_attributes[0].avatar_url == new_url + mock_db_session.add.assert_called() # New attribute should be added + mock_db_session.commit.assert_called() diff --git a/tests/unit_tests/notifications/slack_tests.py b/tests/unit_tests/notifications/slack_tests.py index 7e6cc3afc1ff4..e423527df8201 100644 --- a/tests/unit_tests/notifications/slack_tests.py +++ b/tests/unit_tests/notifications/slack_tests.py @@ -31,7 +31,8 @@ def test_send_slack( # requires app context from superset.reports.models import ReportRecipients, ReportRecipientType from superset.reports.notifications.base import NotificationContent - from superset.reports.notifications.slack import SlackNotification, WebClient + from superset.reports.notifications.slack import SlackNotification + from superset.utils.slack import WebClient execution_id = uuid.uuid4() flask_global_mock.logs_context = {"execution_id": execution_id}