Skip to content

Commit

Permalink
better type handling
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk committed May 3, 2023
1 parent f8bdd8b commit 2184a4d
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,26 @@

{% set expected_rows = config.get('expected_rows') %}
{% set tested_expected_column_names = expected_rows[0].keys() if (expected_rows | length ) > 0 else get_columns_in_query(sql) %} %}

{% set unit_test_sql = get_unit_test_sql(sql, get_expected_sql(expected_rows), tested_expected_column_names) %}

{%- set target_relation = this.incorporate(type='table') -%}
{%- set temp_relation = make_temp_relation(target_relation)-%}
{% do run_query(get_create_table_as_sql(True, temp_relation, get_empty_subquery_sql(sql))) %}
{%- set columns_in_relation = adapter.get_columns_in_relation(temp_relation) -%}
{%- set column_name_to_data_types = {} -%}
{%- for column in columns_in_relation -%}
{%- do column_name_to_data_types.update({column.name: column.dtype}) -%}
{%- endfor -%}

{% set unit_test_sql = get_unit_test_sql(sql, get_expected_sql(expected_rows, column_name_to_data_types), tested_expected_column_names) %}

{% call statement('main', fetch_result=True) -%}

{{ unit_test_sql }}

{%- endcall %}

{% do adapter.drop_relation(temp_relation) %}

{{ return({'relations': relations}) }}

{%- endmaterialization -%}
Original file line number Diff line number Diff line change
@@ -1,24 +1,36 @@
{% macro get_fixture_sql(rows, columns) %}
{% macro get_fixture_sql(rows, column_name_to_data_types) %}
-- Fixture for {{ model.name }}
{% set default_row = {} %}

{%- if not columns -%}
{%- if not column_name_to_data_types -%}
{%- set columns_in_relation = adapter.get_columns_in_relation(this) -%}
{%- set columns = [] -%}
{%- set column_name_to_data_types = {} -%}
{%- for column in columns_in_relation -%}
{%- do columns.append({"name": column.name, "data_type": column.dtype}) -%}
{%- do column_name_to_data_types.update({column.name: column.dtype}) -%}
{%- endfor -%}
{%- endif -%}

{%- if not columns -%}
{%- if not column_name_to_data_types -%}
{{ exceptions.raise_compiler_error("columns not available for" ~ model.name) }}
{%- endif -%}

{%- for column in columns -%}
{%- do default_row.update({column["name"]: (safe_cast("null", column["data_type"]) | trim )}) -%}
{%- for column_name, column_type in column_name_to_data_types.items() -%}
{%- do default_row.update({column_name: (safe_cast("null", column_type) | trim )}) -%}
{%- endfor -%}

{%- for row in rows -%}

{#-- wrap yaml strings in quotes--#}
{%- for column_name, column_value in row.items() -%}
{% set row_update = {column_name: column_value} %}
{%- if column_value is string -%}
{%- set row_update = {column_name: safe_cast(dbt.string_literal(column_value), column_name_to_data_types[column_name]) } -%}
{%- else -%}
{%- set row_update = {column_name: safe_cast(column_value, column_name_to_data_types[column_name]) } -%}
{%- endif -%}
{%- do row.update(row_update) -%}
{%- endfor -%}

{%- set default_row_copy = default_row.copy() -%}
{%- do default_row_copy.update(row) -%}
select
Expand All @@ -38,9 +50,21 @@ union all
{% endmacro %}


{% macro get_expected_sql(rows) %}
{% macro get_expected_sql(rows, column_name_to_data_types) %}

{%- for row in rows -%}

{#-- wrap yaml strings in quotes--#}
{%- for column_name, column_value in row.items() -%}
{% set row_update = {column_name: column_value} %}
{%- if column_value is string -%}
{%- set row_update = {column_name: safe_cast(dbt.string_literal(column_value), column_name_to_data_types[column_name]) } -%}
{%- else -%}
{%- set row_update = {column_name: safe_cast(column_value, column_name_to_data_types[column_name]) } -%}
{%- endif -%}
{%- do row.update(row_update) -%}
{%- endfor -%}

select
{%- for column_name, column_value in row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- endfor %}
Expand Down
11 changes: 7 additions & 4 deletions core/dbt/parser/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1491,9 +1491,9 @@ def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock):
self.schema_parser = schema_parser
self.yaml = yaml

def _build_raw_code(self, rows, columns) -> str:
return ("{{{{ get_fixture_sql({rows}, {columns}) }}}}").format(
rows=rows, columns=columns
def _build_raw_code(self, rows, column_name_to_data_types) -> str:
return ("{{{{ get_fixture_sql({rows}, {column_name_to_data_types}) }}}}").format(
rows=rows, column_name_to_data_types=column_name_to_data_types
)

def parse_unit_test(self, unparsed: UnparsedUnitTestSuite):
Expand All @@ -1519,7 +1519,7 @@ def parse_unit_test(self, unparsed: UnparsedUnitTestSuite):
input_model_name, input_package_name, None, self.manifest
)
if original_input_node.config.contract.enforced:
original_input_node_columns = [{"name": column.name, "data_type": column.data_type} for column in original_input_node.columns]
original_input_node_columns = {column.name: column.data_type for column in original_input_node.columns}
elif statically_parsed["sources"]:
input_package_name, input_source_name = statically_parsed["sources"][0]
original_input_node = self.manifest.source_lookup.find(
Expand Down Expand Up @@ -1577,6 +1577,9 @@ def parse_unit_test(self, unparsed: UnparsedUnitTestSuite):
attached_node=actual_node.unique_id
)

self.schema_parser._update_node_database(unit_test_node, {})
self.schema_parser._update_node_schema(unit_test_node, {})

ctx = generate_parse_exposure(
unit_test_node,
self.root_project,
Expand Down

0 comments on commit 2184a4d

Please sign in to comment.