Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #8031: Call materialization macro from adapter dispatch #8355

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230810-183216.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Call materialization macro from adapter dispatch
time: 2023-08-10T18:32:16.226142+01:00
custom:
Author: aranke
Issue: "8031"
37 changes: 10 additions & 27 deletions core/dbt/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import os
import re
import tempfile
import threading
from ast import literal_eval
from collections import deque
from contextlib import contextmanager
from itertools import chain, islice
from typing import List, Union, Set, Optional, Dict, Any, Iterator, Type, NoReturn, Tuple, Callable
Expand All @@ -16,17 +16,8 @@
import jinja2.parser
import jinja2.sandbox

from dbt.utils import (
get_dbt_macro_name,
get_docs_macro_name,
get_materialization_macro_name,
get_test_macro_name,
deep_map_render,
)

from dbt.clients._jinja_blocks import BlockIterator, BlockData, BlockTag
from dbt.contracts.graph.nodes import GenericTestNode

from dbt.exceptions import (
CaughtMacroError,
CaughtMacroErrorWithNodeError,
Expand All @@ -42,7 +33,13 @@
)
from dbt.flags import get_flags
from dbt.node_types import ModelLanguage

from dbt.utils import (
get_dbt_macro_name,
get_docs_macro_name,
get_materialization_macro_name,
get_test_macro_name,
deep_map_render,
)

SUPPORTED_LANG_ARG = jinja2.nodes.Name("supported_languages", "param")

Expand Down Expand Up @@ -259,22 +256,8 @@ def call_macro(self, *args, **kwargs):
return e.value


class MacroStack(threading.local):
def __init__(self):
super().__init__()
self.call_stack = []

@property
def depth(self) -> int:
return len(self.call_stack)

def push(self, name):
self.call_stack.append(name)

def pop(self, name):
got = self.call_stack.pop()
if got != name:
raise DbtInternalError(f"popped {got}, expected {name}")
class MacroStack(deque):
aranke marked this conversation as resolved.
Show resolved Hide resolved
pass


class MacroGenerator(BaseMacroGenerator):
Expand Down
49 changes: 29 additions & 20 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Type,
Iterable,
Mapping,
Tuple,
)
from typing_extensions import Protocol

