Skip to content

Commit

Permalink
feat(duckdb): Transpile Spark's LATERAL VIEW EXPLODE (#4252)
Browse files Browse the repository at this point in the history
  • Loading branch information
VaggelisD authored Oct 16, 2024
1 parent 94013a2 commit ed97954
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 16 deletions.
15 changes: 15 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,3 +1697,18 @@ def build_regexp_extract(args: t.List, dialect: Dialect) -> exp.RegexpExtract:
expression=seq_get(args, 1),
group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP),
)


def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, exp.Explode):
return self.sql(
exp.Join(
this=exp.Unnest(
expressions=[expression.this.this],
alias=expression.args.get("alias"),
offset=isinstance(expression.this, exp.Posexplode),
),
kind="cross",
)
)
return self.lateral_sql(expression)
2 changes: 2 additions & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
unit_to_str,
sha256_sql,
build_regexp_extract,
explode_to_unnest_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
Expand Down Expand Up @@ -538,6 +539,7 @@ class Generator(generator.Generator):
exp.JSONExtract: _arrow_json_extract_sql,
exp.JSONExtractScalar: _arrow_json_extract_sql,
exp.JSONFormat: _json_format_sql,
exp.Lateral: explode_to_unnest_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),
Expand Down
18 changes: 2 additions & 16 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
unit_to_str,
sequence_sql,
build_regexp_extract,
explode_to_unnest_sql,
)
from sqlglot.dialects.hive import Hive
from sqlglot.dialects.mysql import MySQL
Expand All @@ -40,21 +41,6 @@
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TimestampAdd, exp.DateSub]


def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, exp.Explode):
return self.sql(
exp.Join(
this=exp.Unnest(
expressions=[expression.this.this],
alias=expression.args.get("alias"),
offset=isinstance(expression.this, exp.Posexplode),
),
kind="cross",
)
)
return self.lateral_sql(expression)


def _initcap_sql(self: Presto.Generator, expression: exp.Initcap) -> str:
regex = r"(\w)(\w*)"
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
Expand Down Expand Up @@ -410,7 +396,7 @@ class Generator(generator.Generator):
exp.Last: _first_last_sql,
exp.LastValue: _first_last_sql,
exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this),
exp.Lateral: _explode_to_unnest_sql,
exp.Lateral: explode_to_unnest_sql,
exp.Left: left_to_substring_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.LogicalAnd: rename_func("BOOL_AND"),
Expand Down
3 changes: 3 additions & 0 deletions tests/dialects/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def test_lateral_view(self):
"SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b",
write={
"presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y) AS t(a) CROSS JOIN UNNEST(z) AS u(b)",
"duckdb": "SELECT a, b FROM x CROSS JOIN UNNEST(y) AS t(a) CROSS JOIN UNNEST(z) AS u(b)",
"hive": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b",
"spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b",
},
Expand All @@ -195,6 +196,7 @@ def test_lateral_view(self):
"SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a",
write={
"presto": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)",
"duckdb": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)",
"hive": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a",
"spark": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a",
},
Expand All @@ -211,6 +213,7 @@ def test_lateral_view(self):
"SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a",
write={
"presto": "SELECT a FROM x CROSS JOIN UNNEST(ARRAY[y]) AS t(a)",
"duckdb": "SELECT a FROM x CROSS JOIN UNNEST([y]) AS t(a)",
"hive": "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a",
"spark": "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a",
},
Expand Down

0 comments on commit ed97954

Please sign in to comment.