diff --git a/.changes/unreleased/Fixes-20230810-183216.yaml b/.changes/unreleased/Fixes-20230810-183216.yaml new file mode 100644 index 00000000000..715bc536b2d --- /dev/null +++ b/.changes/unreleased/Fixes-20230810-183216.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Call materialization macro from adapter dispatch +time: 2023-08-10T18:32:16.226142+01:00 +custom: + Author: aranke + Issue: "8031" diff --git a/core/dbt/clients/jinja.py b/core/dbt/clients/jinja.py index 37097dbd805..1c14b9048fa 100644 --- a/core/dbt/clients/jinja.py +++ b/core/dbt/clients/jinja.py @@ -271,10 +271,11 @@ def depth(self) -> int: def push(self, name): self.call_stack.append(name) - def pop(self, name): + def pop(self, name: Optional[str] = None): got = self.call_stack.pop() - if got != name: + if name and got != name: raise DbtInternalError(f"popped {got}, expected {name}") + return got class MacroGenerator(BaseMacroGenerator): diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 6b981091682..2468ed60c65 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -11,13 +11,18 @@ Type, Iterable, Mapping, + Tuple, ) + +import agate from typing_extensions import Protocol +from dbt import selected_resources from dbt.adapters.base.column import Column from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names from dbt.clients import agate_helper from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack +from dbt.config import IsFQNResource from dbt.config import RuntimeConfig, Project from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER from dbt.context.base import contextmember, contextproperty, Var @@ -29,6 +34,7 @@ from dbt.context.manifest import ManifestContext from dbt.contracts.connection import AdapterResponse from dbt.contracts.graph.manifest import Manifest, Disabled +from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference from dbt.contracts.graph.nodes import ( Macro, Exposure, @@ -40,7 +46,6 @@ AccessType, SemanticModel, ) -from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference from dbt.contracts.graph.unparsed import NodeVersion from dbt.events.functions import get_metadata_vars from dbt.exceptions import ( @@ -69,16 +74,9 @@ DbtValidationError, DbtReferenceError, ) -from dbt.config import IsFQNResource from dbt.node_types import NodeType, ModelLanguage - from dbt.utils import merge, AttrDict, MultiDict, args_to_dict, cast_to_str -from dbt import selected_resources - -import agate - - _MISSING = object() @@ -156,6 +154,7 @@ def dispatch( self, macro_name: str, macro_namespace: Optional[str] = None, + stack: Optional[MacroStack] = None, packages: Optional[List[str]] = None, # eventually remove since it's fully deprecated ) -> MacroGenerator: search_packages: List[Optional[str]] @@ -174,30 +173,40 @@ def dispatch( raise MacroDispatchArgError(macro_name) search_packages = self._get_search_packages(macro_namespace) - - attempts = [] + macro = None + potential_macros: List[Tuple[Optional[str], str]] = [] + failed_macros: List[Tuple[Optional[str], str]] = [] for package_name in search_packages: - for prefix in self._get_adapter_macro_prefixes(): - search_name = f"{prefix}__{macro_name}" - try: - # this uses the namespace from the context - macro = self._namespace.get_from_package(package_name, search_name) - except CompilationError: - # Only raise CompilationError if macro is not found in - # any package - macro = None - - if package_name is None: - attempts.append(search_name) - else: - attempts.append(f"{package_name}.{search_name}") - - if macro is not None: + if macro_name.startswith("materialization_"): + potential_macros.append((package_name, macro_name)) + potential_macros.append(("dbt", macro_name)) + else: + for prefix in self._get_adapter_macro_prefixes(): + potential_macros.append((package_name, f"{prefix}__{macro_name}")) + potential_macros.append( + (package_name, f"materialization_{macro_name}_{prefix}") + ) + + for package_name, search_name in potential_macros: + try: + macro = self._namespace.get_from_package(package_name, search_name) + if macro: + macro.stack = stack + except CompilationError: + # Only raise CompilationError if macro is not found in + # any package + pass + finally: + if macro: return macro + else: + failed_macros.append((package_name, search_name)) - searched = ", ".join(repr(a) for a in attempts) - msg = f"In dispatch: No macro named '{macro_name}' found within namespace: '{macro_namespace}'\n Searched for: {searched}" + msg = ( + f"In dispatch: No macro named '{macro_name}' found within namespace: '{macro_namespace}'\n" + f"Searched for: {failed_macros}" + ) raise CompilationError(msg) diff --git a/core/dbt/task/clone.py b/core/dbt/task/clone.py index 87fb1a78106..135280a2fa7 100644 --- a/core/dbt/task/clone.py +++ b/core/dbt/task/clone.py @@ -2,7 +2,6 @@ from typing import AbstractSet, Any, List, Iterable, Set from dbt.adapters.base import BaseRelation -from dbt.clients.jinja import MacroGenerator from dbt.context.providers import generate_runtime_model_context from dbt.contracts.results import RunStatus, RunResult from dbt.dataclass_schema import dbtClassMixin @@ -80,7 +79,9 @@ def execute(self, model, manifest): hook_ctx = self.adapter.pre_model_hook(context_config) try: - result = MacroGenerator(materialization_macro, context)() + result = context["adapter"].dispatch( + materialization_macro.name, stack=context["context_macro_stack"] + )() finally: self.adapter.post_model_hook(context_config, hook_ctx) diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 5c2f2d1f094..ff8f7aa59c1 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -1,32 +1,18 @@ import functools import threading import time +from datetime import datetime from typing import List, Dict, Any, Iterable, Set, Tuple, Optional, AbstractSet -from dbt.dataclass_schema import dbtClassMixin - -from .compile import CompileRunner, CompileTask - -from .printer import ( - print_run_end_messages, - get_counts, -) -from datetime import datetime from dbt import tracking from dbt import utils from dbt.adapters.base import BaseRelation -from dbt.clients.jinja import MacroGenerator from dbt.context.providers import generate_runtime_model_context from dbt.contracts.graph.model_config import Hook from dbt.contracts.graph.nodes import HookNode, ResultNode from dbt.contracts.results import NodeStatus, RunResult, RunStatus, RunningStatus, BaseResult -from dbt.exceptions import ( - CompilationError, - DbtInternalError, - MissingMaterializationError, - DbtRuntimeError, - DbtValidationError, -) +from dbt.dataclass_schema import dbtClassMixin +from dbt.events.base_types import EventLevel from dbt.events.functions import fire_event, get_invocation_id from dbt.events.types import ( DatabaseErrorRunningHook, @@ -38,7 +24,15 @@ LogHookEndLine, LogHookStartLine, ) -from dbt.events.base_types import EventLevel +from dbt.exceptions import ( + CompilationError, + DbtInternalError, + MissingMaterializationError, + DbtRuntimeError, + DbtValidationError, +) +from dbt.graph import ResourceTypeSelector +from dbt.hooks import get_hook_dict from dbt.logger import ( TextOnly, HookMetadata, @@ -46,9 +40,12 @@ TimestampNamed, DbtModelState, ) -from dbt.graph import ResourceTypeSelector -from dbt.hooks import get_hook_dict from dbt.node_types import NodeType, RunHookType +from dbt.task.compile import CompileRunner, CompileTask +from dbt.task.printer import ( + print_run_end_messages, + get_counts, +) class Timer: @@ -288,8 +285,8 @@ def execute(self, model, manifest): hook_ctx = self.adapter.pre_model_hook(context_config) try: - result = MacroGenerator( - materialization_macro, context, stack=context["context_macro_stack"] + result = context["adapter"].dispatch( + materialization_macro.name, stack=context["context_macro_stack"] )() finally: self.adapter.post_model_hook(context_config, hook_ctx) @@ -327,7 +324,6 @@ def _hook_keyfunc(self, hook: HookNode) -> Tuple[str, Optional[int]]: return package_name, hook.index def get_hooks_by_type(self, hook_type: RunHookType) -> List[HookNode]: - if self.manifest is None: raise DbtInternalError("self.manifest was None in get_hooks_by_type") diff --git a/tests/functional/materializations/conftest.py b/tests/functional/materializations/conftest.py index b808c1a6a7b..1f1ed85b1ee 100644 --- a/tests/functional/materializations/conftest.py +++ b/tests/functional/materializations/conftest.py @@ -242,7 +242,7 @@ """ override_view_default_dep__macros__default_view_sql = """ -{%- materialization view, default -%} +{%- materialization view, adapter = 'postgres' -%} {{ exceptions.raise_compiler_error('intentionally raising an error in the default view materialization') }} {%- endmaterialization -%} diff --git a/tests/functional/materializations/test_nested_materializations.py b/tests/functional/materializations/test_nested_materializations.py new file mode 100644 index 00000000000..8db9512cd0d --- /dev/null +++ b/tests/functional/materializations/test_nested_materializations.py @@ -0,0 +1,44 @@ +import pytest + +from dbt.tests.util import run_dbt + +parent_materialization = """ +{% materialization parent, default %} + {%- set target_relation = this.incorporate(type='table') %} + {% call statement('main') -%} + set session time zone 'Asia/Kolkata'; + {%- endcall %} + {{ return({'relations': [target_relation]}) }} +{% endmaterialization %} +""" + +child_materialization = """ +{% materialization child, default %} + {%- set relations = adapter.dispatch('parent')() %} + {{ return({'relations': relations['relations'] }) }} +{% endmaterialization %} +""" + +my_model_sql = """ +{{ config(materialized='child') }} +select current_setting('timezone') as current_tz +""" + + +class TestMaterializationOverride: + @pytest.fixture(scope="class") + def macros(self): + return { + "parent.sql": parent_materialization, + "child.sql": child_materialization, + } + + @pytest.fixture(scope="class") + def models(self): + return { + "model.sql": my_model_sql, + } + + def test_foo(self, project): + res = run_dbt(["run"]) + print(res) diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index b51e8e76de5..2470926e77d 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -464,6 +464,23 @@ def test_macro_namespace_duplicates(config_postgres, manifest_fx): mn.add_macros(mock_macro("macro_a", "dbt"), {}) +def test_macro_stack(config_postgres, manifest_fx): + stack = MacroStack() + stack.push("foo") + stack.push("bar") + mn = macros.MacroNamespaceBuilder("root", "search", stack, ["dbt_postgres", "dbt"]) + mn.add_macros(manifest_fx.macros.values(), {}) + + stack = mn.thread_ctx + assert stack.depth == 2 + assert stack.pop() == "bar" + + with pytest.raises(dbt.exceptions.DbtInternalError): + stack.pop("bar") + + assert stack.depth == 0 + + def test_macro_namespace(config_postgres, manifest_fx): mn = macros.MacroNamespaceBuilder("root", "search", MacroStack(), ["dbt_postgres", "dbt"])