Skip to content

Commit

Permalink
Save selector dictionary and write out in manifest [#2693][#2800]
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank committed Nov 8, 2020
1 parent af3d668 commit 4947954
Show file tree
Hide file tree
Showing 15 changed files with 93 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### Features
- Added macro get_partitions_metadata(table) to return partition metadata for partitioned table [#2596](https://github.com/fishtown-analytics/dbt/pull/2596)
- Added native python 're' module for regex in jinja templates [#2851](https://github.com/fishtown-analytics/dbt/pull/2851)
- Save selectors dictionary to manifest, allow descriptions ([#2693](https://github.com/fishtown-analytics/dbt/issues/2693), [#2866](https://github.com/fishtown-analytics/dbt/pull/2866))

### Fixes
- Respect --project-dir in dbt clean command ([#2840](https://github.com/fishtown-analytics/dbt/issues/2840), [#2841](https://github.com/fishtown-analytics/dbt/pull/2841))
Expand Down
2 changes: 2 additions & 0 deletions core/dbt/config/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def read_user_config(directory: str) -> UserConfig:
return UserConfig()


# The Profile class is included in RuntimeConfig, so any attribute
# additions must also be set where the RuntimeConfig class is created
@dataclass
class Profile(HasCredentials):
profile_name: str
Expand Down
13 changes: 13 additions & 0 deletions core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,14 @@ def create_project(self, rendered: RenderComponents) -> 'Project':
query_comment = _query_comment_from_cfg(cfg.query_comment)

packages = package_config_from_data(rendered.packages_dict)
manifest_selectors: Dict[str, Any] = {}
if rendered.selectors_dict:
# this is a dict with a single key 'selectors' pointing to a list
# of dicts.
if rendered.selectors_dict['selectors']:
# for each selector dict, transform into 'name': { }
for sel in rendered.selectors_dict['selectors']:
manifest_selectors[sel['name']] = sel
selectors = selector_config_from_data(rendered.selectors_dict)

project = Project(
Expand Down Expand Up @@ -396,6 +404,7 @@ def create_project(self, rendered: RenderComponents) -> 'Project':
snapshots=snapshots,
dbt_version=dbt_version,
packages=packages,
manifest_selectors=manifest_selectors,
selectors=selectors,
query_comment=query_comment,
sources=sources,
Expand Down Expand Up @@ -458,6 +467,7 @@ def from_project_root(

class VarProvider:
"""Var providers are tied to a particular Project."""

def __init__(
self,
vars: Dict[str, Dict[str, Any]]
Expand All @@ -476,6 +486,8 @@ def to_dict(self):
return self.vars


# The Project class is included in RuntimeConfig, so any attribute
# additions must also be set where the RuntimeConfig class is created
@dataclass
class Project:
project_name: str
Expand Down Expand Up @@ -504,6 +516,7 @@ class Project:
vars: VarProvider
dbt_version: List[VersionSpecifier]
packages: Dict[str, Any]
manifest_selectors: Dict[str, Any]
selectors: SelectorConfig
query_comment: QueryComment
config_version: int
Expand Down
2 changes: 2 additions & 0 deletions core/dbt/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def from_parts(
snapshots=project.snapshots,
dbt_version=project.dbt_version,
packages=project.packages,
manifest_selectors=project.manifest_selectors,
selectors=project.selectors,
query_comment=project.query_comment,
sources=project.sources,
Expand Down Expand Up @@ -483,6 +484,7 @@ def from_parts(
snapshots=project.snapshots,
dbt_version=project.dbt_version,
packages=project.packages,
manifest_selectors=project.manifest_selectors,
selectors=project.selectors,
query_comment=project.query_comment,
sources=project.sources,
Expand Down
20 changes: 18 additions & 2 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ class Manifest:
macros: MutableMapping[str, ParsedMacro]
docs: MutableMapping[str, ParsedDocumentation]
exposures: MutableMapping[str, ParsedExposure]
selectors: MutableMapping[str, Any]
disabled: List[CompileResultNode]
files: MutableMapping[str, SourceFile]
metadata: ManifestMetadata = field(default_factory=ManifestMetadata)
Expand All @@ -461,6 +462,7 @@ def from_macros(
macros=macros,
docs={},
exposures={},
selectors={},
disabled=[],
files=files,
)
Expand Down Expand Up @@ -730,8 +732,9 @@ def deepcopy(self):
macros={k: _deepcopy(v) for k, v in self.macros.items()},
docs={k: _deepcopy(v) for k, v in self.docs.items()},
exposures={k: _deepcopy(v) for k, v in self.exposures.items()},
disabled=[_deepcopy(n) for n in self.disabled],
selectors=self.root_project.manifest_selectors,
metadata=self.metadata,
disabled=[_deepcopy(n) for n in self.disabled],
files={k: _deepcopy(v) for k, v in self.files.items()},
)

Expand All @@ -749,6 +752,7 @@ def writable_manifest(self):
macros=self.macros,
docs=self.docs,
exposures=self.exposures,
selectors=self.selectors,
metadata=self.metadata,
disabled=self.disabled,
child_map=forward_edges,
Expand Down Expand Up @@ -905,14 +909,21 @@ def merge_from_artifact(
f'Merged {len(merged)} items from state (sample: {sample})'
)

# provide support for copy.deepcopy() - we jsut need to avoid the lock!
# Provide support for copy.deepcopy() - we just need to avoid the lock!
# pickle and deepcopy use this. It returns a callable object used to
# create the initial version of the object and a tuple of arguments
# for the object, i.e. the Manifest.
# The order of the arguments must match the order of the attributes
# in the Manifest class declaration, because they are used as
# positional arguments to construct a Manifest.
def __reduce_ex__(self, protocol):
args = (
self.nodes,
self.sources,
self.macros,
self.docs,
self.exposures,
self.selectors,
self.disabled,
self.files,
self.metadata,
Expand Down Expand Up @@ -952,6 +963,11 @@ class WritableManifest(ArtifactMixin):
'The exposures defined in the dbt project and its dependencies'
))
)
selectors: Mapping[UniqueID, Any] = field(
metadata=dict(description=(
'The selectors defined in selectors.yml'
))
)
disabled: Optional[List[CompileResultNode]] = field(metadata=dict(
description='A list of the disabled nodes in the target'
))
Expand Down
1 change: 1 addition & 0 deletions core/dbt/contracts/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class SelectorDefinition(JsonSchemaMixin):
name: str
definition: Union[str, Dict[str, Any]]
description: str = ''


@dataclass
Expand Down
1 change: 1 addition & 0 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def create_manifest(self) -> Manifest:
metadata=self.root_project.get_metadata(),
disabled=disabled,
files=self.results.files,
selectors=self.root_project.manifest_selectors,
)
manifest.patch_nodes(self.results.patches)
manifest.patch_macros(self.results.macro_patches)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from test.integration.base import DBTIntegrationTest, use_profile
import yaml
import json
import os

class TestGraphSelection(DBTIntegrationTest):

Expand All @@ -10,6 +13,18 @@ def schema(self):
def models(self):
return "models"

@property
def selectors_config(self):
return yaml.safe_load('''
selectors:
- name: bi_selector
description: This is a BI selector
definition:
method: tag
value: bi
''')


def assert_correct_schemas(self):
with self.get_connection():
exists = self.adapter.check_schema_exists(
Expand Down Expand Up @@ -43,7 +58,7 @@ def test__postgres__specific_model(self):
def test__postgres__tags(self):
self.run_sql_file("seed.sql")

results = self.run_dbt(['run', '--models', 'tag:bi'])
results = self.run_dbt(['run', '--selector', 'bi_selector'])
self.assertEqual(len(results), 2)

created_models = self.get_models_in_schema()
Expand All @@ -52,6 +67,12 @@ def test__postgres__tags(self):
self.assertTrue('users' in created_models)
self.assertTrue('users_rollup' in created_models)
self.assert_correct_schemas()
self.assertTrue(os.path.exists('./target/manifest.json'))
with open('./target/manifest.json') as fp:
manifest = json.load(fp)
self.assertTrue('selectors' in manifest)



@use_profile('postgres')
def test__postgres__tags_and_children(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1503,6 +1503,7 @@ def expected_seeded_manifest(self, model_database=None):
'maturity': None,
}
},
'selectors': {},
'parent_map': {
'model.test.model': ['seed.test.seed'],
'model.test.second_model': ['seed.test.seed'],
Expand Down Expand Up @@ -1825,6 +1826,7 @@ def expected_postgres_references_manifest(self, model_database=None):
},
},
'exposures': {},
'selectors': {},
'docs': {
'dbt.__overview__': ANY,
'test.column_info': {
Expand Down Expand Up @@ -2322,6 +2324,7 @@ def expected_bigquery_complex_manifest(self):
},
'sources': {},
'exposures': {},
'selectors': {},
'child_map': {
'model.test.clustered': [],
'model.test.multi_clustered': [],
Expand Down Expand Up @@ -2533,6 +2536,7 @@ def expected_redshift_incremental_view_manifest(self):
},
'sources': {},
'exposures': {},
'selectors': {},
'parent_map': {
'model.test.model': ['seed.test.seed'],
'seed.test.seed': []
Expand Down Expand Up @@ -2572,7 +2576,7 @@ def verify_manifest(self, expected_manifest):

manifest_keys = frozenset({
'nodes', 'sources', 'macros', 'parent_map', 'child_map',
'docs', 'metadata', 'docs', 'disabled', 'exposures'
'docs', 'metadata', 'docs', 'disabled', 'exposures', 'selectors',
})

self.assertEqual(frozenset(manifest), manifest_keys)
Expand Down
5 changes: 4 additions & 1 deletion test/rpc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,11 +402,14 @@ def get_manifest(self, request_id=1):
)

def is_result(self, data: Dict[str, Any], id=None) -> Dict[str, Any]:

if id is not None:
assert data['id'] == id
assert data['jsonrpc'] == '2.0'
assert 'result' in data
if 'error' in data:
print(data['error']['message'])
assert 'error' not in data
assert 'result' in data
return data['result']

def is_async_result(self, data: Dict[str, Any], id=None) -> str:
Expand Down
5 changes: 5 additions & 0 deletions test/unit/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def test__prepend_ctes__already_has_cte(self):
disabled=[],
files={},
exposures={},
selectors={},
)

compiler = dbt.compilation.Compiler(self.config)
Expand Down Expand Up @@ -236,6 +237,7 @@ def test__prepend_ctes__no_ctes(self):
disabled=[],
files={},
exposures={},
selectors={},
)

compiler = dbt.compilation.Compiler(self.config)
Expand Down Expand Up @@ -327,6 +329,7 @@ def test__prepend_ctes(self):
disabled=[],
files={},
exposures={},
selectors={},
)

compiler = dbt.compilation.Compiler(self.config)
Expand Down Expand Up @@ -430,6 +433,7 @@ def test__prepend_ctes__cte_not_compiled(self):
disabled=[],
files={},
exposures={},
selectors={},
)

compiler = dbt.compilation.Compiler(self.config)
Expand Down Expand Up @@ -534,6 +538,7 @@ def test__prepend_ctes__multiple_levels(self):
disabled=[],
files={},
exposures={},
selectors={},
)

compiler = dbt.compilation.Compiler(self.config)
Expand Down
1 change: 1 addition & 0 deletions test/unit/test_graph_selector_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def manifest(seed, source, ephemeral_model, view_model, table_model, ext_source,
files={},
exposures={},
disabled=[],
selectors={},
)
return manifest

Expand Down
2 changes: 2 additions & 0 deletions test/unit/test_graph_selector_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,13 @@ def test_parse_simple():
sf = parse_file('''\
selectors:
- name: tagged_foo
description: Selector for foo-tagged models
definition:
tag: foo
''')

assert len(sf.selectors) == 1
assert sf.selectors[0].description == 'Selector for foo-tagged models'
parsed = cli.parse_from_selectors_definition(sf)
assert len(parsed) == 1
assert 'tagged_foo' in parsed
Expand Down
Loading

0 comments on commit 4947954

Please sign in to comment.