Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle empty catalog when DB supports them #29840

Merged
merged 8 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/change_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
def fetch_files_github_api(url: str): # type: ignore
"""Fetches data using GitHub API."""
req = Request(url)
req.add_header("Authorization", f"token {GITHUB_TOKEN}")
req.add_header("Authorization", f"Bearer {GITHUB_TOKEN}")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: In most cases, you can use Authorization: Bearer or Authorization: token to pass a token. However, if you are passing a JSON web token (JWT), you must use Authorization: Bearer. [source]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

req.add_header("Accept", "application/vnd.github.v3+json")

print(f"Fetching from {url}")
Expand Down
2 changes: 1 addition & 1 deletion scripts/python_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ superset load-test-users

echo "Running tests"

pytest --durations-min=2 --maxfail=1 --cov-report= --cov=superset ./tests/integration_tests "$@"
pytest --durations-min=2 --cov-report= --cov=superset ./tests/integration_tests "$@"
4 changes: 4 additions & 0 deletions superset/cachekeys/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class Datasource(Schema):
datasource_name = fields.String(
metadata={"description": datasource_name_description},
)
catalog = fields.String(
allow_none=True,
metadata={"description": "Datasource catalog"},
)
schema = fields.String(
metadata={"description": "Datasource schema"},
)
Expand Down
19 changes: 12 additions & 7 deletions superset/commands/dataset/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,28 @@ def run(self) -> Model:
def validate(self) -> None:
exceptions: list[ValidationError] = []
database_id = self._properties["database"]
schema = self._properties.get("schema")
catalog = self._properties.get("catalog")
schema = self._properties.get("schema")
table_name = self._properties["table_name"]
sql = self._properties.get("sql")
owner_ids: Optional[list[int]] = self._properties.get("owners")

table = Table(self._properties["table_name"], schema, catalog)

# Validate uniqueness
if not DatasetDAO.validate_uniqueness(database_id, table):
exceptions.append(DatasetExistsValidationError(table))

# Validate/Populate database
database = DatasetDAO.get_database_by_id(database_id)
if not database:
exceptions.append(DatabaseNotFoundValidationError())
self._properties["database"] = database

# Validate uniqueness
if database:
if not catalog:
catalog = self._properties["catalog"] = database.get_default_catalog()

table = Table(table_name, schema, catalog)

if not DatasetDAO.validate_uniqueness(database, table):
exceptions.append(DatasetExistsValidationError(table))

