Skip to content

Commit

Permalink
feat(trino): Support for LISTAGG function (#4253)
Browse files Browse the repository at this point in the history
* feat(trino): Support LISTAGG function

* PR Feedback 1
  • Loading branch information
VaggelisD authored Oct 17, 2024
1 parent cfd692f commit 1c43348
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
10 changes: 7 additions & 3 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,6 @@ class Generator(generator.Generator):
exp.GenerateSeries: sequence_sql,
exp.GenerateDateArray: sequence_sql,
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.GroupConcat: lambda self, e: self.func(
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
),
exp.If: if_sql(),
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
Expand Down Expand Up @@ -680,3 +677,10 @@ def jsonextract_sql(self, expression: exp.JSONExtract) -> str:
expr = "".join(segments)

return f"{this}{expr}"

def groupconcat_sql(self, expression: exp.GroupConcat) -> str:
return self.func(
"ARRAY_JOIN",
self.func("ARRAY_AGG", expression.this),
expression.args.get("separator"),
)
17 changes: 15 additions & 2 deletions sqlglot/dialects/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@ class Trino(Presto):
SUPPORTS_USER_DEFINED_TYPES = False
LOG_BASE_FIRST = True

class Tokenizer(Presto.Tokenizer):
HEX_STRINGS = [("X'", "'")]

class Parser(Presto.Parser):
FUNCTION_PARSERS = {
**Presto.Parser.FUNCTION_PARSERS,
"TRIM": lambda self: self._parse_trim(),
"JSON_QUERY": lambda self: self._parse_json_query(),
"LISTAGG": lambda self: self._parse_string_agg(),
}

JSON_QUERY_OPTIONS: parser.OPTIONS_TYPE = {
Expand Down Expand Up @@ -65,5 +69,14 @@ def jsonextract_sql(self, expression: exp.JSONExtract) -> str:

return self.func("JSON_QUERY", expression.this, json_path + option)

class Tokenizer(Presto.Tokenizer):
HEX_STRINGS = [("X'", "'")]
def groupconcat_sql(self, expression: exp.GroupConcat) -> str:
this = expression.this
separator = expression.args.get("separator") or exp.Literal.string(",")

if isinstance(this, exp.Order):
if this.this:
this = this.this.pop()

return f"LISTAGG({self.format_args(this, separator)}) WITHIN GROUP ({self.sql(expression.this).lstrip()})"

return super().groupconcat_sql(expression)
3 changes: 3 additions & 0 deletions tests/dialects/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ def test_trino(self):
self.validate_identity("JSON_QUERY(content, 'lax $.HY.*')")
self.validate_identity("JSON_QUERY(content, 'strict $.HY.*' WITH UNCONDITIONAL WRAPPER)")
self.validate_identity("JSON_QUERY(content, 'strict $.HY.*' WITHOUT CONDITIONAL WRAPPER)")
self.validate_identity(
"SELECT LISTAGG(DISTINCT col, ',') WITHIN GROUP (ORDER BY col ASC) FROM tbl"
)

def test_trim(self):
self.validate_identity("SELECT TRIM('!' FROM '!foo!')")
Expand Down

0 comments on commit 1c43348

Please sign in to comment.