Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Configuration Method and expanded parameters for Database Model #14451

Merged
merged 14 commits into from
May 15, 2021
3 changes: 3 additions & 0 deletions superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"expose_in_sqllab",
"allow_run_async",
"allow_csv_upload",
"configuration_method",
"allow_ctas",
"allow_cvas",
"allow_dml",
Expand Down Expand Up @@ -146,6 +147,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"allow_ctas",
"allow_cvas",
"allow_dml",
"configuration_method",
"force_ctas_schema",
"impersonate_user",
"allow_multi_schema_metadata_fetch",
Expand Down Expand Up @@ -230,6 +232,7 @@ def post(self) -> Response:
500:
$ref: '#/components/responses/500'
"""

if not request.is_json:
return self.response_400(message="Request is not JSON")
try:
Expand Down
2 changes: 0 additions & 2 deletions superset/databases/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def validate(self) -> None:
exceptions: List[ValidationError] = list()
sqlalchemy_uri: Optional[str] = self._properties.get("sqlalchemy_uri")
database_name: Optional[str] = self._properties.get("database_name")

if not sqlalchemy_uri:
exceptions.append(DatabaseRequiredFieldValidationError("sqlalchemy_uri"))
if not database_name:
Expand All @@ -87,7 +86,6 @@ def validate(self) -> None:
# Check database_name uniqueness
if not DatabaseDAO.validate_uniqueness(database_name):
exceptions.append(DatabaseExistsValidationError())

if exceptions:
exception = DatabaseInvalidError()
exception.add_list(exceptions)
Expand Down
20 changes: 19 additions & 1 deletion superset/databases/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
from flask_babel import lazy_gettext as _
from marshmallow import EXCLUDE, fields, pre_load, Schema, validates_schema
from marshmallow.validate import Length, ValidationError
from marshmallow_enum import EnumField
from sqlalchemy import MetaData
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import ArgumentError

from superset.db_engine_specs import get_engine_specs
from superset.db_engine_specs.base import BasicParametersMixin
from superset.exceptions import CertificateException, SupersetSecurityException
from superset.models.core import PASSWORD_MASK
from superset.models.core import ConfigurationMethod, PASSWORD_MASK
from superset.security.analytics_db_safety import check_sqlalchemy_uri
from superset.utils.core import markdown, parse_ssl_cert

Expand Down Expand Up @@ -70,6 +71,11 @@
"all database schemas. For large data warehouse with thousands of "
"tables, this can be expensive and put strain on the system."
) # pylint: disable=invalid-name
configuration_method_description = (
"Configuration_method is used on the frontend to "
"inform the backend whether to explode parameters "
"or to provide only a sqlalchemy_uri."
)
impersonate_user_description = (
"If Presto, all the queries in SQL Lab are going to be executed as the "
"currently logged on user who must have permission to run them.<br/>"
Expand Down Expand Up @@ -314,6 +320,12 @@ class Meta: # pylint: disable=too-few-public-methods
allow_ctas = fields.Boolean(description=allow_ctas_description)
allow_cvas = fields.Boolean(description=allow_cvas_description)
allow_dml = fields.Boolean(description=allow_dml_description)
configuration_method = EnumField(
ConfigurationMethod,
by_value=True,
required=True,
description=configuration_method_description,
)
force_ctas_schema = fields.String(
description=force_ctas_schema_description,
allow_none=True,
Expand Down Expand Up @@ -351,6 +363,12 @@ class Meta: # pylint: disable=too-few-public-methods
description=cache_timeout_description, allow_none=True
)
expose_in_sqllab = fields.Boolean(description=expose_in_sqllab_description)
configuration_method = EnumField(
ConfigurationMethod,
by_value=True,
allow_none=True,
description=configuration_method_description,
)
allow_run_async = fields.Boolean(description=allow_run_async_description)
allow_csv_upload = fields.Boolean(description=allow_csv_upload_description)
allow_ctas = fields.Boolean(description=allow_ctas_description)
Expand Down
13 changes: 11 additions & 2 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
# under the License.
# pylint: disable=line-too-long,unused-argument,ungrouped-imports
"""A collection of ORM sqlalchemy models for Superset"""
import enum
import json
import logging
import textwrap
from contextlib import closing
from copy import deepcopy
from datetime import datetime
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type

import numpy
Expand Down Expand Up @@ -99,6 +99,11 @@ class CssTemplate(Model, AuditMixinNullable):
css = Column(Text, default="")


class ConfigurationMethod(str, enum.Enum):
SQLALCHEMY_FORM = "sqlalchemy_form"
DYNAMIC_FORM = "dynamic_form"


class Database(
Model, AuditMixinNullable, ImportExportMixin
): # pylint: disable=too-many-public-methods
Expand All @@ -118,6 +123,9 @@ class Database(
cache_timeout = Column(Integer)
select_as_create_table_as = Column(Boolean, default=False)
expose_in_sqllab = Column(Boolean, default=True)
configuration_method = Column(
String(255), server_default=ConfigurationMethod.SQLALCHEMY_FORM.value
)
allow_run_async = Column(Boolean, default=False)
allow_csv_upload = Column(Boolean, default=False)
allow_ctas = Column(Boolean, default=False)
Expand Down Expand Up @@ -207,6 +215,7 @@ def data(self) -> Dict[str, Any]:
"id": self.id,
"name": self.database_name,
"backend": self.backend,
"configuration_method": self.configuration_method,
"allow_multi_schema_metadata_fetch": self.allow_multi_schema_metadata_fetch,
"allows_subquery": self.allows_subquery,
"allows_cost_estimate": self.allows_cost_estimate,
Expand Down Expand Up @@ -722,7 +731,7 @@ class Log(Model): # pylint: disable=too-few-public-methods
referrer = Column(String(1024))


class FavStarClassName(str, Enum):
class FavStarClassName(str, enum.Enum):
CHART = "slice"
DASHBOARD = "Dashboard"

Expand Down
124 changes: 120 additions & 4 deletions tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.errors import SupersetError
from superset.models.core import Database
from superset.models.core import Database, ConfigurationMethod
from superset.models.reports import ReportSchedule, ReportScheduleType
from superset.utils.core import get_example_database, get_main_database
from tests.base_tests import SupersetTestCase
Expand Down Expand Up @@ -229,6 +229,7 @@ def test_create_database(self):
database_data = {
"database_name": "test-create-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
"server_cert": None,
"extra": json.dumps(extra),
}
Expand All @@ -239,9 +240,71 @@ def test_create_database(self):
self.assertEqual(rv.status_code, 201)
# Cleanup
model = db.session.query(Database).get(response.get("id"))
assert model.configuration_method == ConfigurationMethod.SQLALCHEMY_FORM
db.session.delete(model)
db.session.commit()

def test_create_database_invalid_configuration_method(self):
"""
Database API: Test create with an invalid configuration method.
"""
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_csv_upload": [],
}

self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
database_data = {
"database_name": "test-create-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": "BAD_FORM",
"server_cert": None,
"extra": json.dumps(extra),
}

uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert response == {
"message": {"configuration_method": ["Invalid enum value BAD_FORM"]}
}
assert rv.status_code == 400

def test_create_database_no_configuration_method(self):
"""
Database API: Test create with no config method.
"""
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_csv_upload": [],
}

self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
database_data = {
"database_name": "test-create-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"server_cert": None,
"extra": json.dumps(extra),
}

uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert response == {
"message": {"configuration_method": ["Missing data for required field."]}
}
assert rv.status_code == 400

def test_create_database_server_cert_validate(self):
"""
Database API: Test create server cert validation
Expand All @@ -254,6 +317,7 @@ def test_create_database_server_cert_validate(self):
database_data = {
"database_name": "test-create-database-invalid-cert",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
"server_cert": "INVALID CERT",
}

