Skip to content

Commit

Permalink
Merge pull request #2010 from fishtown-analytics/fix/mypy-cleanups
Browse files Browse the repository at this point in the history
Fix some mypy checking
  • Loading branch information
beckjake authored Jan 29, 2020
2 parents 4e23e7d + 3e48dc3 commit e570d22
Show file tree
Hide file tree
Showing 59 changed files with 787 additions and 466 deletions.
5 changes: 3 additions & 2 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import os
from multiprocessing import RLock
# multiprocessing.RLock is a function returning this type
from multiprocessing.synchronize import RLock
from threading import get_ident
from typing import (
Dict, Tuple, Hashable, Optional, ContextManager, List
Expand Down Expand Up @@ -144,7 +145,7 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection:
'Opening a new connection, currently in state {}'
.format(conn.state)
)
conn.handle = LazyHandle(type(self))
conn.handle = LazyHandle(self.open)

conn.name = conn_name
return conn
Expand Down
23 changes: 16 additions & 7 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
from typing import (
Optional, Tuple, Callable, Container, FrozenSet, Type, Dict, Any, List,
Mapping, Iterator, Union
Mapping, Iterator, Union, Set
)

import agate
Expand Down Expand Up @@ -125,15 +125,18 @@ def add(self, relation: BaseRelation):
key = relation.information_schema_only()
if key not in self:
self[key] = set()
self[key].add(relation.schema.lower())
lowered: Optional[str] = None
if relation.schema is not None:
lowered = relation.schema.lower()
self[key].add(lowered)

def search(self):
for information_schema_name, schemas in self.items():
for schema in schemas:
yield information_schema_name, schema

