Skip to content

Commit

Permalink
refactor(ImportModelsCommand): clean code, extract logic into methods
Browse files Browse the repository at this point in the history
  • Loading branch information
michael_hoffman committed Feb 24, 2022
1 parent 0edb979 commit 358465f
Showing 1 changed file with 25 additions and 17 deletions.
42 changes: 25 additions & 17 deletions superset/commands/importers/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@


class ImportModelsCommand(BaseCommand):

"""Import models"""

dao = BaseDAO
Expand Down Expand Up @@ -73,22 +72,26 @@ def run(self) -> None:
def validate(self) -> None:
exceptions: List[ValidationError] = []

# load existing databases so we can apply the password validation
db_passwords = {
str(uuid): password
for uuid, password in db.session.query(
Database.uuid, Database.password
).all()
}

# verify that the metadata file is present and valid
try:
metadata: Optional[Dict[str, str]] = load_metadata(self.contents)
except ValidationError as exc:
exceptions.append(exc)
metadata = None

# validate that the type declared in METADATA_FILE_NAME is correct
self._validate_metadata_type(metadata, exceptions)
self._load__configs(exceptions)
self._prevent_overwrite_existing_model(exceptions)

if exceptions:
exception = CommandInvalidError(f"Error importing {self.model_name}")
exception.add_list(exceptions)
raise exception

def _validate_metadata_type(
self, metadata: Optional[Dict[str, str]], exceptions: List[ValidationError]
) -> None:
"""Validate that the type declared in METADATA_FILE_NAME is correct"""
if metadata and "type" in metadata:
type_validator = validate.Equal(self.dao.model_cls.__name__) # type: ignore
try:
Expand All @@ -97,7 +100,14 @@ def validate(self) -> None:
exc.messages = {METADATA_FILE_NAME: {"type": exc.messages}}
exceptions.append(exc)

# validate objects
def _load__configs(self, exceptions: List[ValidationError]) -> None:
# load existing databases so we can apply the password validation
db_passwords: Dict[str, str] = {
str(uuid): password
for uuid, password in db.session.query(
Database.uuid, Database.password
).all()
}
for file_name, content in self.contents.items():
# skip directories
if not content:
Expand All @@ -121,7 +131,10 @@ def validate(self) -> None:
exc.messages = {file_name: exc.messages}
exceptions.append(exc)

# check if the object exists and shouldn't be overwritten
def _prevent_overwrite_existing_model( # pylint: disable=invalid-name
self, exceptions: List[ValidationError]
) -> None:
"""check if the object exists and shouldn't be overwritten"""
if not self.overwrite:
existing_uuids = self._get_uuids()
for file_name, config in self._configs.items():
Expand All @@ -139,8 +152,3 @@ def validate(self) -> None:
}
)
)

if exceptions:
exception = CommandInvalidError(f"Error importing {self.model_name}")
exception.add_list(exceptions)
raise exception

0 comments on commit 358465f

Please sign in to comment.