From 347ac51da6a553a7904739f0f3ad6b4bb4db01c6 Mon Sep 17 00:00:00 2001 From: Jo <46752250+GeorgeSittas@users.noreply.github.com> Date: Wed, 4 Oct 2023 19:10:41 +0300 Subject: [PATCH] Feat(redshift): add support for Redshift's super array index iteration (#2373) * Feat(redshift): add support for Redshift's super array index iteration * Support Redshift's special syntax for UNPIVOT * Refactor type hint * Make id var parsing more lax for AT clause --- sqlglot/dialects/redshift.py | 19 +++++++++++++++++++ sqlglot/expressions.py | 3 ++- sqlglot/generator.py | 9 ++++++++- sqlglot/parser.py | 7 ++++++- tests/dialects/test_redshift.py | 16 ++++++++++++++++ 5 files changed, 51 insertions(+), 3 deletions(-) diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 88e4448c12..b70a8a1209 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -31,6 +31,7 @@ class Redshift(Postgres): RESOLVES_IDENTIFIERS_AS_UPPERCASE = None SUPPORTS_USER_DEFINED_TYPES = False + INDEX_OFFSET = 0 TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'" TIME_MAPPING = { @@ -57,6 +58,24 @@ class Parser(Postgres.Parser): "STRTOL": exp.FromBase.from_arg_list, } + def _parse_table( + self, + schema: bool = False, + joins: bool = False, + alias_tokens: t.Optional[t.Collection[TokenType]] = None, + parse_bracket: bool = False, + ) -> t.Optional[exp.Expression]: + # Redshift supports UNPIVOTing SUPER objects, e.g. `UNPIVOT foo.obj[0] AS val AT attr` + unpivot = self._match(TokenType.UNPIVOT) + table = super()._parse_table( + schema=schema, + joins=joins, + alias_tokens=alias_tokens, + parse_bracket=parse_bracket, + ) + + return self.expression(exp.Pivot, this=table, unpivot=True) if unpivot else table + def _parse_types( self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True ) -> t.Optional[exp.Expression]: diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 1e4aad61d1..1dd73f5e26 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -2465,6 +2465,7 @@ class Table(Expression): "version": False, "format": False, "pattern": False, + "index": False, } @property @@ -3431,7 +3432,7 @@ class Pivot(Expression): arg_types = { "this": False, "alias": False, - "expressions": True, + "expressions": False, "field": False, "unpivot": False, "using": False, diff --git a/sqlglot/generator.py b/sqlglot/generator.py index edc69393f7..3cf4a7bd04 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1350,13 +1350,17 @@ def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str: pivots = f" {pivots}" if pivots else "" joins = self.expressions(expression, key="joins", sep="", skip_first=True) laterals = self.expressions(expression, key="laterals", sep="") + file_format = self.sql(expression, "format") if file_format: pattern = self.sql(expression, "pattern") pattern = f", PATTERN => {pattern}" if pattern else "" file_format = f" (FILE_FORMAT => {file_format}{pattern})" - return f"{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}" + index = self.sql(expression, "index") + index = f" AT {index}" if index else "" + + return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}" def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " @@ -1401,6 +1405,9 @@ def pivot_sql(self, expression: exp.Pivot) -> str: if expression.this: this = self.sql(expression, "this") + if not expressions: + return f"UNPIVOT {this}" + on = f"{self.seg('ON')} {expressions}" using = self.expressions(expression, key="using", flat=True) using = f"{self.seg('USING')} {using}" if using else "" diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 5e56961278..ddaa9d6ebe 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -2622,7 +2622,9 @@ def _parse_table( bracket = parse_bracket and self._parse_bracket(None) bracket = self.expression(exp.Table, this=bracket) if bracket else None - this: exp.Expression = bracket or self._parse_table_parts(schema=schema) + this = t.cast( + exp.Expression, bracket or self._parse_bracket(self._parse_table_parts(schema=schema)) + ) if schema: return self._parse_schema(this=this) @@ -2639,6 +2641,9 @@ def _parse_table( if alias: this.set("alias", alias) + if self._match_text_seq("AT"): + this.set("index", self._parse_id_var()) + this.set("hints", self._parse_table_hints()) if not this.args.get("pivots"): diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index ae1b987e0c..9f2761f4fe 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -273,6 +273,22 @@ def test_identity(self): "SELECT DATE_ADD('day', 1, DATE('2023-01-01'))", "SELECT DATEADD(day, 1, CAST(DATE('2023-01-01') AS DATE))", ) + self.validate_identity( + """SELECT + c_name, + orders.o_orderkey AS orderkey, + index AS orderkey_index +FROM customer_orders_lineitem AS c, c.c_orders AS orders AT index +ORDER BY + orderkey_index""", + pretty=True, + ) + self.validate_identity( + "SELECT attr AS attr, JSON_TYPEOF(val) AS value_type FROM customer_orders_lineitem AS c, UNPIVOT c.c_orders[0] WHERE c_custkey = 9451" + ) + self.validate_identity( + "SELECT attr AS attr, JSON_TYPEOF(val) AS value_type FROM customer_orders_lineitem AS c, UNPIVOT c.c_orders AS val AT attr WHERE c_custkey = 9451" + ) def test_values(self): # Test crazy-sized VALUES clause to UNION ALL conversion to ensure we don't get RecursionError