Skip to content

Commit

Permalink
add typing to partially typed methods in runnable.py (#8569)
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk authored Sep 6, 2023
1 parent 9aff3ca commit 7578150
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20230906-132435.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Fix untyped functions in task/runnable.py (mypy warning)
time: 2023-09-06T13:24:35.448782-04:00
custom:
Author: michelleark
Issue: "8402"
25 changes: 14 additions & 11 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
RunningStatus,
RunResult,
RunStatus,
BaseResult,
)
from dbt.contracts.state import PreviousState
from dbt.events.contextvars import log_contextvars, task_contextvars
Expand All @@ -42,7 +43,7 @@
FailFastError,
)
from dbt.flags import get_flags
from dbt.graph import GraphQueue, NodeSelector, SelectionSpec, parse_difference
from dbt.graph import GraphQueue, NodeSelector, SelectionSpec, parse_difference, UniqueId
from dbt.logger import (
DbtProcessState,
TextOnly,
Expand All @@ -53,7 +54,7 @@
NodeCount,
)
from dbt.parser.manifest import write_manifest
from dbt.task.base import ConfiguredTask
from dbt.task.base import ConfiguredTask, BaseRunner
from .printer import (
print_run_result_error,
print_run_end_messages,
Expand All @@ -66,13 +67,13 @@
class GraphRunnableTask(ConfiguredTask):
MARK_DEPENDENT_ERRORS_STATUSES = [NodeStatus.Error]

def __init__(self, args, config, manifest):
def __init__(self, args, config, manifest) -> None:
super().__init__(args, config, manifest)
self._flattened_nodes: Optional[List[ResultNode]] = None
self._raise_next_tick = None
self._skipped_children = {}
self._raise_next_tick: Optional[DbtRuntimeError] = None
self._skipped_children: Dict[str, Optional[RunResult]] = {}
self.job_queue: Optional[GraphQueue] = None
self.node_results = []
self.node_results: List[BaseResult] = []
self.num_nodes: int = 0
self.previous_state: Optional[PreviousState] = None
self.previous_defer_state: Optional[PreviousState] = None
Expand Down Expand Up @@ -168,7 +169,7 @@ def get_runner_type(self, node):
def result_path(self):
return os.path.join(self.config.project_target_path, RESULT_FILE_NAME)

def get_runner(self, node):
def get_runner(self, node) -> BaseRunner:
adapter = get_adapter(self.config)
run_count: int = 0
num_nodes: int = 0
Expand All @@ -184,7 +185,7 @@ def get_runner(self, node):
cls = self.get_runner_type(node)
return cls(self.config, adapter, node, run_count, num_nodes)

def call_runner(self, runner):
def call_runner(self, runner: BaseRunner) -> RunResult:
uid_context = UniqueID(runner.node.unique_id)
with RUNNING_STATE, uid_context, log_contextvars(node_info=runner.node.node_info):
startctx = TimestampNamed("node_started_at")
Expand Down Expand Up @@ -292,7 +293,7 @@ def callback(result):

return

def _handle_result(self, result):
def _handle_result(self, result: RunResult):
"""Mark the result as completed, insert the `CompileResultNode` into
the manifest, and mark any descendants (potentially with a 'cause' if
the result was an ephemeral model) as skipped.
Expand Down Expand Up @@ -388,10 +389,12 @@ def execute_nodes(self):

return self.node_results

def _mark_dependent_errors(self, node_id, result, cause):
def _mark_dependent_errors(
self, node_id: str, result: RunResult, cause: Optional[RunResult]
) -> None:
if self.graph is None:
raise DbtInternalError("graph is None in _mark_dependent_errors")
for dep_node_id in self.graph.get_dependent_nodes(node_id):
for dep_node_id in self.graph.get_dependent_nodes(UniqueId(node_id)):
self._skipped_children[dep_node_id] = cause

def populate_adapter_cache(
Expand Down

0 comments on commit 7578150

Please sign in to comment.