Skip to content

Commit

Permalink
first pass
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk committed Apr 30, 2023
1 parent a88f640 commit 0ab9222
Show file tree
Hide file tree
Showing 16 changed files with 466 additions and 70 deletions.
15 changes: 14 additions & 1 deletion core/dbt/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

from dbt.clients._jinja_blocks import BlockIterator, BlockData, BlockTag
from dbt.contracts.graph.nodes import GenericTestNode
from dbt.contracts.graph.nodes import GenericTestNode, UnitTestNode

from dbt.exceptions import (
CaughtMacroError,
Expand Down Expand Up @@ -620,6 +620,7 @@ def extract_toplevel_blocks(


GENERIC_TEST_KWARGS_NAME = "_dbt_generic_test_kwargs"
UNIT_TEST_KWARGS_NAME = "_dbt_unit_test_kwargs"


def add_rendered_test_kwargs(
Expand Down Expand Up @@ -654,6 +655,18 @@ def _convert_function(value: Any, keypath: Tuple[Union[str, int], ...]) -> Any:
context[GENERIC_TEST_KWARGS_NAME] = kwargs


def add_rendered_unit_test_kwargs(
context: Dict[str, Any],
node: UnitTestNode,
capture_macros: bool = False,
) -> None:
"""Render each of the test kwargs in the given context using the native
renderer, then insert that value into the given context as the special test
keyword arguments member.
"""
context[UNIT_TEST_KWARGS_NAME] = node.unit_test_metadata.kwargs


def get_supported_languages(node: jinja2.nodes.Macro) -> List[ModelLanguage]:
if "materialization" not in node.name:
raise MaterializtionMacroNotUsedError(node=node)
Expand Down
6 changes: 5 additions & 1 deletion core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
GraphMemberNode,
InjectedCTE,
SeedNode,
UnitTestNode,
)
from dbt.exceptions import (
GraphDependencyNotFoundError,
Expand Down Expand Up @@ -178,14 +179,17 @@ def _create_node_context(
manifest: Manifest,
extra_context: Dict[str, Any],
) -> Dict[str, Any]:

context = generate_runtime_model_context(node, self.config, manifest)
context.update(extra_context)

if isinstance(node, GenericTestNode):
# for test nodes, add a special keyword args value to the context
jinja.add_rendered_test_kwargs(context, node)

elif isinstance(node, UnitTestNode):
# for test nodes, add a special keyword args value to the context
jinja.add_rendered_unit_test_kwargs(context, node)

return context

def add_ephemeral_prefix(self, name: str):
Expand Down
7 changes: 5 additions & 2 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Resource,
ManifestNode,
RefArgs,
UnitTestNode,
)
from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference
from dbt.contracts.graph.unparsed import NodeVersion
Expand Down Expand Up @@ -477,6 +478,8 @@ def resolve(
target_package: Optional[str] = None,
target_version: Optional[NodeVersion] = None,
) -> RelationProxy:
if isinstance(self.model, UnitTestNode):
target_name = f"{self.model.name}__{target_name}"
target_model = self.manifest.resolve_ref(
target_name,
target_package,
Expand Down Expand Up @@ -1310,7 +1313,7 @@ class ModelContext(ProviderContext):

@contextproperty
def pre_hooks(self) -> List[Dict[str, Any]]:
if self.model.resource_type in [NodeType.Source, NodeType.Test]:
if self.model.resource_type in [NodeType.Source, NodeType.Test, NodeType.Unit]:
return []
# TODO CT-211
return [
Expand All @@ -1319,7 +1322,7 @@ def pre_hooks(self) -> List[Dict[str, Any]]:

@contextproperty
def post_hooks(self) -> List[Dict[str, Any]]:
if self.model.resource_type in [NodeType.Source, NodeType.Test]:
if self.model.resource_type in [NodeType.Source, NodeType.Test, NodeType.Unit]:
return []
# TODO CT-211
return [
Expand Down
7 changes: 7 additions & 0 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
GraphMemberNode,
ResultNode,
BaseNode,
UnitTestNode,
)
from dbt.contracts.graph.unparsed import SourcePatch, NodeVersion
from dbt.contracts.graph.manifest_upgrade import upgrade_manifest_json
Expand Down Expand Up @@ -1117,6 +1118,12 @@ def add_group(self, source_file: SchemaSourceFile, group: Group):
self.groups[group.unique_id] = group
source_file.groups.append(group.unique_id)

def add_unit_test(self, unit_test_node: UnitTestNode, inputs: List[ManifestNode]):
# TODO: _check_duplicates(group, self.groups)
self.nodes[unit_test_node.unique_id] = unit_test_node
for input in inputs:
self.nodes[input.unique_id] = input

def add_disabled_nofile(self, node: GraphMemberNode):
# There can be multiple disabled nodes for the same unique_id
if node.unique_id in self.disabled:
Expand Down
1 change: 1 addition & 0 deletions core/dbt/contracts/graph/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ def finalize_and_validate(self):
NodeType.Source: SourceConfig,
NodeType.Seed: SeedConfig,
NodeType.Test: TestConfig,
NodeType.Unit: TestConfig,
NodeType.Model: NodeConfig,
NodeType.Snapshot: SnapshotConfig,
}
Expand Down
38 changes: 34 additions & 4 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,13 +843,27 @@ class TestMetadata(dbtClassMixin, Replaceable):
namespace: Optional[str] = None


@dataclass
class UnitTestMetadata(dbtClassMixin, Replaceable):
# kwargs are the args that are left in the test builder after
# removing configs. They are set from the test builder when
# the test node is created.
kwargs: Dict[str, Any] = field(default_factory=dict)
namespace: Optional[str] = None


# This has to be separated out because it has no default and so
# has to be included as a superclass, not an attribute
@dataclass
class HasTestMetadata(dbtClassMixin):
test_metadata: TestMetadata


@dataclass
class HasUnitTestMetadata(dbtClassMixin):
unit_test_metadata: UnitTestMetadata


@dataclass
class GenericTestNode(TestShouldStoreFailures, CompiledNode, HasTestMetadata):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
Expand All @@ -871,6 +885,25 @@ def test_node_type(self):
return "generic"


@dataclass
class UnitTestNode(TestShouldStoreFailures, CompiledNode, HasUnitTestMetadata):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Unit]})
# file_key_name: Optional[str] = None
# Was not able to make mypy happy and keep the code working. We need to
# refactor the various configs.
attached_node: Optional[str] = None