# Validate table exists on dataset if sql is not provided
# This should be validated when the dataset is physical
if (
Expand Down
2 changes: 1 addition & 1 deletion superset/commands/dataset/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def import_dataset(

try:
table_exists = dataset.database.has_table(
Table(dataset.table_name, dataset.schema),
Table(dataset.table_name, dataset.schema, dataset.catalog),
)
except Exception: # pylint: disable=broad-except
# MySQL doesn't play nice with GSheets table names
Expand Down
15 changes: 13 additions & 2 deletions superset/commands/dataset/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,12 @@ def run(self) -> Model:
def validate(self) -> None:
exceptions: list[ValidationError] = []
owner_ids: Optional[list[int]] = self._properties.get("owners")

# Validate/populate model exists
self._model = DatasetDAO.find_by_id(self._model_id)
if not self._model:
raise DatasetNotFoundError()

# Check ownership
try:
security_manager.raise_for_ownership(self._model)
Expand All @@ -91,22 +93,30 @@ def validate(self) -> None:

database_id = self._properties.get("database")

catalog = self._properties.get("catalog")
if not catalog:
catalog = self._properties["catalog"] = (
self._model.database.get_default_catalog()
)

table = Table(
self._properties.get("table_name"), # type: ignore
self._properties.get("schema"),
self._properties.get("catalog"),
catalog,
)

# Validate uniqueness
if not DatasetDAO.validate_update_uniqueness(
self._model.database_id,
self._model.database,
table,
self._model_id,
):
exceptions.append(DatasetExistsValidationError(table))

# Validate/Populate database not allowed to change
if database_id and database_id != self._model:
exceptions.append(DatabaseChangeValidationError())

# Validate/Populate owner
try:
owners = self.compute_owners(
Expand All @@ -116,6 +126,7 @@ def validate(self) -> None:
self._properties["owners"] = owners
except ValidationError as ex:
exceptions.append(ex)

# Validate columns
if columns := self._properties.get("columns"):
self._validate_columns(columns, exceptions)
Expand Down
10 changes: 6 additions & 4 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,9 +461,11 @@ def data_for_slices( # pylint: disable=too-many-locals
)
else:
_columns = [
utils.get_column_name(column_)
if utils.is_adhoc_column(column_)
else column_
(
utils.get_column_name(column_)
if utils.is_adhoc_column(column_)
else column_
)
for column_param in COLUMN_FORM_DATA_PARAMS
for column_ in utils.as_list(form_data.get(column_param) or [])
]
Expand Down Expand Up @@ -1963,7 +1965,7 @@ class and any keys added via `ExtraCache`.
if self.has_extra_cache_key_calls(query_obj):
sqla_query = self.get_sqla_query(**query_obj)
extra_cache_keys += sqla_query.extra_cache_keys
return extra_cache_keys
return list(set(extra_cache_keys))

@property
def quote_identifier(self) -> Callable[[str], str]:
Expand Down
20 changes: 14 additions & 6 deletions superset/daos/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,19 @@ def validate_table_exists(

@staticmethod
def validate_uniqueness(
database_id: int,
database: Database,
table: Table,
dataset_id: int | None = None,
) -> bool:
# The catalog might not be set even if the database supports catalogs, in case
# multi-catalog is disabled.
catalog = table.catalog or database.get_default_catalog()

dataset_query = db.session.query(SqlaTable).filter(
SqlaTable.table_name == table.table,
SqlaTable.schema == table.schema,
SqlaTable.catalog == table.catalog,
SqlaTable.database_id == database_id,
SqlaTable.catalog == catalog,
SqlaTable.database_id == database.id,
)

if dataset_id:
Expand All @@ -103,15 +107,19 @@ def validate_uniqueness(

@staticmethod
def validate_update_uniqueness(
database_id: int,
database: Database,
table: Table,
dataset_id: int,
) -> bool:
# The catalog might not be set even if the database supports catalogs, in case
# multi-catalog is disabled.
catalog = table.catalog or database.get_default_catalog()

dataset_query = db.session.query(SqlaTable).filter(
SqlaTable.table_name == table.table,
SqlaTable.database_id == database_id,
SqlaTable.database_id == database.id,
SqlaTable.schema == table.schema,
SqlaTable.catalog == table.catalog,
SqlaTable.catalog == catalog,
SqlaTable.id != dataset_id,
)
return not db.session.query(dataset_query.exists()).scalar()
Expand Down
2 changes: 1 addition & 1 deletion superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,7 @@ def select_star(
self.incr_stats("init", self.select_star.__name__)
try:
result = database.select_star(
Table(table_name, schema_name),
Table(table_name, schema_name, database.get_default_catalog()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should catalog be added to the API path as well (when the table is not from the default one)? Not related with this PR, just asking if it's something we need to do.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but I I couldn't find any API calls to this endpoint — it seems like we return this in the database request now, instead of calling /select_star/, which is why I didn't change the code to pass a parameter. I think we should keep this endpoint until the next major version and then get rid of it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gotcha! thank you

latest_partition=True,
)
except NoSuchTableError:
Expand Down
2 changes: 1 addition & 1 deletion superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def process_template(self, sql: str, **kwargs: Any) -> str:
"""
Makes processing a template a noop
"""
return sql
return str(sql)


class PrestoTemplateProcessor(JinjaTemplateProcessor):
Expand Down
2 changes: 1 addition & 1 deletion superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def _get_sqla_engine( # pylint: disable=too-many-locals
g.user.id,
self.db_engine_spec,
)
if hasattr(g, "user") and hasattr(g.user, "id") and oauth2_config
if oauth2_config and hasattr(g, "user") and hasattr(g.user, "id")
else None
)
# If using MySQL or Presto for example, will set url.username
Expand Down
1 change: 1 addition & 0 deletions superset/models/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def set_related_perm(_mapper: Mapper, _connection: Connection, target: Slice) ->
ds = db.session.query(src_class).filter_by(id=int(id_)).first()
if ds:
target.perm = ds.perm
target.catalog_perm = ds.catalog_perm
target.schema_perm = ds.schema_perm


Expand Down
21 changes: 12 additions & 9 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,9 @@ def get_schemas_accessible_by_user(
# pylint: disable=import-outside-toplevel
from superset.connectors.sqla.models import SqlaTable

default_catalog = database.get_default_catalog()
catalog = catalog or default_catalog

if hierarchical and (
self.can_access_database(database)
or (catalog and self.can_access_catalog(database, catalog))
Expand All @@ -783,7 +786,6 @@ def get_schemas_accessible_by_user(
# schema_access
accessible_schemas: set[str] = set()
schema_access = self.user_view_menu_names("schema_access")
default_catalog = database.get_default_catalog()
default_schema = database.get_default_schema(default_catalog)

for perm in schema_access:
Expand All @@ -800,7 +802,7 @@ def get_schemas_accessible_by_user(
# [database].[catalog].[schema] matches when the catalog is equal to the
# requested catalog or, when no catalog specified, it's equal to the default
# catalog.
elif len(parts) == 3 and parts[1] == (catalog or default_catalog):
elif len(parts) == 3 and parts[1] == catalog:
accessible_schemas.add(parts[2])

# datasource_access
Expand Down Expand Up @@ -906,16 +908,16 @@ def get_datasources_accessible_by_user( # pylint: disable=invalid-name
if self.can_access_database(database):
return datasource_names

catalog = catalog or database.get_default_catalog()
if catalog:
catalog_perm = self.get_catalog_perm(database.database_name, catalog)
if catalog_perm and self.can_access("catalog_access", catalog_perm):
return datasource_names

if schema:
default_catalog = database.get_default_catalog()
schema_perm = self.get_schema_perm(
database.database_name,
catalog or default_catalog,
catalog,
schema,
)
if schema_perm and self.can_access("schema_access", schema_perm):
Expand Down Expand Up @@ -2183,6 +2185,7 @@ def raise_for_access(
database = query.database

database = cast("Database", database)
default_catalog = database.get_default_catalog()

if self.can_access_database(database):
return
Expand All @@ -2196,19 +2199,19 @@ def raise_for_access(
# from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy
# inspector to read it.
default_schema = database.get_default_schema_for_query(query)
# Determining the default catalog is much easier, because DB engine
# specs need explicit support for catalogs.
default_catalog = database.get_default_catalog()
tables = {
Table(
table_.table,
table_.schema or default_schema,
table_.catalog or default_catalog,
table_.catalog or query.catalog or default_catalog,
)
for table_ in extract_tables_from_jinja_sql(query.sql, database)
}
elif table:
tables = {table}
# Make sure table has the default catalog, if not specified.
tables = {
Table(table.table, table.schema, table.catalog or default_catalog)
}

denied = set()

Expand Down
1 change: 1 addition & 0 deletions superset/sqllab/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def export_csv(self, client_id: str) -> CsvResponse:
"client_id": client_id,
"row_count": row_count,
"database": query.database.name,
"catalog": query.catalog,
"schema": query.schema,
"sql": query.sql,
"exported_format": "csv",
Expand Down
2 changes: 2 additions & 0 deletions superset/sqllab/sqllab_execution_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def select_as_cta(self) -> bool:
def set_database(self, database: Database) -> None:
self._validate_db(database)
self.database = database
if self.catalog is None:
self.catalog = database.get_default_catalog()
if self.select_as_cta:
schema_name = self._get_ctas_target_schema_name(database)
self.create_table_as_select.target_schema_name = schema_name # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions superset/views/sql_lab/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,15 @@ def post(self) -> FlaskResponse:
db.session.query(TableSchema).filter(
TableSchema.tab_state_id == table["queryEditorId"],
TableSchema.database_id == table["dbId"],
TableSchema.catalog == table["catalog"],
TableSchema.schema == table["schema"],
TableSchema.table == table["name"],
).delete(synchronize_session=False)

table_schema = TableSchema(
tab_state_id=table["queryEditorId"],
database_id=table["dbId"],
catalog=table["catalog"],
schema=table["schema"],
table=table["name"],
description=json.dumps(table),
Expand Down
28 changes: 0 additions & 28 deletions tests/integration_tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1563,34 +1563,6 @@ def test_get_select_star_not_allowed(self):
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)

def test_get_select_star_datasource_access(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking comment: would be good if you could mention why these tests were removed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests were poorly written and relied on weird side effects.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!

"""
Database API: Test get select star with datasource access
"""
table = SqlaTable(
schema="main", table_name="ab_permission", database=get_main_database()
)
db.session.add(table)
db.session.commit()

tmp_table_perm = security_manager.find_permission_view_menu(
"datasource_access", table.get_perm()
)
gamma_role = security_manager.find_role("Gamma")
security_manager.add_permission_role(gamma_role, tmp_table_perm)

self.login(GAMMA_USERNAME)
main_db = get_main_database()
uri = f"api/v1/database/{main_db.id}/select_star/ab_permission/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)

# rollback changes
security_manager.del_permission_role(gamma_role, tmp_table_perm)
db.session.delete(table)
db.session.delete(main_db)
db.session.commit()

def test_get_select_star_not_found_database(self):
"""
Database API: Test get select star not found database
Expand Down
Loading
Loading