Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable inline csv format in unit testing #8743

Merged
merged 11 commits into from
Oct 5, 2023
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230928-163205.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Enable inline csv fixtures in unit tests
time: 2023-09-28T16:32:05.573776-04:00
custom:
Author: gshank
Issue: "8626"
5 changes: 5 additions & 0 deletions core/dbt/contracts/graph/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,11 @@ class ModelConfig(NodeConfig):
)


@dataclass
class UnitTestNodeConfig(NodeConfig):
expected_rows: List[Dict[str, Any]] = field(default_factory=list)


@dataclass
class SeedConfig(NodeConfig):
materialized: str = "seed"
Expand Down
9 changes: 6 additions & 3 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
UnparsedSourceTableDefinition,
UnparsedColumn,
UnitTestOverrides,
InputFixture,
UnitTestInputFixture,
UnitTestOutputFixture,
)
from dbt.contracts.graph.node_args import ModelNodeArgs
from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin
Expand Down Expand Up @@ -78,6 +79,7 @@
SnapshotConfig,
SemanticModelConfig,
UnitTestConfig,
UnitTestNodeConfig,
)


Expand Down Expand Up @@ -1063,13 +1065,14 @@ class UnitTestNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Unit]})
attached_node: Optional[str] = None
overrides: Optional[UnitTestOverrides] = None
config: UnitTestNodeConfig = field(default_factory=UnitTestNodeConfig)


@dataclass
class UnitTestDefinition(GraphNode):
model: str
given: Sequence[InputFixture]
expect: List[Dict[str, Any]]
given: Sequence[UnitTestInputFixture]
expect: UnitTestOutputFixture
description: str = ""
overrides: Optional[UnitTestOverrides] = None
depends_on: DependsOn = field(default_factory=DependsOn)
Expand Down
45 changes: 41 additions & 4 deletions core/dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import datetime
import re
import csv
from io import StringIO

from dbt import deprecations
from dbt.node_types import NodeType
Expand Down Expand Up @@ -736,10 +738,45 @@
return dt


class UnitTestFormat(StrEnum):
CSV = "csv"
Dict = "dict"


class UnitTestFixture:
@property
def format(self) -> UnitTestFormat:
return UnitTestFormat.Dict

Check warning on line 749 in core/dbt/contracts/graph/unparsed.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/contracts/graph/unparsed.py#L749

Added line #L749 was not covered by tests

@property
def rows(self) -> Union[str, List[Dict[str, Any]]]:
return []

Check warning on line 753 in core/dbt/contracts/graph/unparsed.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/contracts/graph/unparsed.py#L753

Added line #L753 was not covered by tests
Comment on lines +747 to +753
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it necessary for these to be properties? Could they instead be attributes that are inherited by UnitTestInputFixture and UnitTestOutputFixture?

e.g.

rows: Union[str, List[Dict[str, Any]]] = ""
format: UnitTestFormat = UnitTestFormat.Dict

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do that then we run into the frustrating issue of fields without defaults can't come after fields with defaults issue and have to split them out into a special class and do a different order. I'd kind of rather not.


def get_rows(self) -> List[Dict[str, Any]]:
if self.format == UnitTestFormat.Dict:
assert isinstance(self.rows, List)
return self.rows
elif self.format == UnitTestFormat.CSV:
assert isinstance(self.rows, str)
dummy_file = StringIO(self.rows)
reader = csv.DictReader(dummy_file)
rows = []
for row in reader:
rows.append(row)
return rows


@dataclass
class InputFixture(dbtClassMixin):
class UnitTestInputFixture(dbtClassMixin, UnitTestFixture):
input: str
rows: List[Dict[str, Any]] = field(default_factory=list)
rows: Union[str, List[Dict[str, Any]]] = ""
format: UnitTestFormat = UnitTestFormat.Dict


@dataclass
class UnitTestOutputFixture(dbtClassMixin, UnitTestFixture):
rows: Union[str, List[Dict[str, Any]]] = ""
format: UnitTestFormat = UnitTestFormat.Dict