def same_contents(self, other) -> bool:
if other is None:
return False

return self.same_config(other) and self.same_fqn(other) and True

@property
def test_node_type(self):
return "unit"


# ====================================
# Snapshot node
# ====================================
Expand Down Expand Up @@ -1403,7 +1436,4 @@ class ParsedMacroPatch(ParsedPatch):
Group,
]

TestNode = Union[
SingularTestNode,
GenericTestNode,
]
TestNode = Union[SingularTestNode, GenericTestNode, UnitTestNode]
29 changes: 29 additions & 0 deletions core/dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,3 +652,32 @@ def validate(cls, data):
super(UnparsedGroup, cls).validate(data)
if data["owner"].get("name") is None and data["owner"].get("email") is None:
raise ValidationError("Group owner must have at least one of 'name' or 'email'.")


@dataclass
class UnparsedFixture(dbtClassMixin):
defaults: Optional[Dict[str, str]] = field(default_factory=dict)
rows: List[Dict[str, Any]] = field(default_factory=list)


@dataclass
class UnparsedInputFixtureMandatory:
name: str = ""


@dataclass
class UnparsedInputFixture(UnparsedInputFixtureMandatory, UnparsedFixture):
pass


@dataclass
class UnparsedUnitTestCase(dbtClassMixin):
name: str
inputs: Sequence[UnparsedInputFixture]
expected_output: UnparsedFixture


