Skip to content

Commit

Permalink
Use events.contextvar because of multiprocessing unable to pickle Con…
Browse files Browse the repository at this point in the history
…textVar (#7949) (#7981)

Cherry-picked from commit 2e7c968

* Add task contextvars to events/contextvars.py

* Use events.contextvars instead of task.contextvars

* Changie
  • Loading branch information
gshank authored Jun 30, 2023
1 parent b187400 commit 98fcb4a
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 54 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230626-115838.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Move project_root contextvar into events.contextvars
time: 2023-06-26T11:58:38.965299-04:00
custom:
Author: gshank
Issue: "7937"
4 changes: 2 additions & 2 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
SeedExceedsLimitChecksumChanged,
ValidationWarning,
)
from dbt.events.contextvars import set_contextvars
from dbt.events.contextvars import set_log_contextvars
from dbt.flags import get_flags
from dbt.node_types import ModelLanguage, NodeType, AccessType

Expand Down Expand Up @@ -303,7 +303,7 @@ def node_info(self):
def update_event_status(self, **kwargs):
for k, v in kwargs.items():
self._event_status[k] = v
set_contextvars(node_info=self.node_info)
set_log_contextvars(node_info=self.node_info)

def clear_event_status(self):
self._event_status = dict()
Expand Down
77 changes: 54 additions & 23 deletions core/dbt/events/contextvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,79 +5,110 @@


LOG_PREFIX = "log_"
LOG_PREFIX_LEN = len(LOG_PREFIX)
TASK_PREFIX = "task_"

_log_context_vars: Dict[str, contextvars.ContextVar] = {}
_context_vars: Dict[str, contextvars.ContextVar] = {}


def get_contextvars() -> Dict[str, Any]:
def get_contextvars(prefix: str) -> Dict[str, Any]:
rv = {}
ctx = contextvars.copy_context()

prefix_len = len(prefix)
for k in ctx:
if k.name.startswith(LOG_PREFIX) and ctx[k] is not Ellipsis:
rv[k.name[LOG_PREFIX_LEN:]] = ctx[k]
if k.name.startswith(prefix) and ctx[k] is not Ellipsis:
rv[k.name[prefix_len:]] = ctx[k]

return rv


def get_node_info():
cvars = get_contextvars()
cvars = get_contextvars(LOG_PREFIX)
if "node_info" in cvars:
return cvars["node_info"]
else:
return {}


def clear_contextvars() -> None:
def get_project_root():
cvars = get_contextvars(TASK_PREFIX)
if "project_root" in cvars:
return cvars["project_root"]
else:
return None


def clear_contextvars(prefix: str) -> None:
ctx = contextvars.copy_context()
for k in ctx:
if k.name.startswith(LOG_PREFIX):
if k.name.startswith(prefix):
k.set(Ellipsis)


def set_log_contextvars(**kwargs: Any) -> Mapping[str, contextvars.Token]:
return set_contextvars(LOG_PREFIX, **kwargs)


def set_task_contextvars(**kwargs: Any) -> Mapping[str, contextvars.Token]:
return set_contextvars(TASK_PREFIX, **kwargs)


# put keys and values into context. Returns the contextvar.Token mapping
# Save and pass to reset_contextvars
def set_contextvars(**kwargs: Any) -> Mapping[str, contextvars.Token]:
def set_contextvars(prefix: str, **kwargs: Any) -> Mapping[str, contextvars.Token]:
cvar_tokens = {}
for k, v in kwargs.items():
log_key = f"{LOG_PREFIX}{k}"
log_key = f"{prefix}{k}"
try:
var = _log_context_vars[log_key]
var = _context_vars[log_key]
except KeyError:
var = contextvars.ContextVar(log_key, default=Ellipsis)
_log_context_vars[log_key] = var
_context_vars[log_key] = var

cvar_tokens[k] = var.set(v)

return cvar_tokens


# reset by Tokens
def reset_contextvars(**kwargs: contextvars.Token) -> None:
def reset_contextvars(prefix: str, **kwargs: contextvars.Token) -> None:
for k, v in kwargs.items():
log_key = f"{LOG_PREFIX}{k}"
var = _log_context_vars[log_key]
log_key = f"{prefix}{k}"
var = _context_vars[log_key]
var.reset(v)


# remove from contextvars
def unset_contextvars(*keys: str) -> None:
def unset_contextvars(prefix: str, *keys: str) -> None:
for k in keys:
if k in _log_context_vars:
log_key = f"{LOG_PREFIX}{k}"
_log_context_vars[log_key].set(Ellipsis)
if k in _context_vars:
log_key = f"{prefix}{k}"
_context_vars[log_key].set(Ellipsis)


