Skip to content

Commit

Permalink
Feat!: add support for heredoc strings (Postgres, ClickHouse) (#2328)
Browse files Browse the repository at this point in the history
* Feat!: add support for heredoc strings (Postgres, ClickHouse)

* Add redshift to the heredoc tests

* Add tests with empty strings
  • Loading branch information
georgesittas authored Sep 26, 2023
1 parent dc7a6c2 commit ebdfc59
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 5 deletions.
6 changes: 6 additions & 0 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Tokenizer(tokens.Tokenizer):
STRING_ESCAPES = ["'", "\\"]
BIT_STRINGS = [("0b", "")]
HEX_STRINGS = [("0x", ""), ("0X", "")]
HEREDOC_STRINGS = ["$"]

KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
Expand Down Expand Up @@ -75,6 +76,11 @@ class Tokenizer(tokens.Tokenizer):
"UINT8": TokenType.UTINYINT,
}

SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
"$": TokenType.HEREDOC_STRING,
}

class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
Expand Down
5 changes: 2 additions & 3 deletions sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,10 @@ class Postgres(Dialect):
}

class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]

BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
HEREDOC_STRINGS = ["$"]

KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
Expand Down Expand Up @@ -296,7 +295,7 @@ class Tokenizer(tokens.Tokenizer):

SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
"$": TokenType.HEREDOC_STRING,
}

VAR_SINGLE_TOKENS = {"$"}
Expand Down
3 changes: 3 additions & 0 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,9 @@ class Parser(metaclass=_Parser):
exp.National, this=token.text
),
TokenType.RAW_STRING: lambda self, token: self.expression(exp.RawString, this=token.text),
TokenType.HEREDOC_STRING: lambda self, token: self.expression(
exp.RawString, this=token.text
),
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
}

Expand Down
9 changes: 9 additions & 0 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class TokenType(AutoName):
BYTE_STRING = auto()
NATIONAL_STRING = auto()
RAW_STRING = auto()
HEREDOC_STRING = auto()

# types
BIT = auto()
Expand Down Expand Up @@ -418,6 +419,7 @@ def _quotes_to_format(
**_quotes_to_format(TokenType.BYTE_STRING, klass.BYTE_STRINGS),
**_quotes_to_format(TokenType.HEX_STRING, klass.HEX_STRINGS),
**_quotes_to_format(TokenType.RAW_STRING, klass.RAW_STRINGS),
**_quotes_to_format(TokenType.HEREDOC_STRING, klass.HEREDOC_STRINGS),
}

klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
Expand Down Expand Up @@ -484,6 +486,7 @@ class Tokenizer(metaclass=_Tokenizer):
BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = []
HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
RAW_STRINGS: t.List[str | t.Tuple[str, str]] = []
HEREDOC_STRINGS: t.List[str | t.Tuple[str, str]] = []
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
IDENTIFIER_ESCAPES = ['"']
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
Expand Down Expand Up @@ -997,9 +1000,11 @@ def _scan_keywords(self) -> None:
word = word.upper()
self._add(self.KEYWORDS[word], text=word)
return

if self._char in self.SINGLE_TOKENS:
self._add(self.SINGLE_TOKENS[self._char], text=self._char)
return

self._scan_var()

def _scan_comment(self, comment_start: str) -> bool:
Expand Down Expand Up @@ -1126,6 +1131,10 @@ def _scan_string(self, start: str) -> bool:
base = 16
elif token_type == TokenType.BIT_STRING:
base = 2
elif token_type == TokenType.HEREDOC_STRING:
self._advance()
tag = "" if self._char == end else self._extract_string(end)
end = f"{start}{tag}{end}"
else:
return False

Expand Down
39 changes: 39 additions & 0 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Dialects,
ErrorLevel,
ParseError,
TokenError,
UnsupportedError,
parse_one,
)
Expand Down Expand Up @@ -308,6 +309,44 @@ def test_cast(self):
read={"postgres": "INET '127.0.0.1/32'"},
)

def test_heredoc_strings(self):
for dialect in ("clickhouse", "postgres", "redshift"):
# Invalid matching tag
with self.assertRaises(TokenError):
parse_one("SELECT $tag1$invalid heredoc string$tag2$", dialect=dialect)

# Unmatched tag
with self.assertRaises(TokenError):
parse_one("SELECT $tag1$invalid heredoc string", dialect=dialect)

# Without tag
self.validate_all(
"SELECT 'this is a heredoc string'",
read={
dialect: "SELECT $$this is a heredoc string$$",
},
)
self.validate_all(
"SELECT ''",
read={
dialect: "SELECT $$$$",
},
)

# With tag
self.validate_all(
"SELECT 'this is also a heredoc string'",
read={
dialect: "SELECT $foo$this is also a heredoc string$foo$",
},
)
self.validate_all(
"SELECT ''",
read={
dialect: "SELECT $foo$$foo$",
},
)

def test_decode(self):
self.validate_identity("DECODE(bin, charset)")

Expand Down
1 change: 0 additions & 1 deletion tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def test_postgres(self):
self.validate_identity("SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]")
self.validate_identity("SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]")
self.validate_identity("SELECT ARRAY[1, 2, 3] && ARRAY[1, 2]")
self.validate_identity("$x")
self.validate_identity("x$")
self.validate_identity("SELECT ARRAY[1, 2, 3]")
self.validate_identity("SELECT ARRAY(SELECT 1)")
Expand Down
1 change: 0 additions & 1 deletion tests/dialects/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ def test_identity(self):
self.validate_identity("SELECT * FROM #x")
self.validate_identity("SELECT INTERVAL '5 day'")
self.validate_identity("foo$")
self.validate_identity("$foo")
self.validate_identity("CAST('bla' AS SUPER)")
self.validate_identity("CREATE TABLE real1 (realcol REAL)")
self.validate_identity("CAST('foo' AS HLLSKETCH)")
Expand Down

0 comments on commit ebdfc59

Please sign in to comment.