@dataclass
Expand All @@ -752,8 +789,8 @@
@dataclass
class UnparsedUnitTestDefinition(dbtClassMixin):
name: str
given: Sequence[InputFixture]
expect: List[Dict[str, Any]]
given: Sequence[UnitTestInputFixture]
expect: UnitTestOutputFixture
description: str = ""
overrides: Optional[UnitTestOverrides] = None
config: Dict[str, Any] = field(default_factory=dict)
Expand Down
33 changes: 27 additions & 6 deletions core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dbt.context.providers import generate_parse_exposure, get_rendered
from dbt.contracts.files import FileHash
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.model_config import NodeConfig
from dbt.contracts.graph.model_config import UnitTestNodeConfig, ModelConfig
from dbt.contracts.graph.nodes import (
ModelNode,
UnitTestNode,
Expand All @@ -14,7 +14,7 @@
DependsOn,
UnitTestConfig,
)
from dbt.contracts.graph.unparsed import UnparsedUnitTestSuite
from dbt.contracts.graph.unparsed import UnparsedUnitTestSuite, UnitTestFormat
from dbt.exceptions import ParsingError
from dbt.graph import UniqueId
from dbt.node_types import NodeType
Expand Down Expand Up @@ -66,7 +66,9 @@
path=get_pseudo_test_path(name, test_case.original_file_path),
original_file_path=test_case.original_file_path,
unique_id=test_case.unique_id,
config=NodeConfig(materialized="unit", _extra={"expected_rows": test_case.expect}),
config=UnitTestNodeConfig(
materialized="unit", expected_rows=test_case.expect.get_rows()
),
raw_code=actual_node.raw_code,
database=actual_node.database,
schema=actual_node.schema,
Expand Down Expand Up @@ -118,16 +120,15 @@
# TODO: package_name?
input_name = f"{test_case.model}__{test_case.name}__{original_input_node.name}"
input_unique_id = f"model.{package_name}.{input_name}"

input_node = ModelNode(
raw_code=self._build_raw_code(given.rows, original_input_node_columns),
raw_code=self._build_raw_code(given.get_rows(), original_input_node_columns),
resource_type=NodeType.Model,
package_name=package_name,
path=original_input_node.path,
original_file_path=original_input_node.original_file_path,
unique_id=input_unique_id,
name=input_name,
config=NodeConfig(materialized="ephemeral"),
config=ModelConfig(materialized="ephemeral"),
database=original_input_node.database,
schema=original_input_node.schema,
alias=original_input_node.alias,
Expand Down Expand Up @@ -189,6 +190,26 @@
unit_test_fqn = [self.project.project_name] + model_name_split + [test.name]
unit_test_config = self._build_unit_test_config(unit_test_fqn, test.config)

# Check that format and type of rows matches for each given input
for input in test.given:
if (input.format == "dict" and not isinstance(input.rows, list)) or (
input.format == UnitTestFormat.CSV and not isinstance(input.rows, str)
):
raise ParsingError(

Check warning on line 198 in core/dbt/parser/unit_tests.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/unit_tests.py#L198

Added line #L198 was not covered by tests
f"Input rows in invalid format for unit test {test.name}"
)
# Check that format type of rows matches for expect
if (
test.expect.format == UnitTestFormat.Dict
and not isinstance(test.expect.rows, list)
) or (
test.expect.format == UnitTestFormat.CSV
and not isinstance(test.expect.rows, str)
):
raise ParsingError(

Check warning on line 209 in core/dbt/parser/unit_tests.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/unit_tests.py#L209

Added line #L209 was not covered by tests
f"Expected rows in invalid format for unit test {test.name}"
)

gshank marked this conversation as resolved.
Show resolved Hide resolved
unit_test_definition = UnitTestDefinition(
name=test.name,
model=unit_test_suite.model,
Expand Down
7 changes: 7 additions & 0 deletions schemas/dbt/manifest/v11.json
Original file line number Diff line number Diff line change
Expand Up @@ -6005,6 +6005,13 @@
"type": "string"
}
}
},
"format": {
"enum": [
"csv",
"dict"
],
"default": "dict"
}
},
"additionalProperties": false,
Expand Down
Loading