Skip to content

Commit

Permalink
Initial implementation of unit testing (from pr #2911)
Browse files Browse the repository at this point in the history
Co-authored-by: Michelle Ark <michelle.ark@dbtlabs.com>
  • Loading branch information
gshank and MichelleArk committed Aug 14, 2023
1 parent b045180 commit 181f520
Show file tree
Hide file tree
Showing 24 changed files with 912 additions and 14 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230802-145011.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Initial implementation of unit testing
time: 2023-08-02T14:50:11.391992-04:00
custom:
Author: gshank
Issue: "8287"
2 changes: 1 addition & 1 deletion core/dbt/adapters/relation_configs/config_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class RelationConfigChangeAction(StrEnum):
drop = "drop"


@dataclass(frozen=True, eq=True, unsafe_hash=True)
@dataclass(frozen=True, eq=True, unsafe_hash=True) # type: ignore
class RelationConfigChange(RelationConfigBase, ABC):
action: RelationConfigChangeAction
context: Hashable # this is usually a RelationConfig, e.g. IndexConfig, but shouldn't be limited
Expand Down
1 change: 1 addition & 0 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def command_args(command: CliCommand) -> ArgsList:
CliCommand.SOURCE_FRESHNESS: cli.freshness,
CliCommand.TEST: cli.test,
CliCommand.RETRY: cli.retry,
CliCommand.UNIT_TEST: cli.unit_test,
}
click_cmd: Optional[ClickCommand] = CMD_DICT.get(command, None)
if click_cmd is None:
Expand Down
47 changes: 47 additions & 0 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from dbt.task.show import ShowTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask
from dbt.task.unit_test import UnitTestTask


@dataclass
Expand Down Expand Up @@ -847,6 +848,52 @@ def test(ctx, **kwargs):
return results, success


# dbt unit-test
@cli.command("unit-test")
@click.pass_context
@p.defer
@p.deprecated_defer
@p.exclude
@p.fail_fast
@p.favor_state
@p.deprecated_favor_state
@p.indirect_selection
@p.show_output_format
@p.profile
@p.profiles_dir
@p.project_dir
@p.select
@p.selector
@p.state
@p.defer_state
@p.deprecated_state
@p.store_failures
@p.target
@p.target_path
@p.threads
@p.vars
@p.version_check
@requires.postflight
@requires.preflight
@requires.profile
@requires.project
@requires.runtime_config
@requires.manifest
@requires.unit_test_collection
def unit_test(ctx, **kwargs):
"""Runs tests on data in deployed models. Run this after `dbt run`"""
task = UnitTestTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
ctx.obj["manifest"],
ctx.obj["unit_test_collection"],
)

results = task.run()
success = task.interpret_results(results)
return results, success


# Support running as a module
if __name__ == "__main__":
cli()
23 changes: 23 additions & 0 deletions core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from dbt.profiler import profiler
from dbt.tracking import active_user, initialize_from_flags, track_run
from dbt.utils import cast_dict_to_dict_of_strings
from dbt.parser.unit_tests import UnitTestManifestLoader
from dbt.plugins import set_up_plugin_manager, get_plugin_manager

from click import Context
Expand Down Expand Up @@ -265,3 +266,25 @@ def wrapper(*args, **kwargs):
if len(args0) == 0:
return outer_wrapper
return outer_wrapper(args0[0])


def unit_test_collection(func):
"""A decorator used by click command functions for generating a unit test collection provided a manifest"""

def wrapper(*args, **kwargs):
ctx = args[0]
assert isinstance(ctx, Context)

req_strs = ["manifest", "runtime_config"]
reqs = [ctx.obj.get(req_str) for req_str in req_strs]

if None in reqs:
raise DbtProjectError("manifest and runtime_config required for unit_test_collection")

collection = UnitTestManifestLoader.load(ctx.obj["manifest"], ctx.obj["runtime_config"])

ctx.obj["unit_test_collection"] = collection

return func(*args, **kwargs)