@dataclass
class UnparsedUnitTestSuite(dbtClassMixin):
name: str
tests: Sequence[UnparsedUnitTestCase]
3 changes: 3 additions & 0 deletions core/dbt/graph/selector_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ResultNode,
ManifestNode,
ModelNode,
UnitTestNode,
)
from dbt.contracts.graph.unparsed import UnparsedVersion
from dbt.contracts.state import PreviousState
Expand Down Expand Up @@ -435,6 +436,8 @@ def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[Uniqu
search_type = GenericTestNode
elif selector in ("singular", "data"):
search_type = SingularTestNode
elif selector in ("unit"):
search_type = UnitTestNode
else:
raise DbtRuntimeError(
f'Invalid test type selector {selector}: expected "generic" or ' '"singular"'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,22 @@
{{ "limit " ~ limit if limit != none }}
) dbt_internal_test
{%- endmacro %}


{% macro get_unit_test_sql(main_sql, expected_fixture_sql, expected_column_names) -%}
{{ adapter.dispatch('get_unit_test_sql', 'dbt')(main_sql, expected_fixture_sql, expected_column_names) }}
{%- endmacro %}

{% macro default__get_unit_test_sql(main_sql, expected_fixture_sql, expected_column_names) -%}
select
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%},{% endif %}{%- endfor -%}, "actual" as actual_or_expected
from (
{{ main_sql }}
) dbt_internal_unit_test_actual
union all
select
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%},{% endif %}{%- endfor -%}, "expected" as actual_or_expected
from (
{{ expected_fixture_sql }}
) dbt_internal_unit_test_expected
{%- endmacro %}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
{%- materialization unit, default -%}

{% set relations = [] %}

{% set expected_rows = config.get('expected_rows') %}
{% set all_expected_column_names = get_columns_in_query(sql) %}
{% set tested_expected_column_names = expected_rows[0].keys() if (expected_rows | length ) > 0 else all_expected_column_names %}
{% set unit_test_sql = get_unit_test_sql(sql, get_fixture_sql(expected_rows, all_expected_column_names), tested_expected_column_names) %}

{% if should_store_failures() %}

{% set identifier = model['alias'] %}
{% set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) %}
{% set target_relation = api.Relation.create(
identifier=identifier, schema=schema, database=database, type='table') -%} %}

{% if old_relation %}
{% do adapter.drop_relation(old_relation) %}
{% endif %}

{% call statement(auto_begin=True) %}
{{ create_table_as(False, target_relation, unit_test_sql) }}
{% endcall %}

{% do relations.append(target_relation) %}

{% set main_sql %}
select *
from {{ target_relation }}
{% endset %}

{{ adapter.commit() }}

{% else %}

{% set main_sql = unit_test_sql %}

{% endif %}

{% call statement('main', fetch_result=True) -%}

{{ main_sql }}

{%- endcall %}

{{ return({'relations': relations}) }}

{%- endmaterialization -%}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{% macro get_fixture_sql(rows, column_names) %}

{% set default_row = {} %}
{%- for column_name in column_names -%}
{%- do default_row.update({column_name: "null"}) -%}
{%- endfor -%}

{% for row in rows -%}
{%- set default_row_copy = default_row.copy() -%}
{%- do default_row_copy.update(row) -%}
SELECT
{%- for column_name, column_value in default_row_copy.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%},{%- endif %}
{%- endfor %}
{%- if not loop.last %}
UNION ALL
{% endif %}
{%- endfor %}


{%- if (rows | length) == 0 %}
SELECT
{%- for column_name, column_value in default_row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%},{%- endif %}
{%- endfor %}
LIMIT 0
{% endif %}

{% endmacro %}
2 changes: 2 additions & 0 deletions core/dbt/node_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class NodeType(StrEnum):
Exposure = "exposure"
Metric = "metric"
Group = "group"
Unit = "unit test"

@classmethod
def executable(cls) -> List["NodeType"]:
Expand All @@ -46,6 +47,7 @@ def executable(cls) -> List["NodeType"]:
cls.Documentation,
cls.RPCCall,
cls.SqlOperation,
cls.Unit,
]

@classmethod
Expand Down
9 changes: 8 additions & 1 deletion core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
SeedNode,
ManifestNode,
ResultNode,
UnitTestNode,
)
from dbt.contracts.graph.unparsed import NodeVersion
from dbt.contracts.util import Writable
Expand Down Expand Up @@ -1339,9 +1340,15 @@ def _process_refs_for_node(manifest: Manifest, current_project: str, node: Manif
if isinstance(node, SeedNode):
return

unit_test_name = None
if isinstance(node, UnitTestNode):
unit_test_name = node.name

for ref in node.refs:
target_model: Optional[Union[Disabled, ManifestNode]] = None
target_model_name: str = ref.name
target_model_name: str = (
ref.name if unit_test_name is None else f"{unit_test_name}__{ref.name}"
)
target_model_package: Optional[str] = ref.package
target_model_version: Optional[NodeVersion] = ref.version

Expand Down
Loading

0 comments on commit 0ab9222

Please sign in to comment.