Expand All @@ -276,6 +340,7 @@ def test_create_database_json_validate(self):
database_data = {
"database_name": "test-create-database-invalid-json",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
"encrypted_extra": '{"A": "a", "B", "C"}',
"extra": '["A": "a", "B", "C"]',
}
Expand Down Expand Up @@ -316,6 +381,7 @@ def test_create_database_extra_metadata_validate(self):
database_data = {
"database_name": "test-create-database-invalid-extra",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
"extra": json.dumps(extra),
}

Expand Down Expand Up @@ -345,6 +411,7 @@ def test_create_database_unique_validate(self):
database_data = {
"database_name": "examples",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}

uri = "api/v1/database/"
Expand All @@ -364,6 +431,7 @@ def test_create_database_uri_validate(self):
database_data = {
"database_name": "test-database-invalid-uri",
"sqlalchemy_uri": "wrong_uri",
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}

uri = "api/v1/database/"
Expand All @@ -385,6 +453,7 @@ def test_create_database_fail_sqllite(self):
database_data = {
"database_name": "test-create-sqlite-database",
"sqlalchemy_uri": "sqlite:////some.db",
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}

uri = "api/v1/database/"
Expand Down Expand Up @@ -413,6 +482,7 @@ def test_create_database_conn_fail(self):
database_data = {
"database_name": "test-create-database-wrong-password",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}

uri = "api/v1/database/"
Expand All @@ -434,7 +504,10 @@ def test_update_database(self):
"test-database", example_db.sqlalchemy_uri_decrypted
)
self.login(username="admin")
database_data = {"database_name": "test-database-updated"}
database_data = {
"database_name": "test-database-updated",
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
}
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
self.assertEqual(rv.status_code, 200)
Expand Down Expand Up @@ -535,6 +608,49 @@ def test_update_database_uri_validate(self):
db.session.delete(test_database)
db.session.commit()

def test_update_database_with_invalid_configuration_method(self):
"""
Database API: Test update
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
self.login(username="admin")
database_data = {
"database_name": "test-database-updated",
"configuration_method": "BAD_FORM",
}
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert response == {
"message": {"configuration_method": ["Invalid enum value BAD_FORM"]}
}
assert rv.status_code == 400

db.session.delete(test_database)
db.session.commit()

def test_update_database_with_no_configuration_method(self):
"""
Database API: Test update
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
self.login(username="admin")
database_data = {
"database_name": "test-database-updated",
}
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
assert rv.status_code == 200

db.session.delete(test_database)
db.session.commit()

def test_delete_database(self):
"""
Database API: Test delete
Expand Down Expand Up @@ -731,8 +847,8 @@ def test_database_schemas(self):
"""
Database API: Test database schemas
"""
self.login("admin")
database = db.session.query(Database).first()
self.login(username="admin")
database = db.session.query(Database).filter_by(database_name="examples").one()
schemas = database.get_all_schema_names()

rv = self.client.get(f"api/v1/database/{database.id}/schemas/")
Expand Down
2 changes: 0 additions & 2 deletions tests/databases/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def test_export_database_command(self, mock_g):
"metadata.yaml",
"databases/examples.yaml",
"datasets/examples/energy_usage.yaml",
"datasets/examples/wb_health_population.yaml",
"datasets/examples/birth_names.yaml",
}
expected_extra = {
Expand All @@ -88,7 +87,6 @@ def test_export_database_command(self, mock_g):
**expected_extra,
"engine_params": {"connect_args": {"poll_interval": 0.1}},
}

assert core_files.issubset(set(contents.keys()))

if example_db.backend == "postgresql":
Expand Down