diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a83a59d4e73e9..22a6a0c1e6c10 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ # repos: - repo: https://github.com/ambv/black - rev: 19.3b0 + rev: 19.10b0 hooks: - id: black language_version: python3 diff --git a/requirements-dev.txt b/requirements-dev.txt index e3c6f957cd68b..77820f4d49c46 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # -black==19.3b0 +black==19.10b0 coverage==4.5.3 flask-cors==3.0.7 flask-testing==0.7.1 ipdb==0.12 isort==4.3.21 -mypy==0.670 +mypy==0.770 nose==1.3.7 pip-tools==4.5.1 pre-commit==1.17.0 diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 517a458d4c4f5..a1e15c6f0c27b 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -810,7 +810,7 @@ def granularity( "year": "P1Y", } - granularity = {"type": "period"} + granularity: Dict[str, Union[str, float]] = {"type": "period"} if timezone: granularity["timeZone"] = timezone @@ -831,7 +831,7 @@ def granularity( granularity["period"] = period_name else: granularity["type"] = "duration" - granularity["duration"] = ( # type: ignore + granularity["duration"] = ( utils.parse_human_timedelta(period_name).total_seconds() * 1000 ) return granularity @@ -941,23 +941,24 @@ def metrics_and_post_aggs( adhoc_agg_configs = [] postagg_names = [] for metric in metrics: - if utils.is_adhoc_metric(metric): + if isinstance(metric, dict) and utils.is_adhoc_metric(metric): adhoc_agg_configs.append(metric) - elif metrics_dict[metric].metric_type != POST_AGG_TYPE: # type: ignore - saved_agg_names.add(metric) - else: - postagg_names.append(metric) + elif isinstance(metric, str): + if metrics_dict[metric].metric_type != POST_AGG_TYPE: + saved_agg_names.add(metric) + else: + postagg_names.append(metric) # Create the post aggregations, maintain order since postaggs # may depend on previous ones post_aggs: "OrderedDict[str, Postaggregator]" = OrderedDict() visited_postaggs = set() for postagg_name in postagg_names: - postagg = metrics_dict[postagg_name] # type: ignore + postagg = metrics_dict[postagg_name] visited_postaggs.add(postagg_name) DruidDatasource.resolve_postagg( postagg, post_aggs, saved_agg_names, visited_postaggs, metrics_dict ) - aggs = DruidDatasource.get_aggregations( # type: ignore + aggs = DruidDatasource.get_aggregations( metrics_dict, saved_agg_names, adhoc_agg_configs ) return aggs, post_aggs diff --git a/superset/forms.py b/superset/forms.py index fd2d078f834b4..175903af20b94 100644 --- a/superset/forms.py +++ b/superset/forms.py @@ -23,7 +23,7 @@ class CommaSeparatedListField(Field): widget = BS3TextFieldWidget() - data = [] # type: List[str] + data: List[str] = [] def _value(self): if self.data: diff --git a/superset/models/core.py b/superset/models/core.py index 81392d7591edf..1f56019ff3700 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -460,7 +460,7 @@ def get_all_table_names_in_schema( self, schema: str, cache: bool = False, - cache_timeout: int = None, + cache_timeout: Optional[int] = None, force: bool = False, ) -> List[utils.DatasourceName]: """Parameters need to be passed as keyword arguments. @@ -492,7 +492,7 @@ def get_all_view_names_in_schema( self, schema: str, cache: bool = False, - cache_timeout: int = None, + cache_timeout: Optional[int] = None, force: bool = False, ) -> List[utils.DatasourceName]: """Parameters need to be passed as keyword arguments. diff --git a/superset/utils/core.py b/superset/utils/core.py index 95f103298c5c1..778c71eac713d 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -51,7 +51,7 @@ from cryptography.hazmat.backends.openssl.x509 import _Certificate from dateutil.parser import parse from dateutil.relativedelta import relativedelta -from flask import current_app, flash, Flask, g, Markup, render_template +from flask import current_app, flash, g, Markup, render_template from flask_appbuilder import SQLA from flask_appbuilder.security.sqla.models import User from flask_babel import gettext as __, lazy_gettext as _ @@ -1057,15 +1057,11 @@ def get_since_until( else: rel, num, grain = time_range.split() if rel == "Last": - since = relative_start - relativedelta( # type: ignore - **{grain: int(num)} - ) + since = relative_start - relativedelta(**{grain: int(num)}) # type: ignore until = relative_end else: # rel == 'Next' since = relative_start - until = relative_end + relativedelta( # type: ignore - **{grain: int(num)} - ) + until = relative_end + relativedelta(**{grain: int(num)}) # type: ignore else: since = since or "" if since: @@ -1184,7 +1180,7 @@ def parse_ssl_cert(certificate: str) -> _Certificate: return x509.load_pem_x509_certificate( certificate.encode("utf-8"), default_backend() ) - except ValueError as e: + except ValueError: raise CertificateException("Invalid certificate") diff --git a/superset/views/core.py b/superset/views/core.py index 6c5497d4c3f86..02fd207486c33 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -188,7 +188,10 @@ def check_datasource_perms( except SupersetException as e: raise SupersetSecurityException(str(e)) - viz_obj = get_viz( # type: ignore + if datasource_type is None: + raise SupersetSecurityException("Could not determine datasource type") + + viz_obj = get_viz( datasource_type=datasource_type, datasource_id=datasource_id, form_data=form_data, diff --git a/superset/views/schedules.py b/superset/views/schedules.py index ad7861e75af5c..e84e3412c8fa4 100644 --- a/superset/views/schedules.py +++ b/superset/views/schedules.py @@ -48,7 +48,7 @@ class EmailScheduleView( ): # pylint: disable=too-many-ancestors include_route_methods = RouteMethod.CRUD_SET _extra_data = {"test_email": False, "test_email_recipients": None} - schedule_type: Optional[Type] = None + schedule_type: Optional[str] = None schedule_type_model: Optional[Type] = None page_size = 20 diff --git a/tests/access_tests.py b/tests/access_tests.py index 63cbfb5c7054d..58affb640e429 100644 --- a/tests/access_tests.py +++ b/tests/access_tests.py @@ -567,10 +567,7 @@ def test_request_access(self): self.get_resp(ACCESS_REQUEST.format("druid", druid_ds_4_id, "go")) access_request4 = self.get_access_requests("gamma", "druid", druid_ds_4_id) - self.assertEqual( - access_request4.roles_with_datasource, - "".format(access_request4.id), - ) + self.assertEqual(access_request4.roles_with_datasource, "") # Case 5. Roles exist that contains the druid datasource. # add druid ds to the existing roles diff --git a/tests/db_engine_specs/hive_tests.py b/tests/db_engine_specs/hive_tests.py index 4d24c0bdf25af..15de3ce6693df 100644 --- a/tests/db_engine_specs/hive_tests.py +++ b/tests/db_engine_specs/hive_tests.py @@ -58,7 +58,7 @@ def test_job_1_launched_stage_1(self): self.assertEqual(0, HiveEngineSpec.progress(log)) def test_job_1_launched_stage_1_map_40_progress( - self + self, ): # pylint: disable=invalid-name log = """ 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 @@ -71,7 +71,7 @@ def test_job_1_launched_stage_1_map_40_progress( self.assertEqual(10, HiveEngineSpec.progress(log)) def test_job_1_launched_stage_1_map_80_reduce_40_progress( - self + self, ): # pylint: disable=invalid-name log = """ 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 @@ -85,7 +85,7 @@ def test_job_1_launched_stage_1_map_80_reduce_40_progress( self.assertEqual(30, HiveEngineSpec.progress(log)) def test_job_1_launched_stage_2_stages_progress( - self + self, ): # pylint: disable=invalid-name log = """ 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 @@ -101,7 +101,7 @@ def test_job_1_launched_stage_2_stages_progress( self.assertEqual(12, HiveEngineSpec.progress(log)) def test_job_2_launched_stage_2_stages_progress( - self + self, ): # pylint: disable=invalid-name log = """ 17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2 @@ -145,7 +145,7 @@ def test_hive_error_msg(self): ) def test_hive_get_view_names_return_empty_list( - self + self, ): # pylint: disable=invalid-name self.assertEqual( [], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY) diff --git a/tests/db_engine_specs/presto_tests.py b/tests/db_engine_specs/presto_tests.py index cf62b282d4a47..0ca06359a0461 100644 --- a/tests/db_engine_specs/presto_tests.py +++ b/tests/db_engine_specs/presto_tests.py @@ -32,7 +32,7 @@ def test_get_datatype_presto(self): self.assertEqual("STRING", PrestoEngineSpec.get_datatype("string")) def test_presto_get_view_names_return_empty_list( - self + self, ): # pylint: disable=invalid-name self.assertEqual( [], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY) diff --git a/tests/superset_test_config.py b/tests/superset_test_config.py index d95b91a089e71..ff0a1e57f2449 100644 --- a/tests/superset_test_config.py +++ b/tests/superset_test_config.py @@ -14,9 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# type: ignore from copy import copy -from superset.config import * # type: ignore +from superset.config import * AUTH_USER_REGISTRATION_ROLE = "alpha" SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db") diff --git a/tests/superset_test_config_sqllab_backend_persist.py b/tests/superset_test_config_sqllab_backend_persist.py index 86619a2ff739a..27d721ff5fd60 100644 --- a/tests/superset_test_config_sqllab_backend_persist.py +++ b/tests/superset_test_config_sqllab_backend_persist.py @@ -15,10 +15,11 @@ # specific language governing permissions and limitations # under the License. # flake8: noqa +# type: ignore import os from copy import copy -from superset.config import * # type: ignore +from superset.config import * AUTH_USER_REGISTRATION_ROLE = "alpha" SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db")