diff --git a/core/dbt/include/global_project/macros/materializations/tests/unit.sql b/core/dbt/include/global_project/macros/materializations/tests/unit.sql index 5c21adeb9a9..b68c8b5bacd 100644 --- a/core/dbt/include/global_project/macros/materializations/tests/unit.sql +++ b/core/dbt/include/global_project/macros/materializations/tests/unit.sql @@ -4,8 +4,17 @@ {% set expected_rows = config.get('expected_rows') %} {% set tested_expected_column_names = expected_rows[0].keys() if (expected_rows | length ) > 0 else get_columns_in_query(sql) %} %} - - {% set unit_test_sql = get_unit_test_sql(sql, get_expected_sql(expected_rows), tested_expected_column_names) %} + + {%- set target_relation = this.incorporate(type='table') -%} + {%- set temp_relation = make_temp_relation(target_relation)-%} + {% do run_query(get_create_table_as_sql(True, temp_relation, get_empty_subquery_sql(sql))) %} + {%- set columns_in_relation = adapter.get_columns_in_relation(temp_relation) -%} + {%- set column_name_to_data_types = {} -%} + {%- for column in columns_in_relation -%} + {%- do column_name_to_data_types.update({column.name: column.dtype}) -%} + {%- endfor -%} + + {% set unit_test_sql = get_unit_test_sql(sql, get_expected_sql(expected_rows, column_name_to_data_types), tested_expected_column_names) %} {% call statement('main', fetch_result=True) -%} @@ -13,6 +22,8 @@ {%- endcall %} + {% do adapter.drop_relation(temp_relation) %} + {{ return({'relations': relations}) }} {%- endmaterialization -%} diff --git a/core/dbt/include/global_project/macros/unit_test_sql/get_fixture_sql.sql b/core/dbt/include/global_project/macros/unit_test_sql/get_fixture_sql.sql index e3705563670..be3cc76f780 100644 --- a/core/dbt/include/global_project/macros/unit_test_sql/get_fixture_sql.sql +++ b/core/dbt/include/global_project/macros/unit_test_sql/get_fixture_sql.sql @@ -1,24 +1,36 @@ -{% macro get_fixture_sql(rows, columns) %} +{% macro get_fixture_sql(rows, column_name_to_data_types) %} -- Fixture for {{ model.name }} {% set default_row = {} %} -{%- if not columns -%} +{%- if not column_name_to_data_types -%} {%- set columns_in_relation = adapter.get_columns_in_relation(this) -%} -{%- set columns = [] -%} +{%- set column_name_to_data_types = {} -%} {%- for column in columns_in_relation -%} -{%- do columns.append({"name": column.name, "data_type": column.dtype}) -%} +{%- do column_name_to_data_types.update({column.name: column.dtype}) -%} {%- endfor -%} {%- endif -%} -{%- if not columns -%} +{%- if not column_name_to_data_types -%} {{ exceptions.raise_compiler_error("columns not available for" ~ model.name) }} {%- endif -%} -{%- for column in columns -%} - {%- do default_row.update({column["name"]: (safe_cast("null", column["data_type"]) | trim )}) -%} +{%- for column_name, column_type in column_name_to_data_types.items() -%} + {%- do default_row.update({column_name: (safe_cast("null", column_type) | trim )}) -%} {%- endfor -%} {%- for row in rows -%} + +{#-- wrap yaml strings in quotes--#} +{%- for column_name, column_value in row.items() -%} +{% set row_update = {column_name: column_value} %} +{%- if column_value is string -%} +{%- set row_update = {column_name: safe_cast(dbt.string_literal(column_value), column_name_to_data_types[column_name]) } -%} +{%- else -%} +{%- set row_update = {column_name: safe_cast(column_value, column_name_to_data_types[column_name]) } -%} +{%- endif -%} +{%- do row.update(row_update) -%} +{%- endfor -%} + {%- set default_row_copy = default_row.copy() -%} {%- do default_row_copy.update(row) -%} select @@ -38,9 +50,21 @@ union all {% endmacro %} -{% macro get_expected_sql(rows) %} +{% macro get_expected_sql(rows, column_name_to_data_types) %} {%- for row in rows -%} + +{#-- wrap yaml strings in quotes--#} +{%- for column_name, column_value in row.items() -%} +{% set row_update = {column_name: column_value} %} +{%- if column_value is string -%} +{%- set row_update = {column_name: safe_cast(dbt.string_literal(column_value), column_name_to_data_types[column_name]) } -%} +{%- else -%} +{%- set row_update = {column_name: safe_cast(column_value, column_name_to_data_types[column_name]) } -%} +{%- endif -%} +{%- do row.update(row_update) -%} +{%- endfor -%} + select {%- for column_name, column_value in row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %} {%- endfor %} diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index 5eeeb24083a..621154c55ac 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -1491,9 +1491,9 @@ def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock): self.schema_parser = schema_parser self.yaml = yaml - def _build_raw_code(self, rows, columns) -> str: - return ("{{{{ get_fixture_sql({rows}, {columns}) }}}}").format( - rows=rows, columns=columns + def _build_raw_code(self, rows, column_name_to_data_types) -> str: + return ("{{{{ get_fixture_sql({rows}, {column_name_to_data_types}) }}}}").format( + rows=rows, column_name_to_data_types=column_name_to_data_types ) def parse_unit_test(self, unparsed: UnparsedUnitTestSuite): @@ -1519,7 +1519,7 @@ def parse_unit_test(self, unparsed: UnparsedUnitTestSuite): input_model_name, input_package_name, None, self.manifest ) if original_input_node.config.contract.enforced: - original_input_node_columns = [{"name": column.name, "data_type": column.data_type} for column in original_input_node.columns] + original_input_node_columns = {column.name: column.data_type for column in original_input_node.columns} elif statically_parsed["sources"]: input_package_name, input_source_name = statically_parsed["sources"][0] original_input_node = self.manifest.source_lookup.find( @@ -1577,6 +1577,9 @@ def parse_unit_test(self, unparsed: UnparsedUnitTestSuite): attached_node=actual_node.unique_id ) + self.schema_parser._update_node_database(unit_test_node, {}) + self.schema_parser._update_node_schema(unit_test_node, {}) + ctx = generate_parse_exposure( unit_test_node, self.root_project,