Expand Down Expand Up @@ -156,6 +157,7 @@ def dispatch(
self,
macro_name: str,
macro_namespace: Optional[str] = None,
stack: Optional[MacroStack] = None,
packages: Optional[List[str]] = None, # eventually remove since it's fully deprecated
) -> MacroGenerator:
search_packages: List[Optional[str]]
Expand All @@ -174,30 +176,37 @@ def dispatch(
raise MacroDispatchArgError(macro_name)

search_packages = self._get_search_packages(macro_namespace)

attempts = []
macro = None
potential_macros: List[Tuple[Optional[str], str]] = []
failed_macros: List[Tuple[Optional[str], str]] = []

for package_name in search_packages:
for prefix in self._get_adapter_macro_prefixes():
search_name = f"{prefix}__{macro_name}"
try:
# this uses the namespace from the context
macro = self._namespace.get_from_package(package_name, search_name)
except CompilationError:
# Only raise CompilationError if macro is not found in
# any package
macro = None

if package_name is None:
attempts.append(search_name)
else:
attempts.append(f"{package_name}.{search_name}")

if macro is not None:
if macro_name.startswith("materialization_"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we made a fix recently for scenarios where macros that started with "materialization_" were incorrectly being picked up as materializations. For example, this was being flagged as a materialization instead of a macro:

{% macro materialization_setup() %}
    ...
{% endmacro %}

Have you verified this doesn't reintroduce that issue?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have tests for that case now, so this PR shouldn't reintroduce the issue (or if it does, it'll be obvious before merging).

#8181

potential_macros.append((package_name, macro_name))
potential_macros.append(("dbt", macro_name))
else:
for prefix in self._get_adapter_macro_prefixes():
potential_macros.append((package_name, f"{prefix}__{macro_name}"))

for package_name, search_name in potential_macros:
try:
macro = self._namespace.get_from_package(package_name, search_name)
if macro:
macro.stack = stack
except CompilationError:
# Only raise CompilationError if macro is not found in
# any package
pass
finally:
if macro:
return macro
else:
failed_macros.append((package_name, search_name))

searched = ", ".join(repr(a) for a in attempts)
msg = f"In dispatch: No macro named '{macro_name}' found within namespace: '{macro_namespace}'\n Searched for: {searched}"
msg = (
f"In dispatch: No macro named '{macro_name}' found within namespace: '{macro_namespace}'\n"
f"Searched for: {failed_macros}"
)
raise CompilationError(msg)


Expand Down
5 changes: 3 additions & 2 deletions core/dbt/task/clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import AbstractSet, Any, List, Iterable, Set

from dbt.adapters.base import BaseRelation
from dbt.clients.jinja import MacroGenerator
from dbt.context.providers import generate_runtime_model_context
from dbt.contracts.results import RunStatus, RunResult
from dbt.dataclass_schema import dbtClassMixin
Expand Down Expand Up @@ -80,7 +79,9 @@ def execute(self, model, manifest):

hook_ctx = self.adapter.pre_model_hook(context_config)
try:
result = MacroGenerator(materialization_macro, context)()
result = context["adapter"].dispatch(
materialization_macro.name, stack=context["context_macro_stack"]
)()
finally:
self.adapter.post_model_hook(context_config, hook_ctx)

Expand Down
42 changes: 19 additions & 23 deletions core/dbt/task/run.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,18 @@
import functools
import threading
import time
from datetime import datetime
from typing import List, Dict, Any, Iterable, Set, Tuple, Optional, AbstractSet

from dbt.dataclass_schema import dbtClassMixin

from .compile import CompileRunner, CompileTask

from .printer import (
print_run_end_messages,
get_counts,
)
from datetime import datetime
from dbt import tracking
from dbt import utils
from dbt.adapters.base import BaseRelation
from dbt.clients.jinja import MacroGenerator
from dbt.context.providers import generate_runtime_model_context
from dbt.contracts.graph.model_config import Hook
from dbt.contracts.graph.nodes import HookNode, ResultNode
from dbt.contracts.results import NodeStatus, RunResult, RunStatus, RunningStatus, BaseResult
from dbt.exceptions import (
CompilationError,
DbtInternalError,
MissingMaterializationError,
DbtRuntimeError,
DbtValidationError,
)
from dbt.dataclass_schema import dbtClassMixin
from dbt.events.base_types import EventLevel
from dbt.events.functions import fire_event, get_invocation_id
from dbt.events.types import (
DatabaseErrorRunningHook,
Expand All @@ -38,17 +24,28 @@
LogHookEndLine,
LogHookStartLine,
)
from dbt.events.base_types import EventLevel
from dbt.exceptions import (
CompilationError,
DbtInternalError,
MissingMaterializationError,
DbtRuntimeError,
DbtValidationError,
)
from dbt.graph import ResourceTypeSelector
from dbt.hooks import get_hook_dict
from dbt.logger import (
TextOnly,
HookMetadata,
UniqueID,
TimestampNamed,
DbtModelState,
)
from dbt.graph import ResourceTypeSelector
from dbt.hooks import get_hook_dict
from dbt.node_types import NodeType, RunHookType
from dbt.task.compile import CompileRunner, CompileTask
from dbt.task.printer import (
print_run_end_messages,
get_counts,
)


class Timer:
Expand Down Expand Up @@ -288,8 +285,8 @@ def execute(self, model, manifest):

hook_ctx = self.adapter.pre_model_hook(context_config)
try:
result = MacroGenerator(
materialization_macro, context, stack=context["context_macro_stack"]
result = context["adapter"].dispatch(
materialization_macro.name, stack=context["context_macro_stack"]
)()
finally:
self.adapter.post_model_hook(context_config, hook_ctx)
Expand Down Expand Up @@ -327,7 +324,6 @@ def _hook_keyfunc(self, hook: HookNode) -> Tuple[str, Optional[int]]:
return package_name, hook.index

def get_hooks_by_type(self, hook_type: RunHookType) -> List[HookNode]:

if self.manifest is None:
raise DbtInternalError("self.manifest was None in get_hooks_by_type")

Expand Down
2 changes: 1 addition & 1 deletion tests/functional/materializations/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@
"""

override_view_default_dep__macros__default_view_sql = """
{%- materialization view, default -%}
{%- materialization view, adapter = 'postgres' -%}
{{ exceptions.raise_compiler_error('intentionally raising an error in the default view materialization') }}
{%- endmaterialization -%}

Expand Down
14 changes: 14 additions & 0 deletions tests/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,20 @@ def test_macro_namespace_duplicates(config_postgres, manifest_fx):
mn.add_macros(mock_macro("macro_a", "dbt"), {})


def test_macro_stack(config_postgres, manifest_fx):
stack = MacroStack()
stack.append("foo")
stack.append("bar")
mn = macros.MacroNamespaceBuilder("root", "search", stack, ["dbt_postgres", "dbt"])
mn.add_macros(manifest_fx.macros.values(), {})

stack = mn.thread_ctx
assert len(stack) == 2
assert stack.pop() == "bar"
assert stack.pop() == "foo"
assert len(stack) == 0


def test_macro_namespace(config_postgres, manifest_fx):
mn = macros.MacroNamespaceBuilder("root", "search", MacroStack(), ["dbt_postgres", "dbt"])

Expand Down