diff --git a/CHANGELOG.md b/CHANGELOG.md index fedcfd19f..c09835845 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,9 +2,10 @@ ### Fixes - Closes the connection properly ([#280](https://github.com/dbt-labs/dbt-spark/issues/280), [#285](https://github.com/dbt-labs/dbt-spark/pull/285)) +- Make internal macros use macro dispatch to be overridable in child adapters ([#319](https://github.com/dbt-labs/dbt-spark/issues/319), [#320](https://github.com/dbt-labs/dbt-spark/pull/320)) ### Contributors -- [@ueshin](https://github.com/ueshin) ([#285](https://github.com/dbt-labs/dbt-spark/pull/285)) +- [@ueshin](https://github.com/ueshin) ([#285](https://github.com/dbt-labs/dbt-spark/pull/285), [#320](https://github.com/dbt-labs/dbt-spark/pull/320)) ## dbt-spark 1.0.0 (December 3, 2021) diff --git a/dbt/include/spark/macros/adapters.sql b/dbt/include/spark/macros/adapters.sql index 2542af811..e96501c45 100644 --- a/dbt/include/spark/macros/adapters.sql +++ b/dbt/include/spark/macros/adapters.sql @@ -1,11 +1,20 @@ {% macro file_format_clause() %} + {{ return(adapter.dispatch('file_format_clause', 'dbt')()) }} +{%- endmacro -%} + +{% macro spark__file_format_clause() %} {%- set file_format = config.get('file_format', validator=validation.any[basestring]) -%} {%- if file_format is not none %} using {{ file_format }} {%- endif %} {%- endmacro -%} + {% macro location_clause() %} + {{ return(adapter.dispatch('location_clause', 'dbt')()) }} +{%- endmacro -%} + +{% macro spark__location_clause() %} {%- set location_root = config.get('location_root', validator=validation.any[basestring]) -%} {%- set identifier = model['alias'] -%} {%- if location_root is not none %} @@ -13,7 +22,12 @@ {%- endif %} {%- endmacro -%} + {% macro options_clause() -%} + {{ return(adapter.dispatch('options_clause', 'dbt')()) }} +{%- endmacro -%} + +{% macro spark__options_clause() -%} {%- set options = config.get('options') -%} {%- if config.get('file_format') == 'hudi' -%} {%- set unique_key = config.get('unique_key') -%} @@ -35,7 +49,12 @@ {%- endif %} {%- endmacro -%} + {% macro comment_clause() %} + {{ return(adapter.dispatch('comment_clause', 'dbt')()) }} +{%- endmacro -%} + +{% macro spark__comment_clause() %} {%- set raw_persist_docs = config.get('persist_docs', {}) -%} {%- if raw_persist_docs is mapping -%} @@ -48,7 +67,12 @@ {% endif %} {%- endmacro -%} + {% macro partition_cols(label, required=false) %} + {{ return(adapter.dispatch('partition_cols', 'dbt')(label, required)) }} +{%- endmacro -%} + +{% macro spark__partition_cols(label, required=false) %} {%- set cols = config.get('partition_by', validator=validation.any[list, basestring]) -%} {%- if cols is not none %} {%- if cols is string -%} @@ -65,6 +89,10 @@ {% macro clustered_cols(label, required=false) %} + {{ return(adapter.dispatch('clustered_cols', 'dbt')(label, required)) }} +{%- endmacro -%} + +{% macro spark__clustered_cols(label, required=false) %} {%- set cols = config.get('clustered_by', validator=validation.any[list, basestring]) -%} {%- set buckets = config.get('buckets', validator=validation.any[int]) -%} {%- if (cols is not none) and (buckets is not none) %} @@ -80,6 +108,7 @@ {%- endif %} {%- endmacro -%} + {% macro fetch_tbl_properties(relation) -%} {% call statement('list_properties', fetch_result=True) -%} SHOW TBLPROPERTIES {{ relation }} @@ -88,12 +117,17 @@ {%- endmacro %} -{#-- We can't use temporary tables with `create ... as ()` syntax #} {% macro create_temporary_view(relation, sql) -%} + {{ return(adapter.dispatch('create_temporary_view', 'dbt')(relation, sql)) }} +{%- endmacro -%} + +{#-- We can't use temporary tables with `create ... as ()` syntax #} +{% macro spark__create_temporary_view(relation, sql) -%} create temporary view {{ relation.include(schema=false) }} as {{ sql }} {% endmacro %} + {% macro spark__create_table_as(temporary, relation, sql) -%} {% if temporary -%} {{ create_temporary_view(relation, sql) }} diff --git a/tests/unit/test_macros.py b/tests/unit/test_macros.py index 06ce202a7..220a74db7 100644 --- a/tests/unit/test_macros.py +++ b/tests/unit/test_macros.py @@ -15,7 +15,9 @@ def setUp(self): 'validation': mock.Mock(), 'model': mock.Mock(), 'exceptions': mock.Mock(), - 'config': mock.Mock() + 'config': mock.Mock(), + 'adapter': mock.Mock(), + 'return': lambda r: r, } self.default_context['config'].get = lambda key, default=None, **kwargs: self.config.get(key, default) @@ -24,6 +26,11 @@ def __get_template(self, template_filename): def __run_macro(self, template, name, temporary, relation, sql): self.default_context['model'].alias = relation + + def dispatch(macro_name, macro_namespace=None, packages=None): + return getattr(template.module, f'spark__{macro_name}') + self.default_context['adapter'].dispatch = dispatch + value = getattr(template.module, name)(temporary, relation, sql) return re.sub(r'\s\s+', ' ', value)