def schemas_searched(self):
result = set()
result: Set[Tuple[str, str]] = set()
for information_schema_name, schemas in self.items():
result.update(
(information_schema_name.database, schema)
Expand Down Expand Up @@ -907,13 +910,17 @@ def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:

@available
@classmethod
def convert_type(cls, agate_table, col_idx):
def convert_type(
cls, agate_table: agate.Table, col_idx: int
) -> Optional[str]:
return cls.convert_agate_type(agate_table, col_idx)

@classmethod
def convert_agate_type(cls, agate_table, col_idx):
agate_type = agate_table.column_types[col_idx]
conversions = [
def convert_agate_type(
cls, agate_table: agate.Table, col_idx: int
) -> Optional[str]:
agate_type: Type = agate_table.column_types[col_idx]
conversions: List[Tuple[Type, Callable[..., str]]] = [
(agate.Text, cls.convert_text_type),
(agate.Number, cls.convert_number_type),
(agate.Boolean, cls.convert_boolean_type),
Expand All @@ -925,6 +932,8 @@ def convert_agate_type(cls, agate_table, col_idx):
if isinstance(agate_type, agate_cls):
return func(agate_table, col_idx)

return None

###
# Operations involving the manifest
###
Expand Down
22 changes: 16 additions & 6 deletions core/dbt/adapters/base/meta.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from functools import wraps
from typing import Callable, Optional, Any, FrozenSet, Dict
from typing import Callable, Optional, Any, FrozenSet, Dict, Set

from dbt.deprecations import warn, renamed_method

Expand Down Expand Up @@ -86,16 +86,26 @@ def parse_list(self, func: Callable) -> Callable:


class AdapterMeta(abc.ABCMeta):
_available_: FrozenSet[str]
_parse_replacements_: Dict[str, Callable]

def __new__(mcls, name, bases, namespace, **kwargs):
cls = super().__new__(mcls, name, bases, namespace, **kwargs)
# mypy does not like the `**kwargs`. But `ABCMeta` itself takes
# `**kwargs` in its argspec here (and passes them to `type.__new__`.
# I'm not sure there is any benefit to it after poking around a bit,
# but having it doesn't hurt on the python side (and omitting it could
# hurt for obscure metaclass reasons, for all I know)
cls = abc.ABCMeta.__new__( # type: ignore
mcls, name, bases, namespace, **kwargs
)

# this is very much inspired by ABCMeta's own implementation

# dict mapping the method name to whether the model name should be
# injected into the arguments. All methods in here are exposed to the
# context.
available = set()
replacements = {}
available: Set[str] = set()
replacements: Dict[str, Any] = {}

# collect base class data first
for base in bases:
Expand All @@ -110,7 +120,7 @@ def __new__(mcls, name, bases, namespace, **kwargs):
if parse_replacement is not None:
replacements[name] = parse_replacement

cls._available_: FrozenSet[str] = frozenset(available)
cls._available_ = frozenset(available)
# should this be a namedtuple so it will be immutable like _available_?
cls._parse_replacements_: Dict[str, Callable] = replacements
cls._parse_replacements_ = replacements
return cls
16 changes: 8 additions & 8 deletions core/dbt/adapters/base/query_headers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from threading import local
from typing import Optional, Callable
from typing import Optional, Callable, Dict, Any

from dbt.clients.jinja import QueryStringGenerator

Expand Down Expand Up @@ -56,7 +56,7 @@ def add(self, sql: str) -> str:
return '/* {} */\n{}'.format(self.query_comment.strip(), sql)

def set(self, comment: Optional[str]):
if '*/' in comment:
if isinstance(comment, str) and '*/' in comment:
# tell the user "no" so they don't hurt themselves by writing
# garbage
raise RuntimeException(
Expand All @@ -65,7 +65,7 @@ def set(self, comment: Optional[str]):
self.query_comment = comment


QueryStringFunc = Callable[[str, Optional[CompileResultNode]], str]
QueryStringFunc = Callable[[str, Optional[NodeWrapper]], str]


class QueryStringSetter:
Expand All @@ -77,13 +77,14 @@ def __init__(self, config: AdapterRequiredConfig):
self.generator: QueryStringFunc = lambda name, model: ''
# if the comment value was None or the empty string, just skip it
if comment_macro:
assert isinstance(comment_macro, str)
macro = '\n'.join((
'{%- macro query_comment_macro(connection_name, node) -%}',
self._get_comment_macro(),
comment_macro,
'{% endmacro %}'
))
ctx = self._get_context()
self.generator: QueryStringFunc = QueryStringGenerator(macro, ctx)
self.generator = QueryStringGenerator(macro, ctx)
self.comment = _QueryComment(None)
self.reset()

Expand All @@ -105,10 +106,9 @@ def reset(self):
self.set('master', None)

def set(self, name: str, node: Optional[CompileResultNode]):
wrapped: Optional[NodeWrapper] = None
if node is not None:
wrapped = NodeWrapper(node)
else:
wrapped = None
comment_str = self.generator(name, wrapped)
self.comment.set(comment_str)

Expand All @@ -127,5 +127,5 @@ def _get_comment_macro(self):
else:
return super()._get_comment_macro()

def _get_context(self):
def _get_context(self) -> Dict[str, Any]:
return QueryHeaderContext(self.config).to_dict(self.manifest.macros)
9 changes: 6 additions & 3 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,13 @@ def External(cls) -> str:
return str(RelationType.External)

@classproperty
def RelationType(cls) -> Type[RelationType]:
def get_relation_type(cls) -> Type[RelationType]:
return RelationType


Info = TypeVar('Info', bound='InformationSchema')


@dataclass(frozen=True, eq=False, repr=False)
class InformationSchema(BaseRelation):
information_schema_view: Optional[str] = None
Expand Down Expand Up @@ -470,10 +473,10 @@ def get_quote_policy(

@classmethod
def from_relation(
cls: Self,
cls: Type[Info],
relation: BaseRelation,
information_schema_view: Optional[str],
) -> Self:
) -> Info:
include_policy = cls.get_include_policy(
relation, information_schema_view
)
Expand Down
40 changes: 24 additions & 16 deletions core/dbt/adapters/cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import namedtuple
from copy import deepcopy
from typing import List, Iterable, Optional
from typing import List, Iterable, Optional, Dict, Set, Tuple, Any
import threading

from dbt.logger import CACHE_LOGGER as logger
Expand Down Expand Up @@ -177,20 +177,24 @@ class RelationsCache:
The adapters also hold this lock while filling the cache.
:attr Set[str] schemas: The set of known/cached schemas, all lowercased.
"""
def __init__(self):
self.relations = {}
def __init__(self) -> None:
self.relations: Dict[_ReferenceKey, _CachedRelation] = {}
self.lock = threading.RLock()
self.schemas = set()
self.schemas: Set[Tuple[Optional[str], Optional[str]]] = set()

def add_schema(self, database: str, schema: str):
def add_schema(
self, database: Optional[str], schema: Optional[str],
) -> None:
"""Add a schema to the set of known schemas (case-insensitive)
:param database: The database name to add.
:param schema: The schema name to add.
"""
self.schemas.add((_lower(database), _lower(schema)))

def drop_schema(self, database: str, schema: str):
def drop_schema(
self, database: Optional[str], schema: Optional[str],
) -> None:
"""Drop the given schema and remove it from the set of known schemas.
Then remove all its contents (and their dependents, etc) as well.
Expand All @@ -208,21 +212,21 @@ def drop_schema(self, database: str, schema: str):
# handle a drop_schema race by using discard() over remove()
self.schemas.discard(key)

def update_schemas(self, schemas: Iterable[str]):
def update_schemas(self, schemas: Iterable[Tuple[Optional[str], str]]):
"""Add multiple schemas to the set of known schemas (case-insensitive)
:param schemas: An iterable of the schema names to add.
"""
self.schemas.update((_lower(d), _lower(s)) for (d, s) in schemas)
self.schemas.update((_lower(d), s.lower()) for (d, s) in schemas)

def __contains__(self, schema_id):
def __contains__(self, schema_id: Tuple[Optional[str], str]):
"""A schema is 'in' the relations cache if it is in the set of cached
schemas.
:param Tuple[str, str] schema: The db name and schema name to look up.
:param schema_id: The db name and schema name to look up.
"""
db, schema = schema_id
return (_lower(db), _lower(schema)) in self.schemas
return (_lower(db), schema.lower()) in self.schemas

def dump_graph(self):
"""Dump a key-only representation of the schema to a dictionary. Every
Expand All @@ -238,7 +242,7 @@ def dump_graph(self):
for k, v in self.relations.items()
}

def _setdefault(self, relation):
def _setdefault(self, relation: _CachedRelation):
"""Add a relation to the cache, or return it if it already exists.
:param _CachedRelation relation: The relation to set or get.
Expand Down Expand Up @@ -275,6 +279,8 @@ def _add_link(self, referenced_key, dependent_key):
.format(dependent_key)
)

assert dependent is not None # we just raised!

referenced.add_reference(dependent)

def add_link(self, referenced, dependent):
Expand Down Expand Up @@ -305,15 +311,15 @@ def add_link(self, referenced, dependent):
if ref_key not in self.relations:
# Insert a dummy "external" relation.
referenced = referenced.replace(
type=referenced.RelationType.External
type=referenced.External
)
self.add(referenced)

dep_key = _make_key(dependent)
if dep_key not in self.relations:
# Insert a dummy "external" relation.
dependent = dependent.replace(
type=referenced.RelationType.External
type=referenced.External
)
self.add(dependent)
logger.debug(
Expand Down Expand Up @@ -469,7 +475,9 @@ def rename(self, old, new):

lazy_log('after rename: {!s}', self.dump_graph)

def get_relations(self, database, schema):
def get_relations(
self, database: Optional[str], schema: Optional[str]
) -> List[Any]:
"""Case-insensitively yield all relations matching the given schema.
:param str schema: The case-insensitive schema name to list from.
Expand Down Expand Up @@ -498,7 +506,7 @@ def clear(self):
self.schemas.clear()

def _list_relations_in_schema(
self, database: str, schema: str
self, database: Optional[str], schema: Optional[str]
) -> List[_CachedRelation]:
"""Get the relations in a schema. Callers should hold the lock."""
key = (_lower(database), _lower(schema))
Expand Down
5 changes: 3 additions & 2 deletions core/dbt/adapters/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def load_plugin(self, name: str) -> Type[Credentials]:
# and adapter_type entries with the same value, as they're all
# singletons
try:
mod = import_module('.' + name, 'dbt.adapters')
# mypy doesn't think modules have any attributes.
mod: Any = import_module('.' + name, 'dbt.adapters')
except ModuleNotFoundError as exc:
# if we failed to import the target module in particular, inform
# the user about it via a runtiem error
Expand All @@ -56,7 +57,7 @@ def load_plugin(self, name: str) -> Type[Credentials]:
# library. Log the stack trace.
logger.debug('', exc_info=True)
raise
plugin = mod.Plugin # type: AdapterPlugin
plugin: AdapterPlugin = mod.Plugin
plugin_type = plugin.adapter.type()

if plugin_type != name:
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/adapters/sql/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@ def list_relations_without_caching(
}
for _database, name, _schema, _type in results:
try:
_type = self.Relation.RelationType(_type)
_type = self.Relation.get_relation_type(_type)
except ValueError:
_type = self.Relation.RelationType.External
_type = self.Relation.External
relations.append(self.Relation.create(
database=_database,
schema=_schema,
Expand Down
1 change: 1 addition & 0 deletions core/dbt/clients/_jinja_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def find_blocks(self, allowed_blocks=None, collect_raw_data=True):

elif self.is_current_end(tag):
self.last_position = tag.end
assert self.current is not None
yield BlockTag(
block_type_name=self.current.block_type_name,
block_name=self.current.block_name,
Expand Down
4 changes: 4 additions & 0 deletions core/dbt/clients/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def clone_and_checkout(repo, cwd, dirname=None, remove_git_dir=False,
logger.debug('Updating existing dependency {}.', directory)
else:
matches = re.match("Cloning into '(.+)'", err.decode('utf-8'))
if matches is None:
raise dbt.exceptions.RuntimeException(
f'Error cloning {repo} - never saw "Cloning into ..." from git'
)
directory = matches.group(1)
logger.debug('Pulling new dependency {}.', directory)
full_path = os.path.join(cwd, directory)
Expand Down
Loading

0 comments on commit e570d22

Please sign in to comment.