return update_wrapper(wrapper, func)
1 change: 1 addition & 0 deletions core/dbt/cli/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class Command(Enum):
SOURCE_FRESHNESS = "freshness"
TEST = "test"
RETRY = "retry"
UNIT_TEST = "unit-test"

@classmethod
def from_str(cls, s: str) -> "Command":
Expand Down
20 changes: 20 additions & 0 deletions core/dbt/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,26 @@ def __call__(self, *args, **kwargs):
return self.call_macro(*args, **kwargs)


class UnitTestMacroGenerator(MacroGenerator):
# this makes UnitTestMacroGenerator objects callable like functions
def __init__(
self,
macro_generator: MacroGenerator,
call_return_value: Any,
) -> None:
super().__init__(
macro_generator.macro,
macro_generator.context,
macro_generator.node,
macro_generator.stack,
)
self.call_return_value = call_return_value

def __call__(self, *args, **kwargs):
with self.track_call():
return self.call_return_value


class QueryStringGenerator(BaseMacroGenerator):
def __init__(self, template_str: str, context: Dict[str, Any]) -> None:
super().__init__(context)
Expand Down
13 changes: 10 additions & 3 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from dbt.adapters.factory import get_adapter
from dbt.clients import jinja
from dbt.clients.system import make_directory
from dbt.context.providers import generate_runtime_model_context
from dbt.context.providers import (
generate_runtime_model_context,
generate_runtime_unit_test_context,
)
from dbt.contracts.graph.manifest import Manifest, UniqueID
from dbt.contracts.graph.nodes import (
ManifestNode,
Expand All @@ -21,6 +24,7 @@
GraphMemberNode,
InjectedCTE,
SeedNode,
UnitTestNode,
)
from dbt.exceptions import (
GraphDependencyNotFoundError,
Expand All @@ -44,6 +48,7 @@ def print_compile_stats(stats):
names = {
NodeType.Model: "model",
NodeType.Test: "test",
NodeType.Unit: "unit test",
NodeType.Snapshot: "snapshot",
NodeType.Analysis: "analysis",
NodeType.Macro: "macro",
Expand Down Expand Up @@ -289,8 +294,10 @@ def _create_node_context(
manifest: Manifest,
extra_context: Dict[str, Any],
) -> Dict[str, Any]:

context = generate_runtime_model_context(node, self.config, manifest)
if isinstance(node, UnitTestNode):
context = generate_runtime_unit_test_context(node, self.config, manifest)
else:
context = generate_runtime_model_context(node, self.config, manifest)
context.update(extra_context)

if isinstance(node, GenericTestNode):
Expand Down
82 changes: 79 additions & 3 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
from copy import deepcopy
import os
from typing import (
Callable,
Expand All @@ -17,7 +18,7 @@
from dbt.adapters.base.column import Column
from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names
from dbt.clients import agate_helper
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack, UnitTestMacroGenerator
from dbt.config import RuntimeConfig, Project
from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER
from dbt.context.base import contextmember, contextproperty, Var
Expand All @@ -39,6 +40,7 @@
RefArgs,
AccessType,
SemanticModel,
UnitTestNode,
)
from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference
from dbt.contracts.graph.unparsed import NodeVersion
Expand Down Expand Up @@ -566,6 +568,17 @@ def create_relation(self, target_model: ManifestNode) -> RelationProxy:
return super().create_relation(target_model)


class RuntimeUnitTestRefResolver(RuntimeRefResolver):
def resolve(
self,
target_name: str,
target_package: Optional[str] = None,
target_version: Optional[NodeVersion] = None,
) -> RelationProxy:
target_name = f"{self.model.name}__{target_name}"
return super().resolve(target_name, target_package, target_version)


# `source` implementations
class ParseSourceResolver(BaseSourceResolver):
def resolve(self, source_name: str, table_name: str):
Expand Down Expand Up @@ -670,6 +683,22 @@ class RuntimeVar(ModelConfiguredVar):
pass


class UnitTestVar(RuntimeVar):
def __init__(
self,
context: Dict[str, Any],
config: RuntimeConfig,
node: Resource,
) -> None:
config_copy = None
assert isinstance(node, UnitTestNode)
if node.overrides and node.overrides.vars:
config_copy = deepcopy(config)
config_copy.cli_vars.update(node.overrides.vars)

super().__init__(context, config_copy or config, node=node)


# Providers
class Provider(Protocol):
execute: bool
Expand Down Expand Up @@ -711,6 +740,16 @@ class RuntimeProvider(Provider):
metric = RuntimeMetricResolver


class RuntimeUnitTestProvider(Provider):
execute = True
Config = RuntimeConfigObject
DatabaseWrapper = RuntimeDatabaseWrapper
Var = UnitTestVar
ref = RuntimeUnitTestRefResolver
source = RuntimeSourceResolver # TODO: RuntimeUnitTestSourceResolver
metric = RuntimeMetricResolver


class OperationProvider(RuntimeProvider):
ref = OperationRefResolver

Expand Down Expand Up @@ -1360,7 +1399,7 @@ class ModelContext(ProviderContext):

@contextproperty
def pre_hooks(self) -> List[Dict[str, Any]]:
if self.model.resource_type in [NodeType.Source, NodeType.Test]:
if self.model.resource_type in [NodeType.Source, NodeType.Test, NodeType.Unit]:
return []
# TODO CT-211
return [
Expand All @@ -1369,7 +1408,7 @@ def pre_hooks(self) -> List[Dict[str, Any]]:

@contextproperty
def post_hooks(self) -> List[Dict[str, Any]]:
if self.model.resource_type in [NodeType.Source, NodeType.Test]:
if self.model.resource_type in [NodeType.Source, NodeType.Test, NodeType.Unit]:
return []
# TODO CT-211
return [
Expand Down Expand Up @@ -1462,6 +1501,25 @@ def defer_relation(self) -> Optional[RelationProxy]:
return None


class UnitTestContext(ModelContext):
model: UnitTestNode

@contextmember
def env_var(self, var: str, default: Optional[str] = None) -> str:
"""The env_var() function. Return the overriden unit test environment variable named 'var'.
If there is no unit test override, return the environment variable named 'var'.
If there is no such environment variable set, return the default.
If the default is None, raise an exception for an undefined variable.
"""
if self.model.overrides and var in self.model.overrides.env_vars:
return self.model.overrides.env_vars[var]
else:
return super().env_var(var, default)


# This is called by '_context_for', used in 'render_with_context'
def generate_parser_model_context(
model: ManifestNode,
Expand Down Expand Up @@ -1506,6 +1564,24 @@ def generate_runtime_macro_context(
return ctx.to_dict()


def generate_runtime_unit_test_context(
unit_test: UnitTestNode,
config: RuntimeConfig,
manifest: Manifest,
) -> Dict[str, Any]:
ctx = UnitTestContext(unit_test, config, manifest, RuntimeUnitTestProvider(), None)
ctx_dict = ctx.to_dict()

if unit_test.overrides and unit_test.overrides.macros:
for macro_name, macro_value in unit_test.overrides.macros.items():
context_value = ctx_dict.get(macro_name)
if isinstance(context_value, MacroGenerator):
ctx_dict[macro_name] = UnitTestMacroGenerator(context_value, macro_value)
else:
ctx_dict[macro_name] = macro_value
return ctx_dict


class ExposureRefResolver(BaseResolver):
def __call__(self, *args, **kwargs) -> str:
package = None
Expand Down
1 change: 1 addition & 0 deletions core/dbt/contracts/graph/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ def finalize_and_validate(self):
NodeType.Source: SourceConfig,
NodeType.Seed: SeedConfig,
NodeType.Test: TestConfig,
NodeType.Unit: TestConfig,
NodeType.Model: NodeConfig,
NodeType.Snapshot: SnapshotConfig,
}
Expand Down
Loading

0 comments on commit 181f520

Please sign in to comment.