Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile on-run-(start|end) hooks to file #412

Merged
merged 7 commits into from
May 9, 2017
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dbt.utils
import dbt.include
import dbt.wrapper
import dbt.tracking

from dbt.model import Model
from dbt.utils import This, Var, is_enabled, get_materialization, NodeType, \
Expand Down Expand Up @@ -38,6 +39,7 @@ def print_compile_stats(stats):
NodeType.Archive: 'archives',
NodeType.Analysis: 'analyses',
NodeType.Macro: 'macros',
NodeType.Operation: 'operations',
}

results = {
Expand All @@ -46,6 +48,7 @@ def print_compile_stats(stats):
NodeType.Archive: 0,
NodeType.Analysis: 0,
NodeType.Macro: 0,
NodeType.Operation: 0,
}

results.update(stats)
Expand Down Expand Up @@ -235,8 +238,8 @@ def get_compiler_context(self, model, flat_graph):

context.update(wrapper.get_context_functions())

context['run_started_at'] = '{{ run_started_at }}'
context['invocation_id'] = '{{ invocation_id }}'
context['run_started_at'] = dbt.tracking.active_user.run_started_at
context['invocation_id'] = dbt.tracking.active_user.invocation_id
context['sql_now'] = adapter.date_function()

for unique_id, macro in flat_graph.get('macros').items():
Expand Down Expand Up @@ -280,7 +283,8 @@ def compile_node(self, node, flat_graph):
injected_node, _ = prepend_ctes(compiled_node, flat_graph)

if compiled_node.get('resource_type') in [NodeType.Test,
NodeType.Analysis]:
NodeType.Analysis,
NodeType.Operation]:
# data tests get wrapped in count(*)
# TODO : move this somewhere more reasonable
if 'data' in injected_node['tags'] and \
Expand Down Expand Up @@ -351,11 +355,13 @@ def link_graph(self, linker, flat_graph):
linked_graph = {
'nodes': {},
'macros': flat_graph.get('macros'),
'operations': flat_graph.get('operations'),
}

for name, node in flat_graph.get('nodes').items():
self.link_node(linker, node, flat_graph)
linked_graph['nodes'][name] = node
for node_type in ['nodes', 'operations']:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why you put operations in a separate subgraph

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, i think you're right -- these can in the nodes graph... let me investigate

for name, node in flat_graph.get(node_type).items():
self.link_node(linker, node, flat_graph)
linked_graph[node_type][name] = node

cycle = linker.find_cycles()

Expand Down Expand Up @@ -393,6 +399,17 @@ def get_parsed_macros(self, root_project, all_projects):

return parsed_macros

def get_parsed_operations(self, root_project, all_projects):
parsed_operations = {}

for name, project in all_projects.items():
parsed_operations.update(
dbt.parser.load_and_parse_run_hooks(root_project, all_projects, 'on-run-start'))
parsed_operations.update(
dbt.parser.load_and_parse_run_hooks(root_project, all_projects, 'on-run-end'))

return parsed_operations

def get_parsed_models(self, root_project, all_projects):
parsed_models = {}

Expand Down Expand Up @@ -456,6 +473,9 @@ def get_parsed_schema_tests(self, root_project, all_projects):
def load_all_macros(self, root_project, all_projects):
return self.get_parsed_macros(root_project, all_projects)

def load_all_operations(self, root_project, all_projects):
return self.get_parsed_operations(root_project, all_projects)

def load_all_nodes(self, root_project, all_projects):
all_nodes = {}

Expand All @@ -479,10 +499,12 @@ def compile(self):

all_macros = self.load_all_macros(root_project, all_projects)
all_nodes = self.load_all_nodes(root_project, all_projects)
all_operations = self.load_all_operations(root_project, all_projects)

flat_graph = {
'nodes': all_nodes,
'macros': all_macros
'macros': all_macros,
'operations': all_operations
}

flat_graph = dbt.parser.process_refs(flat_graph,
Expand All @@ -498,6 +520,9 @@ def compile(self):
for node_name, node in linked_graph.get('macros').items():
stats[node.get('resource_type')] += 1

for node_name, node in linked_graph.get('operations').items():
stats[node.get('resource_type')] += 1

print_compile_stats(stats)

return linked_graph, linker
3 changes: 2 additions & 1 deletion dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
unparsed_node_contract = unparsed_base_contract.extend({
Required('resource_type'): Any(NodeType.Model,
NodeType.Test,
NodeType.Analysis)
NodeType.Analysis,
NodeType.Operation)
})

