Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix untyped functions in core/dbt/context/base.py #8525

Merged
merged 5 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 50 additions & 47 deletions core/dbt/context/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import json
import os
from typing import Any, Dict, NoReturn, Optional, Mapping, Iterable, Set, List
from typing import Any, Callable, Dict, NoReturn, Optional, Mapping, Iterable, Set, List
import threading

from dbt.flags import get_flags
Expand Down Expand Up @@ -86,33 +88,29 @@ def get_context_modules() -> Dict[str, Dict[str, Any]]:


class ContextMember:
def __init__(self, value, name=None):
def __init__(self, value: Any, name: Optional[str] = None) -> None:
self.name = name
self.inner = value

def key(self, default):
def key(self, default: str) -> str:
if self.name is None:
return default
return self.name


def contextmember(value):
if isinstance(value, str):
return lambda v: ContextMember(v, name=value)
return ContextMember(value)
def contextmember(value: Optional[str] = None) -> Callable:
return lambda v: ContextMember(v, name=value)


def contextproperty(value):
if isinstance(value, str):
return lambda v: ContextMember(property(v), name=value)
return ContextMember(property(value))
def contextproperty(value: Optional[str] = None) -> Callable:
return lambda v: ContextMember(property(v), name=value)


class ContextMeta(type):
def __new__(mcls, name, bases, dct):
context_members = {}
context_attrs = {}
new_dct = {}
def __new__(mcls, name, bases, dct: Dict[str, Any]) -> ContextMeta:
context_members: Dict[str, Any] = {}
context_attrs: Dict[str, Any] = {}
new_dct: Dict[str, Any] = {}

for base in bases:
context_members.update(getattr(base, "_context_members_", {}))
Expand Down Expand Up @@ -148,27 +146,28 @@ def _generate_merged(self) -> Mapping[str, Any]:
return self._cli_vars

@property
def node_name(self):
def node_name(self) -> str:
if self._node is not None:
return self._node.name
else:
return "<Configuration>"

def get_missing_var(self, var_name):
raise RequiredVarNotFoundError(var_name, self._merged, self._node)
def get_missing_var(self, var_name: str) -> NoReturn:
# TODO function name implies a non exception resolution
raise RequiredVarNotFoundError(var_name, dict(self._merged), self._node)

def has_var(self, var_name: str):
def has_var(self, var_name: str) -> bool:
return var_name in self._merged

def get_rendered_var(self, var_name):
def get_rendered_var(self, var_name: str) -> Any:
raw = self._merged[var_name]
# if bool/int/float/etc are passed in, don't compile anything
if not isinstance(raw, str):
return raw

return get_rendered(raw, self._context)
return get_rendered(raw, dict(self._context))

def __call__(self, var_name, default=_VAR_NOTSET):
def __call__(self, var_name: str, default: Any = _VAR_NOTSET) -> Any:
if self.has_var(var_name):
return self.get_rendered_var(var_name)
elif default is not self._VAR_NOTSET:
Expand All @@ -178,13 +177,17 @@ def __call__(self, var_name, default=_VAR_NOTSET):


class BaseContext(metaclass=ContextMeta):
# Set by ContextMeta
_context_members_: Dict[str, Any]
_context_attrs_: Dict[str, Any]

# subclass is TargetContext
def __init__(self, cli_vars):
self._ctx = {}
self.cli_vars = cli_vars
self.env_vars = {}
def __init__(self, cli_vars: Dict[str, Any]) -> None:
self._ctx: Dict[str, Any] = {}
self.cli_vars: Dict[str, Any] = cli_vars
self.env_vars: Dict[str, Any] = {}

def generate_builtins(self):
def generate_builtins(self) -> Dict[str, Any]:
builtins: Dict[str, Any] = {}
for key, value in self._context_members_.items():
if hasattr(value, "__get__"):
Expand All @@ -194,14 +197,14 @@ def generate_builtins(self):
return builtins

# no dbtClassMixin so this is not an actual override
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
self._ctx["context"] = self._ctx
builtins = self.generate_builtins()
self._ctx["builtins"] = builtins
self._ctx.update(builtins)
return self._ctx

@contextproperty
@contextproperty()
def dbt_version(self) -> str:
"""The `dbt_version` variable returns the installed version of dbt that
is currently running. It can be used for debugging or auditing
Expand All @@ -221,7 +224,7 @@ def dbt_version(self) -> str:
"""
return dbt_version

@contextproperty
@contextproperty()
def var(self) -> Var:
"""Variables can be passed from your `dbt_project.yml` file into models
during compilation. These variables are useful for configuring packages
Expand Down Expand Up @@ -290,7 +293,7 @@ def var(self) -> Var:
"""
return Var(self._ctx, self.cli_vars)

@contextmember
@contextmember()
def env_var(self, var: str, default: Optional[str] = None) -> str:
"""The env_var() function. Return the environment variable named 'var'.
If there is no such environment variable set, return the default.
Expand Down Expand Up @@ -318,7 +321,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:

if os.environ.get("DBT_MACRO_DEBUGGING"):

@contextmember
@contextmember()
@staticmethod
def debug():
"""Enter a debugger at this line in the compiled jinja code."""
Expand Down Expand Up @@ -357,7 +360,7 @@ def _return(data: Any) -> NoReturn:
"""
raise MacroReturn(data)

@contextmember
@contextmember()
@staticmethod
def fromjson(string: str, default: Any = None) -> Any:
"""The `fromjson` context method can be used to deserialize a json
Expand All @@ -378,7 +381,7 @@ def fromjson(string: str, default: Any = None) -> Any:
except ValueError:
return default

@contextmember
@contextmember()
@staticmethod
def tojson(value: Any, default: Any = None, sort_keys: bool = False) -> Any:
"""The `tojson` context method can be used to serialize a Python
Expand All @@ -401,7 +404,7 @@ def tojson(value: Any, default: Any = None, sort_keys: bool = False) -> Any:
except ValueError:
return default

