diff --git a/.changes/unreleased/Features-20230928-163205.yaml b/.changes/unreleased/Features-20230928-163205.yaml new file mode 100644 index 00000000000..7f9b7c047ac --- /dev/null +++ b/.changes/unreleased/Features-20230928-163205.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Enable inline csv fixtures in unit tests +time: 2023-09-28T16:32:05.573776-04:00 +custom: + Author: gshank + Issue: "8626" diff --git a/core/dbt/contracts/graph/model_config.py b/core/dbt/contracts/graph/model_config.py index 205fe426e07..b576b9affed 100644 --- a/core/dbt/contracts/graph/model_config.py +++ b/core/dbt/contracts/graph/model_config.py @@ -526,6 +526,11 @@ class ModelConfig(NodeConfig): ) +@dataclass +class UnitTestNodeConfig(NodeConfig): + expected_rows: List[Dict[str, Any]] = field(default_factory=list) + + @dataclass class SeedConfig(NodeConfig): materialized: str = "seed" diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index 37ef85e3008..5076a6ce9ad 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -35,7 +35,8 @@ UnparsedSourceTableDefinition, UnparsedColumn, UnitTestOverrides, - InputFixture, + UnitTestInputFixture, + UnitTestOutputFixture, ) from dbt.contracts.graph.node_args import ModelNodeArgs from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin @@ -78,6 +79,7 @@ SnapshotConfig, SemanticModelConfig, UnitTestConfig, + UnitTestNodeConfig, ) @@ -1063,13 +1065,14 @@ class UnitTestNode(CompiledNode): resource_type: NodeType = field(metadata={"restrict": [NodeType.Unit]}) attached_node: Optional[str] = None overrides: Optional[UnitTestOverrides] = None + config: UnitTestNodeConfig = field(default_factory=UnitTestNodeConfig) @dataclass class UnitTestDefinition(GraphNode): model: str - given: Sequence[InputFixture] - expect: List[Dict[str, Any]] + given: Sequence[UnitTestInputFixture] + expect: UnitTestOutputFixture description: str = "" overrides: Optional[UnitTestOverrides] = None depends_on: DependsOn = field(default_factory=DependsOn) diff --git a/core/dbt/contracts/graph/unparsed.py b/core/dbt/contracts/graph/unparsed.py index f9cbd316d6e..349e92953fe 100644 --- a/core/dbt/contracts/graph/unparsed.py +++ b/core/dbt/contracts/graph/unparsed.py @@ -1,5 +1,7 @@ import datetime import re +import csv +from io import StringIO from dbt import deprecations from dbt.node_types import NodeType @@ -736,10 +738,53 @@ def normalize_date(d: Optional[datetime.date]) -> Optional[datetime.datetime]: return dt +class UnitTestFormat(StrEnum): + CSV = "csv" + Dict = "dict" + + +class UnitTestFixture: + @property + def format(self) -> UnitTestFormat: + return UnitTestFormat.Dict + + @property + def rows(self) -> Union[str, List[Dict[str, Any]]]: + return [] + + def get_rows(self) -> List[Dict[str, Any]]: + if self.format == UnitTestFormat.Dict: + assert isinstance(self.rows, List) + return self.rows + elif self.format == UnitTestFormat.CSV: + assert isinstance(self.rows, str) + dummy_file = StringIO(self.rows) + reader = csv.DictReader(dummy_file) + rows = [] + for row in reader: + rows.append(row) + return rows + + def validate_fixture(self, fixture_type, test_name) -> None: + if (self.format == UnitTestFormat.Dict and not isinstance(self.rows, list)) or ( + self.format == UnitTestFormat.CSV and not isinstance(self.rows, str) + ): + raise ParsingError( + f"Unit test {test_name} has {fixture_type} rows which do not match format {self.format}" + ) + + @dataclass -class InputFixture(dbtClassMixin): +class UnitTestInputFixture(dbtClassMixin, UnitTestFixture): input: str - rows: List[Dict[str, Any]] = field(default_factory=list) + rows: Union[str, List[Dict[str, Any]]] = "" + format: UnitTestFormat = UnitTestFormat.Dict + + +@dataclass +class UnitTestOutputFixture(dbtClassMixin, UnitTestFixture): + rows: Union[str, List[Dict[str, Any]]] = "" + format: UnitTestFormat = UnitTestFormat.Dict @dataclass @@ -752,8 +797,8 @@ class UnitTestOverrides(dbtClassMixin): @dataclass class UnparsedUnitTestDefinition(dbtClassMixin): name: str - given: Sequence[InputFixture] - expect: List[Dict[str, Any]] + given: Sequence[UnitTestInputFixture] + expect: UnitTestOutputFixture description: str = "" overrides: Optional[UnitTestOverrides] = None config: Dict[str, Any] = field(default_factory=dict) diff --git a/core/dbt/parser/unit_tests.py b/core/dbt/parser/unit_tests.py index b011c3f32ba..fa8aa6c48c3 100644 --- a/core/dbt/parser/unit_tests.py +++ b/core/dbt/parser/unit_tests.py @@ -5,7 +5,7 @@ from dbt.context.providers import generate_parse_exposure, get_rendered from dbt.contracts.files import FileHash from dbt.contracts.graph.manifest import Manifest -from dbt.contracts.graph.model_config import NodeConfig +from dbt.contracts.graph.model_config import UnitTestNodeConfig, ModelConfig from dbt.contracts.graph.nodes import ( ModelNode, UnitTestNode, @@ -66,7 +66,9 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): path=get_pseudo_test_path(name, test_case.original_file_path), original_file_path=test_case.original_file_path, unique_id=test_case.unique_id, - config=NodeConfig(materialized="unit", _extra={"expected_rows": test_case.expect}), + config=UnitTestNodeConfig( + materialized="unit", expected_rows=test_case.expect.get_rows() + ), raw_code=actual_node.raw_code, database=actual_node.database, schema=actual_node.schema, @@ -118,16 +120,15 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): # TODO: package_name? input_name = f"{test_case.model}__{test_case.name}__{original_input_node.name}" input_unique_id = f"model.{package_name}.{input_name}" - input_node = ModelNode( - raw_code=self._build_raw_code(given.rows, original_input_node_columns), + raw_code=self._build_raw_code(given.get_rows(), original_input_node_columns), resource_type=NodeType.Model, package_name=package_name, path=original_input_node.path, original_file_path=original_input_node.original_file_path, unique_id=input_unique_id, name=input_name, - config=NodeConfig(materialized="ephemeral"), + config=ModelConfig(materialized="ephemeral"), database=original_input_node.database, schema=original_input_node.schema, alias=original_input_node.alias, @@ -189,6 +190,11 @@ def parse(self) -> ParseResult: unit_test_fqn = [self.project.project_name] + model_name_split + [test.name] unit_test_config = self._build_unit_test_config(unit_test_fqn, test.config) + # Check that format and type of rows matches for each given input + for input in test.given: + input.validate_fixture("input", test.name) + test.expect.validate_fixture("expected", test.name) + unit_test_definition = UnitTestDefinition( name=test.name, model=unit_test_suite.model, diff --git a/schemas/dbt/manifest/v11.json b/schemas/dbt/manifest/v11.json index 94598a20471..230052fcd39 100644 --- a/schemas/dbt/manifest/v11.json +++ b/schemas/dbt/manifest/v11.json @@ -6005,6 +6005,13 @@ "type": "string" } } + }, + "format": { + "enum": [ + "csv", + "dict" + ], + "default": "dict" } }, "additionalProperties": false, diff --git a/tests/functional/unit_testing/test_unit_testing.py b/tests/functional/unit_testing/test_unit_testing.py index 0773a53e613..4cb426a5343 100644 --- a/tests/functional/unit_testing/test_unit_testing.py +++ b/tests/functional/unit_testing/test_unit_testing.py @@ -1,6 +1,6 @@ import pytest from dbt.tests.util import run_dbt, write_file, get_manifest, get_artifact -from dbt.exceptions import DuplicateResourceNameError +from dbt.exceptions import DuplicateResourceNameError, YamlParseDictError, ParsingError my_model_sql = """ SELECT @@ -47,7 +47,8 @@ - {id: 1, b: 2} - {id: 2, b: 2} expect: - - {c: 2} + rows: + - {c: 2} - name: test_my_model_empty given: @@ -57,7 +58,8 @@ rows: - {id: 1, b: 2} - {id: 2, b: 2} - expect: [] + expect: + rows: [] - name: test_my_model_overrides given: - input: ref('my_model_a') @@ -76,7 +78,8 @@ env_vars: MY_TEST: env_var_override expect: - - {macro_call: override, var_call: var_override, env_var_call: env_var_override, invocation_id: 123} + rows: + - {macro_call: override, var_call: var_override, env_var_call: env_var_override, invocation_id: 123} - name: test_my_model_string_concat given: - input: ref('my_model_a') @@ -86,7 +89,8 @@ rows: - {id: 1, string_b: b} expect: - - {string_c: ab} + rows: + - {string_c: ab} config: tags: test_this """ @@ -101,7 +105,8 @@ rows: - {id: 1} expect: - - {date_a: "2020-01-01"} + rows: + - {date_a: "2020-01-01"} """ @@ -172,3 +177,157 @@ def test_basic(self, project): ) with pytest.raises(DuplicateResourceNameError): run_dbt(["unit-test", "--select", "my_model"]) + + +test_my_model_csv_yml = """ +unit: + - model: my_model + tests: + - name: test_my_model + given: + - input: ref('my_model_a') + format: csv + rows: | + id,a + 1,1 + - input: ref('my_model_b') + format: csv + rows: | + id,b + 1,2 + 2,2 + expect: + format: csv + rows: | + c + 2 + + - name: test_my_model_empty + given: + - input: ref('my_model_a') + rows: [] + - input: ref('my_model_b') + format: csv + rows: | + id,b + 1,2 + 2,2 + expect: + rows: [] + - name: test_my_model_overrides + given: + - input: ref('my_model_a') + format: csv + rows: | + id,a + 1,1 + - input: ref('my_model_b') + format: csv + rows: | + id,b + 1,2 + 2,2 + overrides: + macros: + type_numeric: override + invocation_id: 123 + vars: + my_test: var_override + env_vars: + MY_TEST: env_var_override + expect: + rows: + - {macro_call: override, var_call: var_override, env_var_call: env_var_override, invocation_id: 123} + - name: test_my_model_string_concat + given: + - input: ref('my_model_a') + format: csv + rows: | + id,string_a + 1,a + - input: ref('my_model_b') + format: csv + rows: | + id,string_b + 1,b + expect: + format: csv + rows: | + string_c + ab + config: + tags: test_this +""" + +datetime_test_invalid_format = """ + - name: test_my_model_datetime + given: + - input: ref('my_model_a') + format: xxxx + rows: + - {id: 1, date_a: "2020-01-01"} + - input: ref('my_model_b') + rows: + - {id: 1} + expect: + rows: + - {date_a: "2020-01-01"} +""" + +datetime_test_invalid_format2 = """ + - name: test_my_model_datetime + given: + - input: ref('my_model_a') + format: csv + rows: + - {id: 1, date_a: "2020-01-01"} + - input: ref('my_model_b') + rows: + - {id: 1} + expect: + rows: + - {date_a: "2020-01-01"} +""" + + +class TestUnitTestsWithInlineCSV: + @pytest.fixture(scope="class") + def models(self): + return { + "my_model.sql": my_model_sql, + "my_model_a.sql": my_model_a_sql, + "my_model_b.sql": my_model_b_sql, + "test_my_model.yml": test_my_model_csv_yml + datetime_test, + } + + @pytest.fixture(scope="class") + def project_config_update(self): + return {"vars": {"my_test": "my_test_var"}} + + def test_basic(self, project): + results = run_dbt(["run"]) + assert len(results) == 3 + + # Select by model name + results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False) + assert len(results) == 5 + + # Check error with invalid format + write_file( + test_my_model_csv_yml + datetime_test_invalid_format, + project.project_root, + "models", + "test_my_model.yml", + ) + with pytest.raises(YamlParseDictError): + results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False) + + # Check error with format not matching rows + write_file( + test_my_model_csv_yml + datetime_test_invalid_format2, + project.project_root, + "models", + "test_my_model.yml", + ) + with pytest.raises(ParsingError): + results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False) diff --git a/tests/unit/test_unit_test_parser.py b/tests/unit/test_unit_test_parser.py index 31d98c18b8e..e58bf869814 100644 --- a/tests/unit/test_unit_test_parser.py +++ b/tests/unit/test_unit_test_parser.py @@ -7,6 +7,7 @@ from .test_parser import SchemaParserTest, assertEqualNodes from unittest import mock +from dbt.contracts.graph.unparsed import UnitTestOutputFixture UNIT_TEST_MODEL_NOT_FOUND_SOURCE = """ @@ -16,7 +17,9 @@ - name: test_my_model_doesnt_exist description: "unit test description" given: [] - expect: [] + expect: + rows: + - {a: 1} """ @@ -27,7 +30,9 @@ - name: test_my_model description: "unit test description" given: [] - expect: [] + expect: + rows: + - {a: 1} """ @@ -38,7 +43,9 @@ - name: test_my_model_versioned description: "unit test description" given: [] - expect: [] + expect: + rows: + - {a: 1} """ @@ -54,7 +61,9 @@ meta_jinja_key: '{{ 1 + 1 }}' description: "unit test description" given: [] - expect: [] + expect: + rows: + - {a: 1} """ @@ -65,11 +74,15 @@ - name: test_my_model description: "unit test description" given: [] - expect: [] + expect: + rows: + - {a: 1} - name: test_my_model2 description: "unit test description" given: [] - expect: [] + expect: + rows: + - {a: 1} """ @@ -116,7 +129,7 @@ def test_basic(self): original_file_path=block.path.original_file_path, unique_id="unit.snowplow.my_model.test_my_model", given=[], - expect=[], + expect=UnitTestOutputFixture(rows=[{"a": 1}]), description="unit test description", overrides=None, depends_on=DependsOn(nodes=["model.snowplow.my_model"]),