diff --git a/.changes/unreleased/Under the Hood-20230906-132435.yaml b/.changes/unreleased/Under the Hood-20230906-132435.yaml new file mode 100644 index 00000000000..27b6da20177 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20230906-132435.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Fix untyped functions in task/runnable.py (mypy warning) +time: 2023-09-06T13:24:35.448782-04:00 +custom: + Author: michelleark + Issue: "8402" diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 9ae987cc2dc..cce7793b598 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -20,6 +20,7 @@ RunningStatus, RunResult, RunStatus, + BaseResult, ) from dbt.contracts.state import PreviousState from dbt.events.contextvars import log_contextvars, task_contextvars @@ -42,7 +43,7 @@ FailFastError, ) from dbt.flags import get_flags -from dbt.graph import GraphQueue, NodeSelector, SelectionSpec, parse_difference +from dbt.graph import GraphQueue, NodeSelector, SelectionSpec, parse_difference, UniqueId from dbt.logger import ( DbtProcessState, TextOnly, @@ -53,7 +54,7 @@ NodeCount, ) from dbt.parser.manifest import write_manifest -from dbt.task.base import ConfiguredTask +from dbt.task.base import ConfiguredTask, BaseRunner from .printer import ( print_run_result_error, print_run_end_messages, @@ -66,13 +67,13 @@ class GraphRunnableTask(ConfiguredTask): MARK_DEPENDENT_ERRORS_STATUSES = [NodeStatus.Error] - def __init__(self, args, config, manifest): + def __init__(self, args, config, manifest) -> None: super().__init__(args, config, manifest) self._flattened_nodes: Optional[List[ResultNode]] = None - self._raise_next_tick = None - self._skipped_children = {} + self._raise_next_tick: Optional[DbtRuntimeError] = None + self._skipped_children: Dict[str, Optional[RunResult]] = {} self.job_queue: Optional[GraphQueue] = None - self.node_results = [] + self.node_results: List[BaseResult] = [] self.num_nodes: int = 0 self.previous_state: Optional[PreviousState] = None self.previous_defer_state: Optional[PreviousState] = None @@ -168,7 +169,7 @@ def get_runner_type(self, node): def result_path(self): return os.path.join(self.config.project_target_path, RESULT_FILE_NAME) - def get_runner(self, node): + def get_runner(self, node) -> BaseRunner: adapter = get_adapter(self.config) run_count: int = 0 num_nodes: int = 0 @@ -184,7 +185,7 @@ def get_runner(self, node): cls = self.get_runner_type(node) return cls(self.config, adapter, node, run_count, num_nodes) - def call_runner(self, runner): + def call_runner(self, runner: BaseRunner) -> RunResult: uid_context = UniqueID(runner.node.unique_id) with RUNNING_STATE, uid_context, log_contextvars(node_info=runner.node.node_info): startctx = TimestampNamed("node_started_at") @@ -292,7 +293,7 @@ def callback(result): return - def _handle_result(self, result): + def _handle_result(self, result: RunResult): """Mark the result as completed, insert the `CompileResultNode` into the manifest, and mark any descendants (potentially with a 'cause' if the result was an ephemeral model) as skipped. @@ -388,10 +389,12 @@ def execute_nodes(self): return self.node_results - def _mark_dependent_errors(self, node_id, result, cause): + def _mark_dependent_errors( + self, node_id: str, result: RunResult, cause: Optional[RunResult] + ) -> None: if self.graph is None: raise DbtInternalError("graph is None in _mark_dependent_errors") - for dep_node_id in self.graph.get_dependent_nodes(node_id): + for dep_node_id in self.graph.get_dependent_nodes(UniqueId(node_id)): self._skipped_children[dep_node_id] = cause def populate_adapter_cache(