Skip to content

Commit

Permalink
chore(models): Adding encrypted field checks (apache#28436)
Browse files Browse the repository at this point in the history
  • Loading branch information
craig-rueda authored and aehanno committed May 13, 2024
1 parent 24c747a commit 552483a
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 11 deletions.
15 changes: 6 additions & 9 deletions superset/databases/ssh_tunnel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
from flask import current_app
from flask_appbuilder import Model
from sqlalchemy.orm import backref, relationship
from sqlalchemy_utils import EncryptedType
from sqlalchemy.types import Text

from superset.constants import PASSWORD_MASK
from superset.extensions import encrypted_field_factory
from superset.models.core import Database
from superset.models.helpers import (
AuditMixinNullable,
Expand Down Expand Up @@ -53,19 +54,15 @@ class SSHTunnel(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):

server_address = sa.Column(sa.Text)
server_port = sa.Column(sa.Integer)
username = sa.Column(EncryptedType(sa.String, app_config["SECRET_KEY"]))
username = sa.Column(encrypted_field_factory.create(Text))

# basic authentication
password = sa.Column(
EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
)
password = sa.Column(encrypted_field_factory.create(Text), nullable=True)

# password protected pkey authentication
private_key = sa.Column(
EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
)
private_key = sa.Column(encrypted_field_factory.create(Text), nullable=True)
private_key_password = sa.Column(
EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
encrypted_field_factory.create(Text), nullable=True
)

export_fields = [
Expand Down
9 changes: 8 additions & 1 deletion superset/utils/encrypt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sqlalchemy.engine import Connection, Dialect, Row
from sqlalchemy_utils import EncryptedType

ENC_ADAPTER_TAG_ATTR_NAME = "__created_by_enc_field_adapter__"
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -70,12 +71,18 @@ def create(
self, *args: list[Any], **kwargs: Optional[dict[str, Any]]
) -> TypeDecorator:
if self._concrete_type_adapter:
return self._concrete_type_adapter.create(self._config, *args, **kwargs)
adapter = self._concrete_type_adapter.create(self._config, *args, **kwargs)
setattr(adapter, ENC_ADAPTER_TAG_ATTR_NAME, True)
return adapter

raise Exception( # pylint: disable=broad-exception-raised
"App not initialized yet. Please call init_app first"
)

@staticmethod
def created_by_enc_field_factory(field: TypeDecorator) -> bool:
return getattr(field, ENC_ADAPTER_TAG_ATTR_NAME, False)


class SecretsMigrator:
def __init__(self, previous_secret_key: str) -> None:
Expand Down
27 changes: 26 additions & 1 deletion tests/integration_tests/utils/encrypt_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
from sqlalchemy_utils.types.encrypted.encrypted_type import StringEncryptedType

from superset.extensions import encrypted_field_factory
from superset.utils.encrypt import AbstractEncryptedFieldAdapter, SQLAlchemyUtilsAdapter
from superset.utils.encrypt import (
AbstractEncryptedFieldAdapter,
SecretsMigrator,
SQLAlchemyUtilsAdapter,
)
from tests.integration_tests.base_tests import SupersetTestCase


Expand Down Expand Up @@ -60,4 +64,25 @@ def test_custom_adapter(self):
field = encrypted_field_factory.create(String(1024))
self.assertTrue(isinstance(field, StringEncryptedType))
self.assertFalse(isinstance(field, EncryptedType))
self.assertTrue(getattr(field, "__created_by_enc_field_adapter__"))
self.assertEqual(self.app.config["SECRET_KEY"], field.key)

def test_ensure_encrypted_field_factory_is_used(self):
"""
Ensure that the EncryptedFieldFactory is used everywhere
that an encrypted field is needed.
:return:
"""
from superset.extensions import encrypted_field_factory

migrator = SecretsMigrator("")
encrypted_fields = migrator.discover_encrypted_fields()
for table_name, cols in encrypted_fields.items():
for col_name, field in cols.items():
if not encrypted_field_factory.created_by_enc_field_factory(field):
self.fail(
f"The encrypted column [{col_name}]"
f" in the table [{table_name}]"
" was not created using the"
" encrypted_field_factory"
)

0 comments on commit 552483a

Please sign in to comment.