unparsed_nodes_contract = Schema([unparsed_node_contract])
Expand Down
46 changes: 46 additions & 0 deletions dbt/parser.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,52 @@ def load_and_parse_sql(package_name, root_project, all_projects, root_dir,
return parse_sql_nodes(result, root_project, all_projects, tags)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like this file was chmod +xed -- can you undo that



def get_hooks_from_project(project_cfg, hook_type):
hooks = project_cfg.get(hook_type, [])

if type(hooks) not in (list, tuple):
hooks = [hooks]

return hooks


def get_hooks(all_projects, hook_type):
project_hooks = {}

for project_name, project in all_projects.items():
hooks = get_hooks_from_project(project, hook_type)

if len(hooks) > 0:
project_hooks[project_name] = ";\n".join(hooks)

return project_hooks


def load_and_parse_run_hooks(root_project, all_projects, hook_type):

if dbt.flags.STRICT_MODE:
dbt.contracts.project.validate_list(all_projects)

project_hooks = get_hooks(all_projects, hook_type)

result = []
for project_name, hooks in project_hooks.items():
project = all_projects[project_name]

hook_path = dbt.utils.get_pseudo_hook_path(hook_type)

result.append({
'name': hook_type,
'root_path': "{}/dbt_project.yml".format(project_name),
'resource_type': NodeType.Operation,
'path': hook_path,
'package_name': project_name,
'raw_sql': hooks
})

return parse_sql_nodes(result, root_project, all_projects, tags={hook_type})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great



def load_and_parse_macros(package_name, root_project, all_projects, root_dir,
relative_dirs, resource_type, tags=None):
extension = "[!.#~]*.sql"
Expand Down
82 changes: 40 additions & 42 deletions dbt/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import time
import itertools
from datetime import datetime

from dbt.adapters.factory import get_adapter
from dbt.logger import GLOBAL_LOGGER as logger
Expand Down Expand Up @@ -370,24 +369,13 @@ def execute_archive(profile, node, context):
return result


def run_hooks(profile, hooks, context, source):
if type(hooks) not in (list, tuple):
hooks = [hooks]

ctx = {
"target": profile,
"state": "start",
"invocation_id": context['invocation_id'],
"run_started_at": context['run_started_at']
}

compiled_hooks = [
dbt.clients.jinja.get_rendered(hook, ctx) for hook in hooks
]

def run_hooks(profile, hooks):
adapter = get_adapter(profile)

return adapter.execute_all(profile=profile, sqls=compiled_hooks)
master_connection = adapter.begin(profile)
compiled_hooks = [hook['wrapped_sql'] for hook in hooks]
adapter.execute_all(profile=profile, sqls=compiled_hooks)
master_connection = adapter.commit(master_connection)


def track_model_run(index, num_nodes, run_model_result):
Expand Down Expand Up @@ -461,10 +449,8 @@ def call_table_exists(schema, table):
return adapter.table_exists(
profile, schema, table, node.get('name'))

self.run_started_at = datetime.now()

return {
"run_started_at": datetime.now(),
"run_started_at": dbt.tracking.active_user.run_started_at,
"invocation_id": dbt.tracking.active_user.invocation_id,
"get_columns_in_table": call_get_columns_in_table,
"get_missing_columns": call_get_missing_columns,
Expand Down Expand Up @@ -513,7 +499,6 @@ def execute_node(self, node, flat_graph, existing, profile, adapter):
return node, result

def compile_node(self, node, flat_graph):

compiler = dbt.compilation.Compiler(self.project)
node = compiler.compile_node(node, flat_graph)
return node
Expand Down Expand Up @@ -687,12 +672,9 @@ def execute_nodes(self, flat_graph, node_dependency_list, on_failure,
start_time = time.time()

if should_run_hooks:
master_connection = adapter.begin(profile)
run_hooks(self.project.get_target(),
self.project.cfg.get('on-run-start', []),
self.node_context({}),
'on-run-start hooks')
master_connection = adapter.commit(master_connection)
start_hooks = dbt.utils.get_nodes_by_tags(flat_graph, {'on-run-start'}, "operations")
hooks = [self.compile_node(hook, flat_graph) for hook in start_hooks]
run_hooks(profile, hooks)

def get_idx(node):
return node_id_to_index_map.get(node.get('unique_id'))
Expand Down Expand Up @@ -739,12 +721,9 @@ def get_idx(node):
pool.join()

if should_run_hooks:
adapter.begin(profile)
run_hooks(self.project.get_target(),
self.project.cfg.get('on-run-end', []),
self.node_context({}),
'on-run-end hooks')
adapter.commit(master_connection)
end_hooks = dbt.utils.get_nodes_by_tags(flat_graph, {'on-run-end'}, "operations")
hooks = [self.compile_node(hook, flat_graph) for hook in end_hooks]
run_hooks(profile, hooks)

execution_time = time.time() - start_time

Expand All @@ -755,18 +734,35 @@ def get_idx(node):

def get_ancestor_ephemeral_nodes(self, flat_graph, linked_graph,
selected_nodes):
node_names = {
node: flat_graph['nodes'].get(node).get('name')
for node in selected_nodes
if node in flat_graph['nodes']
}

include_spec = [
'+{}'.format(node_names[node])
for node in selected_nodes if node in node_names
]

all_ancestors = dbt.graph.selector.select_nodes(
self.project,
linked_graph,
['+{}'.format(flat_graph.get('nodes').get(node).get('name'))
for node in selected_nodes],
include_spec,
[])

return set([ancestor for ancestor in all_ancestors
if(flat_graph['nodes'][ancestor].get(
'resource_type') == NodeType.Model and
get_materialization(
flat_graph['nodes'][ancestor]) == 'ephemeral')])
res = []

for ancestor in all_ancestors:
if ancestor not in flat_graph['nodes']:
continue
ancestor_node = flat_graph['nodes'][ancestor]
is_model = ancestor_node.get('resource_type') == NodeType.Model
is_ephemeral = get_materialization(ancestor_node) == 'ephemeral'
if is_model and is_ephemeral:
res.append(ancestor)

return set(res)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


def get_nodes_to_run(self, graph, include_spec, exclude_spec,
resource_types, tags):
Expand Down Expand Up @@ -874,15 +870,17 @@ def compile_models(self, include_spec, exclude_spec):
NodeType.Model,
NodeType.Test,
NodeType.Archive,
NodeType.Analysis
NodeType.Analysis,
NodeType.Operation
]

return self.run_types_from_graph(include_spec,
exclude_spec,
resource_types=resource_types,
tags=set(),
should_run_hooks=False,
should_execute=False)
should_execute=False,
flatten_graph=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why flatten the graph here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debugging! good catch


def run_models(self, include_spec, exclude_spec):
return self.run_types_from_graph(include_spec,
Expand Down
6 changes: 3 additions & 3 deletions dbt/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dbt import version as dbt_version
from snowplow_tracker import Subject, Tracker, Emitter, logger as sp_logger
from snowplow_tracker import SelfDescribingJson, disable_contracts
from datetime import datetime

import platform
import uuid
Expand Down Expand Up @@ -42,16 +43,15 @@ def __init__(self):
self.do_not_track = True

self.id = None
self.invocation_id = None
self.invocation_id = str(uuid.uuid4())
self.run_started_at = datetime.now()

def state(self):
return "do not track" if self.do_not_track else "tracking"

def initialize(self):
self.do_not_track = False

self.invocation_id = str(uuid.uuid4())

cookie = self.get_cookie()
self.id = cookie.get('id')

Expand Down
15 changes: 15 additions & 0 deletions dbt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class NodeType(object):
Test = 'test'
Archive = 'archive'
Macro = 'macro'
Operation = 'operation'


class This(object):
Expand Down Expand Up @@ -263,6 +264,11 @@ def get_pseudo_test_path(node_name, source_path, test_type):
return os.path.join(*pseudo_path_parts)


def get_pseudo_hook_path(hook_name):
path_parts = ['hooks', "{}.sql".format(hook_name)]
return os.path.join(*path_parts)


def get_run_status_line(results):
total = len(results)
errored = len([r for r in results if r.errored or r.failed])
Expand All @@ -277,3 +283,12 @@ def get_run_status_line(results):
errored=errored,
skipped=skipped
))


def get_nodes_by_tags(flat_graph, match_tags, resource_type):
nodes = []
for node_name, node in flat_graph[resource_type].items():
node_tags = node.get('tags', set())
if len(node_tags & match_tags):
nodes.append(node)
return nodes
Loading