diff --git a/.changes/unreleased/Fixes-20240923-190758.yaml b/.changes/unreleased/Fixes-20240923-190758.yaml new file mode 100644 index 00000000000..4d005ec5999 --- /dev/null +++ b/.changes/unreleased/Fixes-20240923-190758.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Allow singular tests to be documented in properties.yml +time: 2024-09-23T19:07:58.151069+01:00 +custom: + Author: aranke + Issue: "9005" diff --git a/core/dbt/config/project.py b/core/dbt/config/project.py index 25b0f343ef2..cbad5a38434 100644 --- a/core/dbt/config/project.py +++ b/core/dbt/config/project.py @@ -158,14 +158,8 @@ def _parse_versions(versions: Union[List[str], str]) -> List[VersionSpecifier]: return [VersionSpecifier.from_version_string(v) for v in versions] -def _all_source_paths( - model_paths: List[str], - seed_paths: List[str], - snapshot_paths: List[str], - analysis_paths: List[str], - macro_paths: List[str], -) -> List[str]: - paths = chain(model_paths, seed_paths, snapshot_paths, analysis_paths, macro_paths) +def _all_source_paths(*args: List[str]) -> List[str]: + paths = chain(*args) # Strip trailing slashes since the path is the same even though the name is not stripped_paths = map(lambda s: s.rstrip("/"), paths) return list(set(stripped_paths)) @@ -409,7 +403,7 @@ def create_project(self, rendered: RenderComponents) -> "Project": snapshot_paths: List[str] = value_or(cfg.snapshot_paths, ["snapshots"]) all_source_paths: List[str] = _all_source_paths( - model_paths, seed_paths, snapshot_paths, analysis_paths, macro_paths + model_paths, seed_paths, snapshot_paths, analysis_paths, macro_paths, test_paths ) docs_paths: List[str] = value_or(cfg.docs_paths, all_source_paths) @@ -652,6 +646,7 @@ def all_source_paths(self) -> List[str]: self.snapshot_paths, self.analysis_paths, self.macro_paths, + self.test_paths, ) @property diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index f4cdafea737..b556b479fb4 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -58,6 +58,7 @@ SavedQuery, SeedNode, SemanticModel, + SingularTestNode, SourceDefinition, UnitTestDefinition, UnitTestFileFixture, @@ -89,7 +90,7 @@ RefName = str -def find_unique_id_for_package(storage, key, package: Optional[PackageName]): +def find_unique_id_for_package(storage, key, package: Optional[PackageName]) -> Optional[UniqueID]: if key not in storage: return None @@ -470,6 +471,43 @@ class AnalysisLookup(RefableLookup): _versioned_types: ClassVar[set] = set() +class SingularTestLookup(dbtClassMixin): + def __init__(self, manifest: "Manifest") -> None: + self.storage: Dict[str, Dict[PackageName, UniqueID]] = {} + self.populate(manifest) + + def get_unique_id(self, search_name, package: Optional[PackageName]) -> Optional[UniqueID]: + return find_unique_id_for_package(self.storage, search_name, package) + + def find( + self, search_name, package: Optional[PackageName], manifest: "Manifest" + ) -> Optional[SingularTestNode]: + unique_id = self.get_unique_id(search_name, package) + if unique_id is not None: + return self.perform_lookup(unique_id, manifest) + return None + + def add_singular_test(self, source: SingularTestNode) -> None: + if source.search_name not in self.storage: + self.storage[source.search_name] = {} + + self.storage[source.search_name][source.package_name] = source.unique_id + + def populate(self, manifest: "Manifest") -> None: + for node in manifest.nodes.values(): + if isinstance(node, SingularTestNode): + self.add_singular_test(node) + + def perform_lookup(self, unique_id: UniqueID, manifest: "Manifest") -> SingularTestNode: + if unique_id not in manifest.nodes: + raise dbt_common.exceptions.DbtInternalError( + f"Singular test {unique_id} found in cache but not found in manifest" + ) + node = manifest.nodes[unique_id] + assert isinstance(node, SingularTestNode) + return node + + def _packages_to_search( current_project: str, node_package: str, @@ -869,6 +907,9 @@ class Manifest(MacroMethods, dbtClassMixin): _analysis_lookup: Optional[AnalysisLookup] = field( default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None} ) + _singular_test_lookup: Optional[SingularTestLookup] = field( + default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None} + ) _parsing_info: ParsingInfo = field( default_factory=ParsingInfo, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}, @@ -1264,6 +1305,12 @@ def analysis_lookup(self) -> AnalysisLookup: self._analysis_lookup = AnalysisLookup(self) return self._analysis_lookup + @property + def singular_test_lookup(self) -> SingularTestLookup: + if self._singular_test_lookup is None: + self._singular_test_lookup = SingularTestLookup(self) + return self._singular_test_lookup + @property def external_node_unique_ids(self): return [node.unique_id for node in self.nodes.values() if node.is_external_node] @@ -1708,6 +1755,7 @@ def __reduce_ex__(self, protocol): self._semantic_model_by_measure_lookup, self._disabled_lookup, self._analysis_lookup, + self._singular_test_lookup, ) return self.__class__, args diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index d04d03d0f81..3a386eaaf34 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -1642,6 +1642,11 @@ class ParsedMacroPatch(ParsedPatch): arguments: List[MacroArgument] = field(default_factory=list) +@dataclass +class ParsedSingularTestPatch(ParsedPatch): + pass + + # ==================================== # Node unions/categories # ==================================== diff --git a/core/dbt/contracts/graph/unparsed.py b/core/dbt/contracts/graph/unparsed.py index 847be3d3a2a..ebe704fc1c5 100644 --- a/core/dbt/contracts/graph/unparsed.py +++ b/core/dbt/contracts/graph/unparsed.py @@ -202,6 +202,11 @@ class UnparsedAnalysisUpdate(HasConfig, HasColumnDocs, HasColumnProps, HasYamlMe access: Optional[str] = None +@dataclass +class UnparsedSingularTestUpdate(HasConfig, HasColumnProps, HasYamlMetadata): + pass + + @dataclass class UnparsedNodeUpdate(HasConfig, HasColumnTests, HasColumnAndTestProps, HasYamlMetadata): quote_columns: Optional[bool] = None diff --git a/core/dbt/parser/common.py b/core/dbt/parser/common.py index 66d84d2db9b..5cc4385ea1c 100644 --- a/core/dbt/parser/common.py +++ b/core/dbt/parser/common.py @@ -13,6 +13,7 @@ UnparsedMacroUpdate, UnparsedModelUpdate, UnparsedNodeUpdate, + UnparsedSingularTestUpdate, ) from dbt.exceptions import ParsingError from dbt.node_types import NodeType @@ -58,6 +59,7 @@ def trimmed(inp: str) -> str: UnpatchedSourceDefinition, UnparsedExposure, UnparsedModelUpdate, + UnparsedSingularTestUpdate, ) diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index 29f36f96b34..cf2c3b7f109 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -17,6 +17,7 @@ ModelNode, ParsedMacroPatch, ParsedNodePatch, + ParsedSingularTestPatch, UnpatchedSourceDefinition, ) from dbt.contracts.graph.unparsed import ( @@ -27,6 +28,7 @@ UnparsedMacroUpdate, UnparsedModelUpdate, UnparsedNodeUpdate, + UnparsedSingularTestUpdate, UnparsedSourceDefinition, ) from dbt.events.types import ( @@ -65,7 +67,9 @@ from dbt.utils import coerce_dict_str from dbt_common.contracts.constraints import ConstraintType, ModelLevelConstraint from dbt_common.dataclass_schema import ValidationError, dbtClassMixin -from dbt_common.events.functions import warn_or_error +from dbt_common.events import EventLevel +from dbt_common.events.functions import fire_event, warn_or_error +from dbt_common.events.types import Note from dbt_common.exceptions import DbtValidationError from dbt_common.utils import deep_merge @@ -207,6 +211,18 @@ def parse_file(self, block: FileBlock, dct: Optional[Dict] = None) -> None: parser = MacroPatchParser(self, yaml_block, "macros") parser.parse() + if "data_tests" in dct: + parser = SingularTestPatchParser(self, yaml_block, "data_tests") + try: + parser.parse() + except ParsingError as e: + fire_event( + Note( + msg=f"Unable to parse 'data_tests' section of file '{block.path.original_file_path}'\n{e}", + ), + EventLevel.WARN, + ) + # PatchParser.parse() (but never test_blocks) if "analyses" in dct: parser = AnalysisPatchParser(self, yaml_block, "analyses") @@ -301,7 +317,9 @@ def _add_yaml_snapshot_nodes_to_manifest( self.manifest.rebuild_ref_lookup() -Parsed = TypeVar("Parsed", UnpatchedSourceDefinition, ParsedNodePatch, ParsedMacroPatch) +Parsed = TypeVar( + "Parsed", UnpatchedSourceDefinition, ParsedNodePatch, ParsedMacroPatch, ParsedSingularTestPatch +) NodeTarget = TypeVar("NodeTarget", UnparsedNodeUpdate, UnparsedAnalysisUpdate, UnparsedModelUpdate) NonSourceTarget = TypeVar( "NonSourceTarget", @@ -309,6 +327,7 @@ def _add_yaml_snapshot_nodes_to_manifest( UnparsedAnalysisUpdate, UnparsedMacroUpdate, UnparsedModelUpdate, + UnparsedSingularTestUpdate, ) @@ -1116,6 +1135,55 @@ def _target_type(self) -> Type[UnparsedAnalysisUpdate]: return UnparsedAnalysisUpdate +class SingularTestPatchParser(PatchParser[UnparsedSingularTestUpdate, ParsedSingularTestPatch]): + def get_block(self, node: UnparsedSingularTestUpdate) -> TargetBlock: + return TargetBlock.from_yaml_block(self.yaml, node) + + def _target_type(self) -> Type[UnparsedSingularTestUpdate]: + return UnparsedSingularTestUpdate + + def parse_patch(self, block: TargetBlock[UnparsedSingularTestUpdate], refs: ParserRef) -> None: + patch = ParsedSingularTestPatch( + name=block.target.name, + description=block.target.description, + meta=block.target.meta, + docs=block.target.docs, + config=block.target.config, + original_file_path=block.target.original_file_path, + yaml_key=block.target.yaml_key, + package_name=block.target.package_name, + ) + + assert isinstance(self.yaml.file, SchemaSourceFile) + source_file: SchemaSourceFile = self.yaml.file + + unique_id = self.manifest.singular_test_lookup.get_unique_id( + block.name, block.target.package_name + ) + if not unique_id: + warn_or_error( + NoNodeForYamlKey( + patch_name=patch.name, + yaml_key=patch.yaml_key, + file_path=source_file.path.original_file_path, + ) + ) + return + + node = self.manifest.nodes.get(unique_id) + assert node is not None + + source_file.append_patch(patch.yaml_key, unique_id) + if patch.config: + self.patch_node_config(node, patch) + + node.patch_path = patch.file_id + node.description = patch.description + node.created_at = time.time() + node.meta = patch.meta + node.docs = patch.docs + + class MacroPatchParser(PatchParser[UnparsedMacroUpdate, ParsedMacroPatch]): def get_block(self, node: UnparsedMacroUpdate) -> TargetBlock: return TargetBlock.from_yaml_block(self.yaml, node) diff --git a/tests/functional/data_test_patch/fixtures.py b/tests/functional/data_test_patch/fixtures.py new file mode 100644 index 00000000000..be056f32680 --- /dev/null +++ b/tests/functional/data_test_patch/fixtures.py @@ -0,0 +1,38 @@ +tests__my_singular_test_sql = """ +with my_cte as ( + select 1 as id, 'foo' as name + union all + select 2 as id, 'bar' as name +) +select * from my_cte +""" + +tests__schema_yml = """ +data_tests: + - name: my_singular_test + description: "{{ doc('my_singular_test_documentation') }}" + config: + error_if: ">10" + meta: + some_key: some_val +""" + +tests__doc_block_md = """ +{% docs my_singular_test_documentation %} + +Some docs from a doc block + +{% enddocs %} +""" + +tests__invalid_name_schema_yml = """ +data_tests: + - name: my_double_test + description: documentation, but make it double +""" + +tests__malformed_schema_yml = """ +data_tests: ¬_null + - not_null: + where: some_condition +""" diff --git a/tests/functional/data_test_patch/test_singular_test_patch.py b/tests/functional/data_test_patch/test_singular_test_patch.py new file mode 100644 index 00000000000..df359c5e645 --- /dev/null +++ b/tests/functional/data_test_patch/test_singular_test_patch.py @@ -0,0 +1,65 @@ +from pathlib import Path + +import pytest + +from dbt.tests.util import get_artifact, run_dbt, run_dbt_and_capture +from tests.functional.data_test_patch.fixtures import ( + tests__doc_block_md, + tests__invalid_name_schema_yml, + tests__malformed_schema_yml, + tests__my_singular_test_sql, + tests__schema_yml, +) + + +class TestPatchSingularTest: + @pytest.fixture(scope="class") + def tests(self): + return { + "my_singular_test.sql": tests__my_singular_test_sql, + "schema.yml": tests__schema_yml, + "doc_block.md": tests__doc_block_md, + } + + def test_compile(self, project): + run_dbt(["compile"]) + manifest = get_artifact(project.project_root, "target", "manifest.json") + assert len(manifest["nodes"]) == 1 + + my_singular_test_node = manifest["nodes"]["test.test.my_singular_test"] + assert my_singular_test_node["description"] == "Some docs from a doc block" + assert my_singular_test_node["config"]["error_if"] == ">10" + assert my_singular_test_node["config"]["meta"] == {"some_key": "some_val"} + + +class TestPatchSingularTestInvalidName: + @pytest.fixture(scope="class") + def tests(self): + return { + "my_singular_test.sql": tests__my_singular_test_sql, + "schema_with_invalid_name.yml": tests__invalid_name_schema_yml, + } + + def test_compile(self, project): + _, log_output = run_dbt_and_capture(["compile"]) + + file_path = Path("tests/schema_with_invalid_name.yml") + assert ( + f"Did not find matching node for patch with name 'my_double_test' in the 'data_tests' section of file '{file_path}'" + in log_output + ) + + +class TestPatchSingularTestMalformedYaml: + @pytest.fixture(scope="class") + def tests(self): + return { + "my_singular_test.sql": tests__my_singular_test_sql, + "schema.yml": tests__malformed_schema_yml, + } + + def test_compile(self, project): + _, log_output = run_dbt_and_capture(["compile"]) + file_path = Path("tests/schema.yml") + assert f"Unable to parse 'data_tests' section of file '{file_path}'" in log_output + assert "Entry did not contain a name" in log_output diff --git a/tests/unit/config/test_project.py b/tests/unit/config/test_project.py index ab842c164d7..ddd519cc6ee 100644 --- a/tests/unit/config/test_project.py +++ b/tests/unit/config/test_project.py @@ -31,7 +31,7 @@ class TestProjectMethods: def test_all_source_paths(self, project: Project): assert ( project.all_source_paths.sort() - == ["models", "seeds", "snapshots", "analyses", "macros"].sort() + == ["models", "seeds", "snapshots", "analyses", "macros", "tests"].sort() ) def test_generic_test_paths(self, project: Project): @@ -99,7 +99,8 @@ def test_defaults(self): self.assertEqual(project.test_paths, ["tests"]) self.assertEqual(project.analysis_paths, ["analyses"]) self.assertEqual( - set(project.docs_paths), set(["models", "seeds", "snapshots", "analyses", "macros"]) + set(project.docs_paths), + {"models", "seeds", "snapshots", "analyses", "macros", "tests"}, ) self.assertEqual(project.asset_paths, []) self.assertEqual(project.target_path, "target") @@ -128,7 +129,7 @@ def test_implicit_overrides(self): ) self.assertEqual( set(project.docs_paths), - set(["other-models", "seeds", "snapshots", "analyses", "macros"]), + {"other-models", "seeds", "snapshots", "analyses", "macros", "tests"}, ) def test_all_overrides(self): diff --git a/tests/unit/config/test_runtime.py b/tests/unit/config/test_runtime.py index 816ec8f98c3..d03d33dab94 100644 --- a/tests/unit/config/test_runtime.py +++ b/tests/unit/config/test_runtime.py @@ -129,7 +129,7 @@ def test_from_args(self): self.assertEqual(config.test_paths, ["tests"]) self.assertEqual(config.analysis_paths, ["analyses"]) self.assertEqual( - set(config.docs_paths), set(["models", "seeds", "snapshots", "analyses", "macros"]) + set(config.docs_paths), {"models", "seeds", "snapshots", "analyses", "macros", "tests"} ) self.assertEqual(config.asset_paths, []) self.assertEqual(config.target_path, "target")