Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dbt 0.18.0 support #103

Merged
merged 7 commits into from
Sep 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bumpversion-dbt.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.17.2
current_version = 0.18.0rc1
parse = (?P<major>\d+)
\.(?P<minor>\d+)
\.(?P<patch>\d+)
Expand Down
3 changes: 2 additions & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.17.2
current_version = 0.18.0rc1
parse = (?P<major>\d+)
\.(?P<minor>\d+)
\.(?P<patch>\d+)
Expand Down Expand Up @@ -27,3 +27,4 @@ first_value = 1
first_value = 1

[bumpversion:file:dbt/adapters/spark/__version__.py]

2 changes: 1 addition & 1 deletion dbt/adapters/spark/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "0.17.2"
version = "0.18.0rc1"
85 changes: 73 additions & 12 deletions dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def get_properties(self, relation: Relation) -> Dict[str, str]:

def get_catalog(self, manifest):
schema_map = self._get_catalog_schemas(manifest)
if len(schema_map) != 1:
if len(schema_map) > 1:
dbt.exceptions.raise_compiler_error(
f'Expected only one database in get_catalog, found '
f'{list(schema_map)}'
Expand All @@ -232,7 +232,8 @@ def get_catalog(self, manifest):
futures: List[Future[agate.Table]] = []
for info, schemas in schema_map.items():
for schema in schemas:
futures.append(tpe.submit(
futures.append(tpe.submit_connected(
self, schema,
self._get_one_catalog, info, [schema], manifest
))
catalogs, exceptions = catch_as_completed(futures)
Expand All @@ -241,8 +242,6 @@ def get_catalog(self, manifest):
def _get_one_catalog(
self, information_schema, schemas, manifest,
) -> agate.Table:
name = f'{information_schema.database}.information_schema'

if len(schemas) != 1:
dbt.exceptions.raise_compiler_error(
f'Expected only one schema in spark _get_one_catalog, found '
Expand All @@ -252,14 +251,13 @@ def _get_one_catalog(
database = information_schema.database
schema = list(schemas)[0]

with self.connection_named(name):
columns: List[Dict[str, Any]] = []
for relation in self.list_relations(database, schema):
logger.debug("Getting table schema for relation {}", relation)
columns.extend(self._get_columns_for_catalog(relation))
return agate.Table.from_object(
columns, column_types=DEFAULT_TYPE_TESTER
)
columns: List[Dict[str, Any]] = []
for relation in self.list_relations(database, schema):
logger.debug("Getting table schema for relation {}", relation)
columns.extend(self._get_columns_for_catalog(relation))
return agate.Table.from_object(
columns, column_types=DEFAULT_TYPE_TESTER
)

def check_schema_exists(self, database, schema):
results = self.execute_macro(
Expand All @@ -269,3 +267,66 @@ def check_schema_exists(self, database, schema):

exists = True if schema in [row[0] for row in results] else False
return exists

def get_rows_different_sql(
self,
relation_a: BaseRelation,
relation_b: BaseRelation,
column_names: Optional[List[str]] = None,
except_operator: str = 'EXCEPT',
) -> str:
"""Generate SQL for a query that returns a single row with a two
columns: the number of rows that are different between the two
relations and the number of mismatched rows.
"""
# This method only really exists for test reasons.
names: List[str]
if column_names is None:
columns = self.get_columns_in_relation(relation_a)
names = sorted((self.quote(c.name) for c in columns))
else:
names = sorted((self.quote(n) for n in column_names))
columns_csv = ', '.join(names)

sql = COLUMNS_EQUAL_SQL.format(
columns=columns_csv,
relation_a=str(relation_a),
relation_b=str(relation_b),
)

return sql


# spark does something interesting with joins when both tables have the same
# static values for the join condition and complains that the join condition is
# "trivial". Which is true, though it seems like an unreasonable cause for
# failure! It also doesn't like the `from foo, bar` syntax as opposed to
# `from foo cross join bar`.
COLUMNS_EQUAL_SQL = '''
with diff_count as (
SELECT
1 as id,
COUNT(*) as num_missing FROM (
(SELECT {columns} FROM {relation_a} EXCEPT
SELECT {columns} FROM {relation_b})
UNION ALL
(SELECT {columns} FROM {relation_b} EXCEPT
SELECT {columns} FROM {relation_a})
) as a
), table_a as (
SELECT COUNT(*) as num_rows FROM {relation_a}
), table_b as (
SELECT COUNT(*) as num_rows FROM {relation_b}
), row_count_diff as (
select
1 as id,
table_a.num_rows - table_b.num_rows as difference
from table_a
cross join table_b
)
select
row_count_diff.difference as row_count_difference,
diff_count.num_missing as num_mismatched
from row_count_diff
cross join diff_count
'''.strip()
43 changes: 37 additions & 6 deletions dbt/include/spark/macros/materializations/seed.sql
Original file line number Diff line number Diff line change
@@ -1,32 +1,34 @@
{% macro spark__load_csv_rows(model, agate_table) %}
{% set batch_size = 1000 %}
{% set cols_sql = ", ".join(agate_table.column_names) %}
{% set bindings = [] %}

{% set statements = [] %}

{% for chunk in agate_table.rows | batch(batch_size) %}
{% set bindings = [] %}

{% for row in chunk %}
{% set _ = bindings.extend(row) %}
{% do bindings.extend(row) %}
{% endfor %}

{% set sql %}
insert into {{ this.render() }} values
{% for row in chunk -%}
({%- for column in agate_table.column_names -%}
({%- for column in agate_table.columns -%}
{%- if 'ISODate' in (column.data_type | string) -%}
cast(%s as timestamp)
{%- else -%}
%s
{%- endif -%}
{%- if not loop.last%},{%- endif %}
{%- endfor -%})
{%- if not loop.last%},{%- endif %}
{%- endfor %}
{% endset %}

{% set _ = adapter.add_query(sql, bindings=bindings, abridge_sql_log=True) %}
{% do adapter.add_query(sql, bindings=bindings, abridge_sql_log=True) %}

{% if loop.index0 == 0 %}
{% set _ = statements.append(sql) %}
{% do statements.append(sql) %}
{% endif %}
{% endfor %}

Expand All @@ -42,6 +44,35 @@
{{ return(sql) }}
{% endmacro %}


{% macro spark__create_csv_table(model, agate_table) %}
{%- set column_override = model['config'].get('column_types', {}) -%}
{%- set quote_seed_column = model['config'].get('quote_columns', None) -%}

{% set sql %}
create table {{ this.render() }} (
{%- for col_name in agate_table.column_names -%}
{%- set inferred_type = adapter.convert_type(agate_table, loop.index0) -%}
{%- set type = column_override.get(col_name, inferred_type) -%}
{%- set column_name = (col_name | string) -%}
{{ adapter.quote_seed_column(column_name, quote_seed_column) }} {{ type }} {%- if not loop.last -%}, {%- endif -%}
{%- endfor -%}
)
{{ file_format_clause() }}
{{ partition_cols(label="partitioned by") }}
{{ clustered_cols(label="clustered by") }}
{{ location_clause() }}
{{ comment_clause() }}
{% endset %}

{% call statement('_') -%}
{{ sql }}
{%- endcall %}

{{ return(sql) }}
{% endmacro %}


{% materialization seed, adapter='spark' %}

{%- set identifier = model['alias'] -%}
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
dbt-core==0.17.2
dbt-core==0.18.0rc1
PyHive[hive]>=0.6.0,<0.7.0
thrift>=0.11.0,<0.12.0
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def _dbt_spark_version():
package_version = _dbt_spark_version()
description = """The SparkSQL plugin for dbt (data build tool)"""

dbt_version = '0.17.2'
dbt_version = '0.18.0rc1'
# the package version should be the dbt version, with maybe some things on the
# ends of it. (0.17.2 vs 0.17.2a1, 0.17.2.1, ...)
# ends of it. (0.18.0rc1 vs 0.18.0rc1a1, 0.18.0rc1.1, ...)
if not package_version.startswith(dbt_version):
raise ValueError(
f'Invalid setup.py: package_version={package_version} must start with '
Expand Down
21 changes: 14 additions & 7 deletions test/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def profile_from_dict(profile, profile_name, cli_vars='{}'):
from dbt.config import Profile
from dbt.config.renderer import ProfileRenderer
from dbt.context.base import generate_base_context
from dbt.utils import parse_cli_vars
from dbt.config.utils import parse_cli_vars
if not isinstance(cli_vars, dict):
cli_vars = parse_cli_vars(cli_vars)

Expand All @@ -50,11 +50,11 @@ def profile_from_dict(profile, profile_name, cli_vars='{}'):
)


def project_from_dict(project, profile, packages=None, cli_vars='{}'):
def project_from_dict(project, profile, packages=None, selectors=None, cli_vars='{}'):
from dbt.context.target import generate_target_context
from dbt.config import Project
from dbt.config.renderer import DbtProjectYamlRenderer
from dbt.utils import parse_cli_vars
from dbt.config.utils import parse_cli_vars
if not isinstance(cli_vars, dict):
cli_vars = parse_cli_vars(cli_vars)

Expand All @@ -63,11 +63,11 @@ def project_from_dict(project, profile, packages=None, cli_vars='{}'):
project_root = project.pop('project-root', os.getcwd())

return Project.render_from_dict(
project_root, project, packages, renderer
project_root, project, packages, selectors, renderer
)


def config_from_parts_or_dicts(project, profile, packages=None, cli_vars='{}'):
def config_from_parts_or_dicts(project, profile, packages=None, selectors=None, cli_vars='{}'):
from dbt.config import Project, Profile, RuntimeConfig
from copy import deepcopy

Expand All @@ -88,6 +88,7 @@ def config_from_parts_or_dicts(project, profile, packages=None, cli_vars='{}'):
deepcopy(project),
profile,
packages,
selectors,
cli_vars,
)

Expand All @@ -101,14 +102,20 @@ def config_from_parts_or_dicts(project, profile, packages=None, cli_vars='{}'):
)


def inject_adapter(value):
def inject_plugin(plugin):
from dbt.adapters.factory import FACTORY
key = plugin.adapter.type()
FACTORY.plugins[key] = plugin


def inject_adapter(value, plugin):
"""Inject the given adapter into the adapter factory, so your hand-crafted
artisanal adapter will be available from get_adapter() as if dbt loaded it.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is funny

inject_plugin(plugin)
from dbt.adapters.factory import FACTORY
key = value.type()
FACTORY.adapters[key] = value
FACTORY.adapter_types[key] = type(value)


class ContractTestCase(TestCase):
Expand Down