Skip to content

Commit

Permalink
Feat(snowflake): add support for staged file file_format clause (#2359)
Browse files Browse the repository at this point in the history
* Feat(snowflake): add support for staged file file_format clause

* Increase test coverage
  • Loading branch information
georgesittas authored Oct 2, 2023
1 parent e8273e2 commit d2047ec
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 23 deletions.
11 changes: 10 additions & 1 deletion sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,16 @@ def _parse_table_parts(self, schema: bool = False) -> exp.Table:
table = self._parse_string()

if table:
return self.expression(exp.Table, this=table)
file_format = None
pattern = None

if self._match_text_seq("(", "FILE_FORMAT", "=>"):
file_format = self._parse_string() or super()._parse_table_parts()
if self._match_text_seq(",", "PATTERN", "=>"):
pattern = self._parse_string()
self._match_r_paren()

return self.expression(exp.Table, this=table, format=file_format, pattern=pattern)

return super()._parse_table_parts(schema=schema)

Expand Down
2 changes: 2 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,8 @@ class Table(Expression):
"hints": False,
"system_time": False,
"version": False,
"format": False,
"pattern": False,
}

@property
Expand Down
7 changes: 6 additions & 1 deletion sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,8 +1349,13 @@ def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
pivots = f" {pivots}" if pivots else ""
joins = self.expressions(expression, key="joins", sep="", skip_first=True)
laterals = self.expressions(expression, key="laterals", sep="")
file_format = self.sql(expression, "format")
if file_format:
pattern = self.sql(expression, "pattern")
pattern = f", PATTERN => {pattern}" if pattern else ""
file_format = f" (FILE_FORMAT => {file_format}{pattern})"

return f"{table}{version}{alias}{hints}{pivots}{joins}{laterals}"
return f"{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}"

def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
Expand Down
55 changes: 34 additions & 21 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,6 @@ class TestSnowflake(Validator):
dialect = "snowflake"

def test_snowflake(self):
# Ensure we don't treat staged file paths as identifiers (i.e. they're not normalized)
staged_file = parse_one("SELECT * FROM @foo", read="snowflake")
self.assertEqual(
normalize_identifiers(staged_file, dialect="snowflake").sql(dialect="snowflake"),
staged_file.sql(dialect="snowflake"),
)

self.validate_identity("SELECT * FROM @~")
self.validate_identity("SELECT * FROM @~/some/path/to/file.csv")
self.validate_identity("SELECT * FROM @mystage")
self.validate_identity("SELECT * FROM '@mystage'")
self.validate_identity("SELECT * FROM @namespace.mystage/path/to/file.json.gz")
self.validate_identity("SELECT * FROM @namespace.%table_name/path/to/file.json.gz")
self.validate_identity("LISTAGG(data['some_field'], ',')")
self.validate_identity("WEEKOFYEAR(tstamp)")
self.validate_identity("SELECT SUM(amount) FROM mytable GROUP BY ALL")
Expand All @@ -40,7 +27,6 @@ def test_snowflake(self):
self.validate_identity("$x") # parameter
self.validate_identity("a$b") # valid snowflake identifier
self.validate_identity("SELECT REGEXP_LIKE(a, b, c)")
self.validate_identity("PUT file:///dir/tmp.csv @%table")
self.validate_identity("CREATE TABLE foo (bar FLOAT AUTOINCREMENT START 0 INCREMENT 1)")
self.validate_identity("ALTER TABLE IF EXISTS foo SET TAG a = 'a', b = 'b', c = 'c'")
self.validate_identity("ALTER TABLE foo UNSET TAG a, b, c")
Expand All @@ -53,9 +39,6 @@ def test_snowflake(self):
self.validate_identity(
'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage'
)
self.validate_identity(
'COPY INTO NEW_TABLE ("foo", "bar") FROM (SELECT $1, $2, $3, $4 FROM @%old_table)'
)
self.validate_identity(
"SELECT state, city, SUM(retail_price * quantity) AS gross_revenue FROM sales GROUP BY ALL"
)
Expand All @@ -75,10 +58,6 @@ def test_snowflake(self):
"SELECT {'test': 'best'}::VARIANT",
"SELECT CAST(OBJECT_CONSTRUCT('test', 'best') AS VARIANT)",
)
self.validate_identity(
"SELECT parse_json($1):a.b FROM @mystage2/data1.json.gz",
"SELECT PARSE_JSON($1)['a'].b FROM @mystage2/data1.json.gz",
)

self.validate_all("CAST(x AS BYTEINT)", write={"snowflake": "CAST(x AS INT)"})
self.validate_all("CAST(x AS CHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"})
Expand Down Expand Up @@ -540,6 +519,40 @@ def test_null_treatment(self):
},
)

def test_staged_files(self):
# Ensure we don't treat staged file paths as identifiers (i.e. they're not normalized)
staged_file = parse_one("SELECT * FROM @foo", read="snowflake")
self.assertEqual(
normalize_identifiers(staged_file, dialect="snowflake").sql(dialect="snowflake"),
staged_file.sql(dialect="snowflake"),
)

self.validate_identity("SELECT * FROM @~")
self.validate_identity("SELECT * FROM @~/some/path/to/file.csv")
self.validate_identity("SELECT * FROM @mystage")
self.validate_identity("SELECT * FROM '@mystage'")
self.validate_identity("SELECT * FROM @namespace.mystage/path/to/file.json.gz")
self.validate_identity("SELECT * FROM @namespace.%table_name/path/to/file.json.gz")
self.validate_identity("SELECT * FROM '@external/location' (FILE_FORMAT => 'path.to.csv')")
self.validate_identity("PUT file:///dir/tmp.csv @%table")
self.validate_identity(
'COPY INTO NEW_TABLE ("foo", "bar") FROM (SELECT $1, $2, $3, $4 FROM @%old_table)'
)
self.validate_identity(
"SELECT * FROM @foo/bar (FILE_FORMAT => ds_sandbox.test.my_csv_format, PATTERN => 'test') AS bla"
)
self.validate_identity(
"SELECT t.$1, t.$2 FROM @mystage1 (FILE_FORMAT => 'myformat', PATTERN => '.*data.*[.]csv.gz') AS t"
)
self.validate_identity(
"SELECT parse_json($1):a.b FROM @mystage2/data1.json.gz",
"SELECT PARSE_JSON($1)['a'].b FROM @mystage2/data1.json.gz",
)
self.validate_identity(
"SELECT * FROM @mystage t (c1)",
"SELECT * FROM @mystage AS t(c1)",
)

def test_sample(self):
self.validate_identity("SELECT * FROM testtable TABLESAMPLE BERNOULLI (20.3)")
self.validate_identity("SELECT * FROM testtable TABLESAMPLE (100)")
Expand Down

0 comments on commit d2047ec

Please sign in to comment.