@contextmember
@contextmember()
@staticmethod
def fromyaml(value: str, default: Any = None) -> Any:
"""The fromyaml context method can be used to deserialize a yaml string
Expand Down Expand Up @@ -432,7 +435,7 @@ def fromyaml(value: str, default: Any = None) -> Any:

# safe_dump defaults to sort_keys=True, but we act like json.dumps (the
# opposite)
@contextmember
@contextmember()
@staticmethod
def toyaml(
value: Any, default: Optional[str] = None, sort_keys: bool = False
Expand Down Expand Up @@ -477,7 +480,7 @@ def _set(value: Iterable[Any], default: Any = None) -> Optional[Set[Any]]:
except TypeError:
return default

@contextmember
@contextmember()
@staticmethod
def set_strict(value: Iterable[Any]) -> Set[Any]:
"""The `set_strict` context method can be used to convert any iterable
Expand Down Expand Up @@ -519,7 +522,7 @@ def _zip(*args: Iterable[Any], default: Any = None) -> Optional[Iterable[Any]]:
except TypeError:
return default

@contextmember
@contextmember()
@staticmethod
def zip_strict(*args: Iterable[Any]) -> Iterable[Any]:
"""The `zip_strict` context method can be used to used to return
Expand All @@ -541,7 +544,7 @@ def zip_strict(*args: Iterable[Any]) -> Iterable[Any]:
except TypeError as e:
raise ZipStrictWrongTypeError(e)

@contextmember
@contextmember()
@staticmethod
def log(msg: str, info: bool = False) -> str:
"""Logs a line to either the log file or stdout.
Expand All @@ -562,7 +565,7 @@ def log(msg: str, info: bool = False) -> str:
fire_event(JinjaLogDebug(msg=msg, node_info=get_node_info()))
return ""

@contextproperty
@contextproperty()
def run_started_at(self) -> Optional[datetime.datetime]:
"""`run_started_at` outputs the timestamp that this run started, e.g.
`2017-04-21 01:23:45.678`. The `run_started_at` variable is a Python
Expand Down Expand Up @@ -590,19 +593,19 @@ def run_started_at(self) -> Optional[datetime.datetime]:
else:
return None

@contextproperty
@contextproperty()
def invocation_id(self) -> Optional[str]:
"""invocation_id outputs a UUID generated for this dbt run (useful for
auditing)
"""
return get_invocation_id()

@contextproperty
@contextproperty()
def thread_id(self) -> str:
"""thread_id outputs an ID for the current thread (useful for auditing)"""
return threading.current_thread().name

@contextproperty
@contextproperty()
def modules(self) -> Dict[str, Any]:
"""The `modules` variable in the Jinja context contains useful Python
modules for operating on data.
Expand All @@ -627,7 +630,7 @@ def modules(self) -> Dict[str, Any]:
""" # noqa
return get_context_modules()

@contextproperty
@contextproperty()
def flags(self) -> Any:
"""The `flags` variable contains true/false values for flags provided
on the command line.
Expand All @@ -644,7 +647,7 @@ def flags(self) -> Any:
"""
return flags_module.get_flag_obj()

@contextmember
@contextmember()
@staticmethod
def print(msg: str) -> str:
"""Prints a line to stdout.
Expand All @@ -662,7 +665,7 @@ def print(msg: str) -> str:
print(msg)
return ""

@contextmember
@contextmember()
@staticmethod
def diff_of_two_dicts(
dict_a: Dict[str, List[str]], dict_b: Dict[str, List[str]]
Expand Down Expand Up @@ -691,7 +694,7 @@ def diff_of_two_dicts(
dict_diff.update({k: dict_a[k]})
return dict_diff

@contextmember
@contextmember()
@staticmethod
def local_md5(value: str) -> str:
"""Calculates an MD5 hash of the given string.
Expand Down
8 changes: 4 additions & 4 deletions core/dbt/context/configured.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, config: AdapterRequiredConfig) -> None:
super().__init__(config.to_target_dict(), config.cli_vars)
self.config = config

@contextproperty
@contextproperty()
def project_name(self) -> str:
return self.config.project_name

Expand Down Expand Up @@ -80,11 +80,11 @@ def __init__(self, config, project_name: str, schema_yaml_vars: Optional[SchemaY
self._project_name = project_name
self.schema_yaml_vars = schema_yaml_vars

@contextproperty
@contextproperty()
def var(self) -> ConfiguredVar:
return ConfiguredVar(self._ctx, self.config, self._project_name)

@contextmember
@contextmember()
def env_var(self, var: str, default: Optional[str] = None) -> str:
return_value = None
if var.startswith(SECRET_ENV_PREFIX):
Expand Down Expand Up @@ -113,7 +113,7 @@ class MacroResolvingContext(ConfiguredContext):
def __init__(self, config):
super().__init__(config)

@contextproperty
@contextproperty()
def var(self) -> ConfiguredVar:
return ConfiguredVar(self._ctx, self.config, self.config.project_name)

Expand Down
2 changes: 1 addition & 1 deletion core/dbt/context/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
self.node = node
self.manifest = manifest

@contextmember
@contextmember()
def doc(self, *args: str) -> str:
"""The `doc` function is used to reference docs blocks in schema.yml
files. It is analogous to the `ref` function. For more information,
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/context/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def to_dict(self):
dct.update(self.namespace)
return dct

@contextproperty
@contextproperty()
def context_macro_stack(self):
return self.macro_stack

Expand Down
Loading