diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index aaa4252e7f33b..947551cf5bf21 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1073,7 +1073,7 @@ def mutator(df: pd.DataFrame) -> None: def get_sqla_table_object(self) -> Table: return self.database.get_table(self.table_name, schema=self.schema) - def fetch_metadata(self) -> None: + def fetch_metadata(self, commit=True) -> None: """Fetches the metadata for the table and merges it in""" try: table = self.get_sqla_table_object() @@ -1086,7 +1086,6 @@ def fetch_metadata(self) -> None: ).format(self.table_name) ) - M = SqlMetric metrics = [] any_date_col = None db_engine_spec = self.database.db_engine_spec @@ -1123,7 +1122,7 @@ def fetch_metadata(self) -> None: any_date_col = col.name metrics.append( - M( + SqlMetric( metric_name="count", verbose_name="COUNT(*)", metric_type="count", @@ -1134,7 +1133,8 @@ def fetch_metadata(self) -> None: self.main_dttm_col = any_date_col self.add_missing_metrics(metrics) db.session.merge(self) - db.session.commit() + if commit: + db.session.commit() @classmethod def import_obj(cls, i_datasource, import_time=None) -> int: diff --git a/superset/datasets/api.py b/superset/datasets/api.py index 2e12629c42f10..75e4f8fba640e 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -30,8 +30,10 @@ DatasetForbiddenError, DatasetInvalidError, DatasetNotFoundError, + DatasetRefreshFailedError, DatasetUpdateFailedError, ) +from superset.datasets.commands.refresh import RefreshDatasetCommand from superset.datasets.commands.update import UpdateDatasetCommand from superset.datasets.schemas import DatasetPostSchema, DatasetPutSchema from superset.views.base import DatasourceFilter @@ -49,7 +51,9 @@ class DatasetRestApi(BaseSupersetModelRestApi): allow_browser_login = True class_permission_name = "TableModelView" - include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | {RouteMethod.RELATED} + include_route_methods = ( + RouteMethod.REST_MODEL_VIEW_CRUD_SET | {RouteMethod.RELATED} | {"refresh"} + ) list_columns = [ "changed_by_name", @@ -268,3 +272,47 @@ def delete(self, pk: int) -> Response: # pylint: disable=arguments-differ except DatasetDeleteFailedError as e: logger.error(f"Error deleting model {self.__class__.__name__}: {e}") return self.response_422(message=str(e)) + + @expose("//refresh", methods=["PUT"]) + @protect() + @safe + def refresh(self, pk: int) -> Response: # pylint: disable=invalid-name + """Refresh a Dataset + --- + put: + description: >- + Refresh updates columns for a dataset + parameters: + - in: path + schema: + type: integer + name: pk + responses: + 200: + description: Dataset refreshed + content: + application/json: + schema: + type: object + properties: + message: + type: string + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + try: + RefreshDatasetCommand(g.user, pk).run() + return self.response(200, message="OK") + except DatasetNotFoundError: + return self.response_404() + except DatasetForbiddenError: + return self.response_403() + except DatasetRefreshFailedError as e: + logger.error(f"Error refreshing dataset {self.__class__.__name__}: {e}") + return self.response_422(message=str(e)) diff --git a/superset/datasets/commands/exceptions.py b/superset/datasets/commands/exceptions.py index a6d0ed7deda3b..e96313755aaad 100644 --- a/superset/datasets/commands/exceptions.py +++ b/superset/datasets/commands/exceptions.py @@ -101,3 +101,7 @@ class DatasetDeleteFailedError(DeleteFailedError): class DatasetForbiddenError(ForbiddenError): message = _("Changing this dataset is forbidden") + + +class DatasetRefreshFailedError(UpdateFailedError): + message = _("Dataset could not be updated.") diff --git a/superset/datasets/commands/refresh.py b/superset/datasets/commands/refresh.py new file mode 100644 index 0000000000000..2c58cc4b2afcd --- /dev/null +++ b/superset/datasets/commands/refresh.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Optional + +from flask_appbuilder.security.sqla.models import User + +from superset.commands.base import BaseCommand +from superset.connectors.sqla.models import SqlaTable +from superset.datasets.commands.exceptions import ( + DatasetForbiddenError, + DatasetNotFoundError, + DatasetRefreshFailedError, +) +from superset.datasets.dao import DatasetDAO +from superset.exceptions import SupersetSecurityException +from superset.views.base import check_ownership + +logger = logging.getLogger(__name__) + + +class RefreshDatasetCommand(BaseCommand): + def __init__(self, user: User, model_id: int): + self._actor = user + self._model_id = model_id + self._model: Optional[SqlaTable] = None + + def run(self): + self.validate() + try: + # Updates columns and metrics from the dataset + self._model.fetch_metadata() + except Exception as e: + logger.exception(e) + raise DatasetRefreshFailedError() + return self._model + + def validate(self) -> None: + # Validate/populate model exists + self._model = DatasetDAO.find_by_id(self._model_id) + if not self._model: + raise DatasetNotFoundError() + # Check ownership + try: + check_ownership(self._model) + except SupersetSecurityException: + raise DatasetForbiddenError() diff --git a/superset/views/base_api.py b/superset/views/base_api.py index fdfafe1707b0e..013b0fd0ecbc9 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -74,6 +74,7 @@ class BaseSupersetModelRestApi(ModelRestApi): "bulk_delete": "delete", "info": "list", "related": "list", + "refresh": "edit", } order_rel_fields: Dict[str, Tuple[str, str]] = {} diff --git a/tests/dataset_api_tests.py b/tests/dataset_api_tests.py index a55140a5dae7b..0959813b3e81d 100644 --- a/tests/dataset_api_tests.py +++ b/tests/dataset_api_tests.py @@ -20,9 +20,10 @@ from unittest.mock import patch import prison +from sqlalchemy.sql import func from superset import db, security_manager -from superset.connectors.sqla.models import SqlaTable +from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.dao.exceptions import ( DAOCreateFailedError, DAODeleteFailedError, @@ -452,3 +453,55 @@ def test_delete_dataset_sqlalchemy_error(self, mock_dao_delete): self.assertEqual(data, {"message": "Dataset could not be deleted."}) db.session.delete(table) db.session.commit() + + def test_dataset_item_refresh(self): + """ + Dataset API: Test item refresh + """ + dataset = self.insert_default_dataset() + # delete a column + id_column = ( + db.session.query(TableColumn) + .filter_by(table_id=dataset.id, column_name="id") + .one() + ) + db.session.delete(id_column) + db.session.commit() + + self.login(username="admin") + uri = f"api/v1/dataset/{dataset.id}/refresh" + rv = self.client.put(uri) + self.assertEqual(rv.status_code, 200) + # Assert the column is restored on refresh + id_column = ( + db.session.query(TableColumn) + .filter_by(table_id=dataset.id, column_name="id") + .one() + ) + self.assertIsNotNone(id_column) + db.session.delete(dataset) + db.session.commit() + + def test_dataset_item_refresh_not_found(self): + """ + Dataset API: Test item refresh not found dataset + """ + max_id = db.session.query(func.max(SqlaTable.id)).scalar() + + self.login(username="admin") + uri = f"api/v1/dataset/{max_id + 1}/refresh" + rv = self.client.put(uri) + self.assertEqual(rv.status_code, 404) + + def test_dataset_item_refresh_not_owned(self): + """ + Dataset API: Test item refresh not owned dataset + """ + dataset = self.insert_default_dataset() + self.login(username="alpha") + uri = f"api/v1/dataset/{dataset.id}/refresh" + rv = self.client.put(uri) + self.assertEqual(rv.status_code, 403) + + db.session.delete(dataset) + db.session.commit()