diff --git a/.changes/unreleased/Features-20240531-150816.yaml b/.changes/unreleased/Features-20240531-150816.yaml new file mode 100644 index 00000000000..ebe69c0c5e3 --- /dev/null +++ b/.changes/unreleased/Features-20240531-150816.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Update data_test to accept arbitrary config options +time: 2024-05-31T15:08:16.431966-05:00 +custom: + Author: McKnight-42 + Issue: "10197" diff --git a/core/dbt/parser/generic_test_builders.py b/core/dbt/parser/generic_test_builders.py index 8a4864be82e..6bca8300dae 100644 --- a/core/dbt/parser/generic_test_builders.py +++ b/core/dbt/parser/generic_test_builders.py @@ -114,7 +114,8 @@ def __init__( self.package_name: str = package_name self.target: Testable = target self.version: Optional[NodeVersion] = version - + self.render_ctx: Dict[str, Any] = render_ctx + self.column_name: Optional[str] = column_name self.args["model"] = self.build_model_str() match = self.TEST_NAME_PATTERN.match(test_name) @@ -125,39 +126,12 @@ def __init__( self.name: str = groups["test_name"] self.namespace: str = groups["test_namespace"] self.config: Dict[str, Any] = {} + # Process legacy args + self.config.update(self._process_legacy_args()) - # This code removes keys identified as config args from the test entry - # dictionary. The keys remaining in the 'args' dictionary will be - # "kwargs", or keyword args that are passed to the test macro. - # The "kwargs" are not rendered into strings until compilation time. - # The "configs" are rendered here (since they were not rendered back - # in the 'get_key_dicts' methods in the schema parsers). - for key in self.CONFIG_ARGS: - value = self.args.pop(key, None) - # 'modifier' config could be either top level arg or in config - if value and "config" in self.args and key in self.args["config"]: - raise SameKeyNestedError() - if not value and "config" in self.args: - value = self.args["config"].pop(key, None) - if isinstance(value, str): - - try: - value = get_rendered(value, render_ctx, native=True) - except UndefinedMacroError as e: - - raise CustomMacroPopulatingConfigValueError( - target_name=self.target.name, - column_name=column_name, - name=self.name, - key=key, - err_msg=e.msg, - ) - - if value is not None: - self.config[key] = value - + # Process config args if present if "config" in self.args: - del self.args["config"] + self.config.update(self._render_values(self.args.pop("config", {}))) if self.namespace is not None: self.package_name = self.namespace @@ -182,6 +156,36 @@ def __init__( if short_name != full_name and "alias" not in self.config: self.config["alias"] = short_name + def _process_legacy_args(self): + config = {} + for key in self.CONFIG_ARGS: + value = self.args.pop(key, None) + if value and "config" in self.args and key in self.args["config"]: + raise SameKeyNestedError() + if not value and "config" in self.args: + value = self.args["config"].pop(key, None) + config[key] = value + + return self._render_values(config) + + def _render_values(self, config: Dict[str, Any]) -> Dict[str, Any]: + rendered_config = {} + for key, value in config.items(): + if isinstance(value, str): + try: + value = get_rendered(value, self.render_ctx, native=True) + except UndefinedMacroError as e: + raise CustomMacroPopulatingConfigValueError( + target_name=self.target.name, + column_name=self.column_name, + name=self.name, + key=key, + err_msg=e.msg, + ) + if value is not None: + rendered_config[key] = value + return rendered_config + def _bad_type(self) -> TypeError: return TypeError('invalid target type "{}"'.format(type(self.target))) diff --git a/tests/functional/schema_tests/data_test_config.py b/tests/functional/schema_tests/data_test_config.py new file mode 100644 index 00000000000..377f14aac04 --- /dev/null +++ b/tests/functional/schema_tests/data_test_config.py @@ -0,0 +1,115 @@ +import re + +import pytest + +from dbt.exceptions import CompilationError +from dbt.tests.util import get_manifest, run_dbt +from tests.functional.schema_tests.fixtures import ( + custom_config_yml, + mixed_config_yml, + same_key_error_yml, + seed_csv, + table_sql, +) + + +class BaseDataTestsConfig: + @pytest.fixture(scope="class") + def seeds(self): + return {"seed.csv": seed_csv} + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "config-version": 2, + } + + @pytest.fixture(scope="class", autouse=True) + def setUp(self, project): + run_dbt(["seed"]) + + +class TestCustomDataTestConfig(BaseDataTestsConfig): + @pytest.fixture(scope="class") + def models(self): + return {"table.sql": table_sql, "custom_config.yml": custom_config_yml} + + def test_custom_config(self, project): + run_dbt(["parse"]) + manifest = get_manifest(project.project_root) + + # Pattern to match the test_id without the specific suffix + pattern = re.compile(r"test\.test\.accepted_values_table_color__blue__red\.\d+") + + # Find the test_id dynamically + test_id = None + for node_id in manifest.nodes: + if pattern.match(node_id): + test_id = node_id + break + + # Ensure the test_id was found + assert ( + test_id is not None + ), "Test ID matching the pattern was not found in the manifest nodes" + + # Proceed with the assertions + test_node = manifest.nodes[test_id] + assert "custom_config_key" in test_node.config + assert test_node.config["custom_config_key"] == "some_value" + + +class TestMixedDataTestConfig(BaseDataTestsConfig): + @pytest.fixture(scope="class") + def models(self): + return {"table.sql": table_sql, "mixed_config.yml": mixed_config_yml} + + def test_mixed_config(self, project): + run_dbt(["parse"]) + manifest = get_manifest(project.project_root) + + # Pattern to match the test_id without the specific suffix + pattern = re.compile(r"test\.test\.accepted_values_table_color__blue__red\.\d+") + + # Find the test_id dynamically + test_id = None + for node_id in manifest.nodes: + if pattern.match(node_id): + test_id = node_id + break + + # Ensure the test_id was found + assert ( + test_id is not None + ), "Test ID matching the pattern was not found in the manifest nodes" + + # Proceed with the assertions + test_node = manifest.nodes[test_id] + assert "custom_config_key" in test_node.config + assert test_node.config["custom_config_key"] == "some_value" + assert "severity" in test_node.config + assert test_node.config["severity"] == "warn" + + +class TestSameKeyErrorDataTestConfig: + @pytest.fixture(scope="class") + def models(self): + return {"table.sql": table_sql, "same_key_error.yml": same_key_error_yml} + + def test_same_key_error(self, project): + """ + Test that verifies dbt raises a CompilationError when the test configuration + contains the same key at the top level and inside the config dictionary. + """ + # Run dbt and expect a CompilationError due to the invalid configuration + with pytest.raises(CompilationError) as exc_info: + run_dbt(["parse"]) + + # Extract the exception message + exception_message = str(exc_info.value) + + # Assert that the error message contains the expected text + assert "Test cannot have the same key at the top-level and in config" in exception_message + + # Assert that the error message contains the context of the error + assert "models/same_key_error.yml" in exception_message diff --git a/tests/functional/schema_tests/fixtures.py b/tests/functional/schema_tests/fixtures.py index 51ae067bd84..bf16148e0c7 100644 --- a/tests/functional/schema_tests/fixtures.py +++ b/tests/functional/schema_tests/fixtures.py @@ -1273,3 +1273,63 @@ data_tests: - my_custom_test """ + +custom_config_yml = """ +version: 2 +models: + - name: table + columns: + - name: color + data_tests: + - accepted_values: + values: ['blue', 'red'] + config: + custom_config_key: some_value +""" + +mixed_config_yml = """ +version: 2 +models: + - name: table + columns: + - name: color + data_tests: + - accepted_values: + values: ['blue', 'red'] + severity: warn + config: + custom_config_key: some_value +""" + +same_key_error_yml = """ +version: 2 +models: + - name: table + columns: + - name: color + data_tests: + - accepted_values: + values: ['blue', 'red'] + severity: warn + config: + severity: error +""" + +seed_csv = """ +id,color,value +1,blue,10 +2,red,20 +3,green,30 +4,yellow,40 +5,blue,50 +6,red,60 +7,blue,70 +8,green,80 +9,yellow,90 +10,blue,100 +""" + +table_sql = """ +-- content of the table.sql +select * from {{ ref('seed') }} +"""