diff --git a/core/dbt/include/global_project/macros/adapters/common.sql b/core/dbt/include/global_project/macros/adapters/common.sql index 6bf15509175..db30c2ba286 100644 --- a/core/dbt/include/global_project/macros/adapters/common.sql +++ b/core/dbt/include/global_project/macros/adapters/common.sql @@ -69,6 +69,10 @@ {%- endmacro %} {% macro default__create_table_as(temporary, relation, sql) -%} + {%- set sql_header = config.get('sql_header', none) -%} + + {{ sql_header if sql_header is not none }} + create {% if temporary: -%}temporary{%- endif %} table {{ relation.include(database=(not temporary), schema=(not temporary)) }} as ( @@ -81,6 +85,10 @@ {%- endmacro %} {% macro default__create_view_as(relation, sql) -%} + {%- set sql_header = config.get('sql_header', none) -%} + + {{ sql_header if sql_header is not none }} + create view {{ relation }} as ( {{ sql }} ); @@ -269,3 +277,7 @@ {% do return(tmp_relation) %} {% endmacro %} + +{% macro set_sql_header(config) -%} + {{ config.set('sql_header', caller()) }} +{%- endmacro %} diff --git a/core/dbt/source_config.py b/core/dbt/source_config.py index eacbbabcad9..dd1efa00bdb 100644 --- a/core/dbt/source_config.py +++ b/core/dbt/source_config.py @@ -16,7 +16,7 @@ class SourceConfig: 'unique_key', 'database', 'severity', - + 'sql_header', 'incremental_strategy', # snapshots diff --git a/plugins/bigquery/dbt/include/bigquery/macros/adapters.sql b/plugins/bigquery/dbt/include/bigquery/macros/adapters.sql index 6f9698a4ecf..d67b3f47a08 100644 --- a/plugins/bigquery/dbt/include/bigquery/macros/adapters.sql +++ b/plugins/bigquery/dbt/include/bigquery/macros/adapters.sql @@ -61,6 +61,10 @@ {%- set raw_persist_docs = config.get('persist_docs', {}) -%} {%- set raw_kms_key_name = config.get('kms_key_name', none) -%} {%- set raw_labels = config.get('labels', []) -%} + {%- set sql_header = config.get('sql_header', none) -%} + + {{ sql_header if sql_header is not none }} + create or replace table {{ relation }} {{ partition_by(raw_partition_by) }} {{ cluster_by(raw_cluster_by) }} @@ -76,6 +80,9 @@ {% macro bigquery__create_view_as(relation, sql) -%} {%- set raw_persist_docs = config.get('persist_docs', {}) -%} {%- set raw_labels = config.get('labels', []) -%} + {%- set sql_header = config.get('sql_header', none) -%} + + {{ sql_header if sql_header is not none }} create or replace view {{ relation }} {{ bigquery_table_options(persist_docs=raw_persist_docs, temporary=false, labels=raw_labels) }} diff --git a/plugins/postgres/dbt/include/postgres/macros/adapters.sql b/plugins/postgres/dbt/include/postgres/macros/adapters.sql index 9f7ac647693..c95cb4b0b2c 100644 --- a/plugins/postgres/dbt/include/postgres/macros/adapters.sql +++ b/plugins/postgres/dbt/include/postgres/macros/adapters.sql @@ -1,5 +1,8 @@ {% macro postgres__create_table_as(temporary, relation, sql) -%} {%- set unlogged = config.get('unlogged', default=false) -%} + {%- set sql_header = config.get('sql_header', none) -%} + + {{ sql_header if sql_header is not none }} create {% if temporary -%} temporary diff --git a/plugins/redshift/dbt/include/redshift/macros/adapters.sql b/plugins/redshift/dbt/include/redshift/macros/adapters.sql index d274b9ba7d6..ea9ec5da21b 100644 --- a/plugins/redshift/dbt/include/redshift/macros/adapters.sql +++ b/plugins/redshift/dbt/include/redshift/macros/adapters.sql @@ -37,6 +37,9 @@ {%- set _sort = config.get( 'sort', validator=validation.any[list, basestring]) -%} + {%- set sql_header = config.get('sql_header', none) -%} + + {{ sql_header if sql_header is not none }} create {% if temporary -%}temporary{%- endif %} table {{ relation.include(database=(not temporary), schema=(not temporary)) }} @@ -51,6 +54,9 @@ {% macro redshift__create_view_as(relation, sql) -%} {% set bind_qualifier = '' if config.get('bind', default=True) else 'with no schema binding' %} + {%- set sql_header = config.get('sql_header', none) -%} + + {{ sql_header if sql_header is not none }} create view {{ relation }} as ( {{ sql }} diff --git a/plugins/snowflake/dbt/include/snowflake/macros/adapters.sql b/plugins/snowflake/dbt/include/snowflake/macros/adapters.sql index 8c7f48bcdd4..ffa7d6bae62 100644 --- a/plugins/snowflake/dbt/include/snowflake/macros/adapters.sql +++ b/plugins/snowflake/dbt/include/snowflake/macros/adapters.sql @@ -11,6 +11,9 @@ {% else %} {%- set cluster_by_string = none -%} {%- endif -%} + {%- set sql_header = config.get('sql_header', none) -%} + + {{ sql_header if sql_header is not none }} create or replace {% if temporary -%} temporary @@ -38,6 +41,9 @@ {% macro snowflake__create_view_as(relation, sql) -%} {%- set secure = config.get('secure', default=false) -%} {%- set copy_grants = config.get('copy_grants', default=false) -%} + {%- set sql_header = config.get('sql_header', none) -%} + + {{ sql_header if sql_header is not none }} create or replace {% if secure -%} secure {%- endif %} view {{ relation }} {% if copy_grants -%} copy grants {%- endif %} as ( diff --git a/test/integration/022_bigquery_test/models/sql_header_model.sql b/test/integration/022_bigquery_test/models/sql_header_model.sql new file mode 100644 index 00000000000..e49d82c4bc0 --- /dev/null +++ b/test/integration/022_bigquery_test/models/sql_header_model.sql @@ -0,0 +1,14 @@ +{{ config(materialized="table") }} + +{# This will fail if it is not extracted correctly #} +{% call set_sql_header(config) %} + CREATE TEMPORARY FUNCTION a_to_b(str STRING) + RETURNS STRING AS ( + CASE + WHEN LOWER(str) = 'a' THEN 'b' + ELSE str + END + ); +{% endcall %} + +select a_to_b(dupe) as dupe from {{ ref('view_model') }} diff --git a/test/integration/022_bigquery_test/test_simple_bigquery_view.py b/test/integration/022_bigquery_test/test_simple_bigquery_view.py index b1b0dd9b58f..e25704ba1b4 100644 --- a/test/integration/022_bigquery_test/test_simple_bigquery_view.py +++ b/test/integration/022_bigquery_test/test_simple_bigquery_view.py @@ -53,7 +53,7 @@ def test__bigquery_simple_run(self): self.run_dbt(['seed', '--full-refresh']) results = self.run_dbt() # Bump expected number of results when adding new model - self.assertEqual(len(results), 7) + self.assertEqual(len(results), 8) self.assert_nondupes_pass() @@ -64,7 +64,7 @@ class TestUnderscoreBigQueryRun(TestBaseBigQueryRun): def test_bigquery_run_twice(self): self.run_dbt(['seed']) results = self.run_dbt() - self.assertEqual(len(results), 7) + self.assertEqual(len(results), 8) results = self.run_dbt() - self.assertEqual(len(results), 7) + self.assertEqual(len(results), 8) self.assert_nondupes_pass()