# Context manager or decorator to set and unset the context vars
@contextlib.contextmanager
def log_contextvars(**kwargs: Any) -> Generator[None, None, None]:
context = get_contextvars()
context = get_contextvars(LOG_PREFIX)
saved = {k: context[k] for k in context.keys() & kwargs.keys()}

set_contextvars(LOG_PREFIX, **kwargs)
try:
yield
finally:
unset_contextvars(LOG_PREFIX, *kwargs.keys())
set_contextvars(LOG_PREFIX, **saved)


# Context manager for earlier in task.run
@contextlib.contextmanager
def task_contextvars(**kwargs: Any) -> Generator[None, None, None]:
context = get_contextvars(TASK_PREFIX)
saved = {k: context[k] for k in context.keys() & kwargs.keys()}

set_contextvars(**kwargs)
set_contextvars(TASK_PREFIX, **kwargs)
try:
yield
finally:
unset_contextvars(*kwargs.keys())
set_contextvars(**saved)
unset_contextvars(TASK_PREFIX, *kwargs.keys())
set_contextvars(TASK_PREFIX, **saved)
8 changes: 6 additions & 2 deletions core/dbt/graph/selector_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
DbtRuntimeError,
)
from dbt.node_types import NodeType
from dbt.task.contextvars import cv_project_root
from dbt.events.contextvars import get_project_root


SELECTOR_GLOB = "*"
Expand Down Expand Up @@ -326,7 +326,11 @@ class PathSelectorMethod(SelectorMethod):
def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[UniqueId]:
"""Yields nodes from included that match the given path."""
# get project root from contextvar
root = Path(cv_project_root.get())
project_root = get_project_root()
if project_root:
root = Path(project_root)
else:
root = Path.cwd()
paths = set(p.relative_to(root) for p in root.glob(selector))
for node, real_node in self.all_nodes(included_nodes):
ofp = Path(real_node.original_file_path)
Expand Down
3 changes: 0 additions & 3 deletions core/dbt/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from dbt.graph import Graph
from dbt.logger import log_manager
from .printer import print_run_result_error
from dbt.task.contextvars import cv_project_root


class NoneConfig:
Expand Down Expand Up @@ -76,8 +75,6 @@ def __init__(self, args, config, project=None):
self.args = args
self.config = config
self.project = config if isinstance(config, Project) else project
if self.config:
cv_project_root.set(self.config.project_root)

@classmethod
def pre_init_hook(cls, args):
Expand Down
6 changes: 0 additions & 6 deletions core/dbt/task/contextvars.py

This file was deleted.

41 changes: 23 additions & 18 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
EndRunResult,
NothingToDo,
)
from dbt.events.contextvars import log_contextvars
from dbt.events.contextvars import log_contextvars, task_contextvars
from dbt.contracts.graph.nodes import SourceDefinition, ResultNode
from dbt.contracts.results import NodeStatus, RunExecutionResult, RunningStatus
from dbt.contracts.state import PreviousState
Expand Down Expand Up @@ -422,25 +422,30 @@ def run(self):
"""
Run dbt for the query, based on the graph.
"""
self._runtime_initialize()
# We set up a context manager here with "task_contextvars" because we
# we need the project_root in runtime_initialize.
with task_contextvars(project_root=self.config.project_root):
self._runtime_initialize()

if self._flattened_nodes is None:
raise DbtInternalError("after _runtime_initialize, _flattened_nodes was still None")
if self._flattened_nodes is None:
raise DbtInternalError(
"after _runtime_initialize, _flattened_nodes was still None"
)

if len(self._flattened_nodes) == 0:
with TextOnly():
fire_event(Formatting(""))
warn_or_error(NothingToDo())
result = self.get_result(
results=[],
generated_at=datetime.utcnow(),
elapsed_time=0.0,
)
else:
with TextOnly():
fire_event(Formatting(""))
selected_uids = frozenset(n.unique_id for n in self._flattened_nodes)
result = self.execute_with_hooks(selected_uids)
if len(self._flattened_nodes) == 0:
with TextOnly():
fire_event(Formatting(""))
warn_or_error(NothingToDo())
result = self.get_result(
results=[],
generated_at=datetime.utcnow(),
elapsed_time=0.0,
)
else:
with TextOnly():
fire_event(Formatting(""))
selected_uids = frozenset(n.unique_id for n in self._flattened_nodes)
result = self.execute_with_hooks(selected_uids)

# We have other result types here too, including FreshnessResult
if isinstance(result, RunExecutionResult):
Expand Down

0 comments on commit 98fcb4a

Please sign in to comment.