Skip to content

Commit

Permalink
Feat(redshift): add support for Redshift's super array index iteration (
Browse files Browse the repository at this point in the history
#2373)

* Feat(redshift): add support for Redshift's super array index iteration

* Support Redshift's special syntax for UNPIVOT

* Refactor type hint

* Make id var parsing more lax for AT clause
  • Loading branch information
georgesittas authored Oct 4, 2023
1 parent 40bb71f commit 347ac51
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 3 deletions.
19 changes: 19 additions & 0 deletions sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Redshift(Postgres):
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None

SUPPORTS_USER_DEFINED_TYPES = False
INDEX_OFFSET = 0

TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
TIME_MAPPING = {
Expand All @@ -57,6 +58,24 @@ class Parser(Postgres.Parser):
"STRTOL": exp.FromBase.from_arg_list,
}

def _parse_table(
self,
schema: bool = False,
joins: bool = False,
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
) -> t.Optional[exp.Expression]:
# Redshift supports UNPIVOTing SUPER objects, e.g. `UNPIVOT foo.obj[0] AS val AT attr`
unpivot = self._match(TokenType.UNPIVOT)
table = super()._parse_table(
schema=schema,
joins=joins,
alias_tokens=alias_tokens,
parse_bracket=parse_bracket,
)

return self.expression(exp.Pivot, this=table, unpivot=True) if unpivot else table

def _parse_types(
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
Expand Down
3 changes: 2 additions & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2465,6 +2465,7 @@ class Table(Expression):
"version": False,
"format": False,
"pattern": False,
"index": False,
}

@property
Expand Down Expand Up @@ -3431,7 +3432,7 @@ class Pivot(Expression):
arg_types = {
"this": False,
"alias": False,
"expressions": True,
"expressions": False,
"field": False,
"unpivot": False,
"using": False,
Expand Down
9 changes: 8 additions & 1 deletion sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,13 +1350,17 @@ def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
pivots = f" {pivots}" if pivots else ""
joins = self.expressions(expression, key="joins", sep="", skip_first=True)
laterals = self.expressions(expression, key="laterals", sep="")

file_format = self.sql(expression, "format")
if file_format:
pattern = self.sql(expression, "pattern")
pattern = f", PATTERN => {pattern}" if pattern else ""
file_format = f" (FILE_FORMAT => {file_format}{pattern})"

return f"{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}"
index = self.sql(expression, "index")
index = f" AT {index}" if index else ""

return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}"

def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
Expand Down Expand Up @@ -1401,6 +1405,9 @@ def pivot_sql(self, expression: exp.Pivot) -> str:

if expression.this:
this = self.sql(expression, "this")
if not expressions:
return f"UNPIVOT {this}"

on = f"{self.seg('ON')} {expressions}"
using = self.expressions(expression, key="using", flat=True)
using = f"{self.seg('USING')} {using}" if using else ""
Expand Down
7 changes: 6 additions & 1 deletion sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2622,7 +2622,9 @@ def _parse_table(

bracket = parse_bracket and self._parse_bracket(None)
bracket = self.expression(exp.Table, this=bracket) if bracket else None
this: exp.Expression = bracket or self._parse_table_parts(schema=schema)
this = t.cast(
exp.Expression, bracket or self._parse_bracket(self._parse_table_parts(schema=schema))
)

if schema:
return self._parse_schema(this=this)
Expand All @@ -2639,6 +2641,9 @@ def _parse_table(
if alias:
this.set("alias", alias)

if self._match_text_seq("AT"):
this.set("index", self._parse_id_var())

this.set("hints", self._parse_table_hints())

if not this.args.get("pivots"):
Expand Down
16 changes: 16 additions & 0 deletions tests/dialects/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,22 @@ def test_identity(self):
"SELECT DATE_ADD('day', 1, DATE('2023-01-01'))",
"SELECT DATEADD(day, 1, CAST(DATE('2023-01-01') AS DATE))",
)
self.validate_identity(
"""SELECT
c_name,
orders.o_orderkey AS orderkey,
index AS orderkey_index
FROM customer_orders_lineitem AS c, c.c_orders AS orders AT index
ORDER BY
orderkey_index""",
pretty=True,
)
self.validate_identity(
"SELECT attr AS attr, JSON_TYPEOF(val) AS value_type FROM customer_orders_lineitem AS c, UNPIVOT c.c_orders[0] WHERE c_custkey = 9451"
)
self.validate_identity(
"SELECT attr AS attr, JSON_TYPEOF(val) AS value_type FROM customer_orders_lineitem AS c, UNPIVOT c.c_orders AS val AT attr WHERE c_custkey = 9451"
)

def test_values(self):
# Test crazy-sized VALUES clause to UNION ALL conversion to ensure we don't get RecursionError
Expand Down

0 comments on commit 347ac51

Please sign in to comment.