Skip to content

Commit

Permalink
refactor: Re-enable pylint on 5 files (#10106)
Browse files Browse the repository at this point in the history
* Re-enable lint on 5 files

* revert something questionable

* Address PR feedback

* One more PR comment...

* black?

* Update code wrapping

* Disable bugged check

* Add a disable for a failure that's only showing up in CI.

* Fix bad refactor

* A little more lint fixing, bug fixing
  • Loading branch information
Will Barrett authored Jun 25, 2020
1 parent 16cffd0 commit 0017b61
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 132 deletions.
154 changes: 87 additions & 67 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C,R,W
import logging
from collections import OrderedDict
from datetime import datetime, timedelta
Expand Down Expand Up @@ -102,7 +101,7 @@ def query(self, query_obj: QueryObjectDict) -> QueryResult:
status = utils.QueryStatus.SUCCESS
try:
df = pd.read_sql_query(qry.statement, db.engine)
except Exception as ex:
except Exception as ex: # pylint: disable=broad-except
df = pd.DataFrame()
status = utils.QueryStatus.FAILED
logger.exception(ex)
Expand Down Expand Up @@ -228,7 +227,7 @@ def get_timestamp_expression(
"""
label = utils.DTTM_ALIAS

db = self.table.database
db_ = self.table.database
pdf = self.python_date_format
is_epoch = pdf in ("epoch_s", "epoch_ms")
if not self.expression and not time_grain and not is_epoch:
Expand All @@ -238,7 +237,7 @@ def get_timestamp_expression(
col = literal_column(self.expression)
else:
col = column(self.column_name)
time_expr = db.db_engine_spec.get_timestamp_expr(
time_expr = db_.db_engine_spec.get_timestamp_expr(
col, pdf, time_grain, self.type
)
return self.table.make_sqla_column_compatible(time_expr, label)
Expand Down Expand Up @@ -387,7 +386,9 @@ def lookup_obj(lookup_metric: SqlMetric) -> SqlMetric:
)


class SqlaTable(Model, BaseDatasource):
class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-attributes
Model, BaseDatasource
):

"""An ORM object for SqlAlchemy table references"""

Expand Down Expand Up @@ -461,7 +462,7 @@ def make_sqla_column_compatible(
if db_engine_spec.allows_column_aliases:
label = db_engine_spec.make_label_compatible(label_expected)
sqla_col = sqla_col.label(label)
sqla_col._df_label_expected = label_expected
sqla_col._df_label_expected = label_expected # pylint: disable=protected-access
return sqla_col

def __repr__(self) -> str:
Expand Down Expand Up @@ -560,8 +561,7 @@ def any_dttm_col(self) -> Optional[str]:

@property
def html(self) -> str:
t = ((c.column_name, c.type) for c in self.columns)
df = pd.DataFrame(t)
df = pd.DataFrame((c.column_name, c.type) for c in self.columns)
df.columns = ["field", "type"]
return df.to_html(
index=False,
Expand Down Expand Up @@ -598,18 +598,18 @@ def select_star(self) -> Optional[str]:

@property
def data(self) -> Dict[str, Any]:
d = super().data
data_ = super().data
if self.type == "table":
grains = self.database.grains() or []
if grains:
grains = [(g.duration, g.name) for g in grains]
d["granularity_sqla"] = utils.choicify(self.dttm_cols)
d["time_grain_sqla"] = grains
d["main_dttm_col"] = self.main_dttm_col
d["fetch_values_predicate"] = self.fetch_values_predicate
d["template_params"] = self.template_params
d["is_sqllab_view"] = self.is_sqllab_view
return d
data_["granularity_sqla"] = utils.choicify(self.dttm_cols)
data_["time_grain_sqla"] = grains
data_["main_dttm_col"] = self.main_dttm_col
data_["fetch_values_predicate"] = self.fetch_values_predicate
data_["template_params"] = self.template_params
data_["is_sqllab_view"] = self.is_sqllab_view
return data_

def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
"""Runs query against sqla to retrieve some
Expand Down Expand Up @@ -642,10 +642,10 @@ def mutate_query_from_config(self, sql: str) -> str:
"""Apply config's SQL_QUERY_MUTATOR
Typically adds comments to the query with context"""
SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"]
if SQL_QUERY_MUTATOR:
sql_query_mutator = config["SQL_QUERY_MUTATOR"]
if sql_query_mutator:
username = utils.get_username()
sql = SQL_QUERY_MUTATOR(sql, username, security_manager, self.database)
sql = sql_query_mutator(sql, username, security_manager, self.database)
return sql

def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
Expand Down Expand Up @@ -717,9 +717,11 @@ def _get_sqla_row_level_filters(
self, template_processor: BaseTemplateProcessor
) -> List[str]:
"""
Return the appropriate row level security filters for this table and the current user.
Return the appropriate row level security filters for
this table and the current user.
:param BaseTemplateProcessor template_processor: The template processor to apply to the filters.
:param BaseTemplateProcessor template_processor: The template
processor to apply to the filters.
:returns: A list of SQL clauses to be ANDed together.
:rtype: List[str]
"""
Expand All @@ -728,15 +730,17 @@ def _get_sqla_row_level_filters(
for f in security_manager.get_rls_filters(self)
]

def get_sqla_query( # sqla
def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements
self,
metrics: List[Metric],
granularity: str,
from_dttm: Optional[datetime],
to_dttm: Optional[datetime],
columns: Optional[List[str]] = None,
groupby: Optional[List[str]] = None,
filter: Optional[List[Dict[str, Any]]] = None,
filter: Optional[ # pylint: disable=redefined-builtin
List[Dict[str, Any]]
] = None, # pylint: disable=bad-whitespace
is_timeseries: bool = True,
timeseries_limit: int = 15,
timeseries_limit_metric: Optional[Metric] = None,
Expand Down Expand Up @@ -793,14 +797,14 @@ def get_sqla_query( # sqla
):
raise Exception(_("Empty query?"))
metrics_exprs: List[ColumnElement] = []
for m in metrics:
if utils.is_adhoc_metric(m):
assert isinstance(m, dict)
metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
elif isinstance(m, str) and m in metrics_dict:
metrics_exprs.append(metrics_dict[m].get_sqla_col())
for metric in metrics:
if utils.is_adhoc_metric(metric):
assert isinstance(metric, dict)
metrics_exprs.append(self.adhoc_metric_to_sqla(metric, cols))
elif isinstance(metric, str) and metric in metrics_dict:
metrics_exprs.append(metrics_dict[metric].get_sqla_col())
else:
raise Exception(_("Metric '%(metric)s' does not exist", metric=m))
raise Exception(_("Metric '%(metric)s' does not exist", metric=metric))
if metrics_exprs:
main_metric_expr = metrics_exprs[0]
else:
Expand All @@ -817,21 +821,21 @@ def get_sqla_query( # sqla
groupby = list(dict.fromkeys(columns_))

select_exprs = []
for s in groupby:
if s in cols:
outer = cols[s].get_sqla_col()
for selected in groupby:
if selected in cols:
outer = cols[selected].get_sqla_col()
else:
outer = literal_column(f"({s})")
outer = self.make_sqla_column_compatible(outer, s)
outer = literal_column(f"({selected})")
outer = self.make_sqla_column_compatible(outer, selected)

groupby_exprs_sans_timestamp[outer.name] = outer
select_exprs.append(outer)
elif columns:
for s in columns:
for selected in columns:
select_exprs.append(
cols[s].get_sqla_col()
if s in cols
else self.make_sqla_column_compatible(literal_column(s))
cols[selected].get_sqla_col()
if selected in cols
else self.make_sqla_column_compatible(literal_column(selected))
)
metrics_exprs = []

Expand Down Expand Up @@ -865,7 +869,10 @@ def get_sqla_query( # sqla

select_exprs += metrics_exprs

labels_expected = [c._df_label_expected for c in select_exprs]
labels_expected = [
c._df_label_expected # pylint: disable=protected-access
for c in select_exprs
]

select_exprs = db_engine_spec.make_select_compatible(
groupby_exprs_with_timestamp.values(), select_exprs
Expand Down Expand Up @@ -902,7 +909,11 @@ def get_sqla_query( # sqla
):
cond = col_obj.get_sqla_col().in_(eq)
if isinstance(eq, str) and NULL_STRING in eq:
cond = or_(cond, col_obj.get_sqla_col() is None)
cond = or_(
cond,
col_obj.get_sqla_col() # pylint: disable=singleton-comparison
== None,
)
if op == utils.FilterOperator.NOT_IN.value:
cond = ~cond
where_clause_and.append(cond)
Expand All @@ -924,9 +935,15 @@ def get_sqla_query( # sqla
elif op == utils.FilterOperator.LIKE.value:
where_clause_and.append(col_obj.get_sqla_col().like(eq))
elif op == utils.FilterOperator.IS_NULL.value:
where_clause_and.append(col_obj.get_sqla_col() == None)
where_clause_and.append(
col_obj.get_sqla_col() # pylint: disable=singleton-comparison
== None
)
elif op == utils.FilterOperator.IS_NOT_NULL.value:
where_clause_and.append(col_obj.get_sqla_col() != None)
where_clause_and.append(
col_obj.get_sqla_col() # pylint: disable=singleton-comparison
!= None
)
else:
raise Exception(
_("Invalid filter operation type: %(op)s", op=op)
Expand All @@ -953,7 +970,9 @@ def get_sqla_query( # sqla

# To ensure correct handling of the ORDER BY labeling we need to reference the
# metric instance if defined in the SELECT clause.
metrics_exprs_by_label = {m._label: m for m in metrics_exprs}
metrics_exprs_by_label = {
m._label: m for m in metrics_exprs # pylint: disable=protected-access
}

for col, ascending in orderby:
direction = asc if ascending else desc
Expand All @@ -962,8 +981,10 @@ def get_sqla_query( # sqla
elif col in cols:
col = cols[col].get_sqla_col()

if isinstance(col, Label) and col._label in metrics_exprs_by_label:
col = metrics_exprs_by_label[col._label]
if isinstance(col, Label):
label = col._label # pylint: disable=protected-access
if label in metrics_exprs_by_label:
col = metrics_exprs_by_label[label]

qry = qry.order_by(direction(col))

Expand All @@ -973,7 +994,7 @@ def get_sqla_query( # sqla
qry = qry.offset(row_offset)

if (
is_timeseries
is_timeseries # pylint: disable=too-many-boolean-expressions
and timeseries_limit
and not time_groupby_inline
and ((is_sip_38 and columns) or (not is_sip_38 and groupby))
Expand Down Expand Up @@ -1087,14 +1108,14 @@ def _get_timeseries_orderby(

return ob

def _get_top_groups(
def _get_top_groups( # pylint: disable=no-self-use
self,
df: pd.DataFrame,
dimensions: List[str],
groupby_exprs: "OrderedDict[str, Any]",
) -> ColumnElement:
groups = []
for unused, row in df.iterrows():
for _unused, row in df.iterrows():
group = []
for dimension in dimensions:
group.append(groupby_exprs[dimension] == row[dimension])
Expand Down Expand Up @@ -1126,15 +1147,16 @@ def mutator(df: pd.DataFrame) -> None:
f"For {sql}, df.columns: {df.columns}"
f" differs from {labels_expected}"
)
else:
df.columns = labels_expected
df.columns = labels_expected

try:
df = self.database.get_df(sql, self.schema, mutator)
except Exception as ex:
except Exception as ex: # pylint: disable=broad-except
df = pd.DataFrame()
status = utils.QueryStatus.FAILED
logger.warning(f"Query {sql} on schema {self.schema} failed", exc_info=True)
logger.warning(
"Query %s on schema %s failed", sql, self.schema, exc_info=True
)
db_engine_spec = self.database.db_engine_spec
errors = db_engine_spec.extract_errors(ex)

Expand All @@ -1152,7 +1174,7 @@ def get_sqla_table_object(self) -> Table:
def fetch_metadata(self, commit: bool = True) -> None:
"""Fetches the metadata for the table and merges it in"""
try:
table = self.get_sqla_table_object()
table_ = self.get_sqla_table_object()
except Exception as ex:
logger.exception(ex)
raise Exception(
Expand All @@ -1169,24 +1191,22 @@ def fetch_metadata(self, commit: bool = True) -> None:
dbcols = (
db.session.query(TableColumn)
.filter(TableColumn.table == self)
.filter(or_(TableColumn.column_name == col.name for col in table.columns))
.filter(or_(TableColumn.column_name == col.name for col in table_.columns))
)
dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}

for col in table.columns:
for col in table_.columns:
try:
datatype = db_engine_spec.column_datatype_to_string(
col.type, db_dialect
)
except Exception as ex:
except Exception as ex: # pylint: disable=broad-except
datatype = "UNKNOWN"
logger.error("Unrecognized data type in {}.{}".format(table, col.name))
logger.error("Unrecognized data type in %s.%s", table_, col.name)
logger.exception(ex)
dbcol = dbcols.get(col.name, None)
if not dbcol:
dbcol = TableColumn(column_name=col.name, type=datatype, table=self)
dbcol.sum = dbcol.is_numeric
dbcol.avg = dbcol.is_numeric
dbcol.is_dttm = dbcol.is_temporal
db_engine_spec.alter_new_orm_column(dbcol)
else:
Expand Down Expand Up @@ -1227,30 +1247,30 @@ def import_obj(
superset instances. Audit metadata isn't copies over.
"""

def lookup_sqlatable(table: "SqlaTable") -> "SqlaTable":
def lookup_sqlatable(table_: "SqlaTable") -> "SqlaTable":
return (
db.session.query(SqlaTable)
.join(Database)
.filter(
SqlaTable.table_name == table.table_name,
SqlaTable.schema == table.schema,
Database.id == table.database_id,
SqlaTable.table_name == table_.table_name,
SqlaTable.schema == table_.schema,
Database.id == table_.database_id,
)
.first()
)

def lookup_database(table: SqlaTable) -> Database:
def lookup_database(table_: SqlaTable) -> Database:
try:
return (
db.session.query(Database)
.filter_by(database_name=table.params_dict["database_name"])
.filter_by(database_name=table_.params_dict["database_name"])
.one()
)
except NoResultFound:
raise DatabaseNotFound(
_(
"Database '%(name)s' is not found",
name=table.params_dict["database_name"],
name=table_.params_dict["database_name"],
)
)

Expand Down
Loading

0 comments on commit 0017b61

Please sign in to comment.