Skip to content

Commit

Permalink
Fix: unescape escape sequences on read, re-escape them on generation (#…
Browse files Browse the repository at this point in the history
…2367)

* Fix: unescape escape sequences on read, re-escape them on generation

* Get rid of setuptools_scm pin

* Bring back setuptools_scm dep but without the pin

* update test to read from mysql
  • Loading branch information
georgesittas authored Oct 3, 2023
1 parent f6750ef commit f777155
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"fallback_version": "0.0.0",
"local_scheme": "no-local-version",
},
setup_requires=["setuptools_scm<8.0.1"],
setup_requires=["setuptools_scm"],
python_requires=">=3.7",
extras_require={
"dev": [
Expand Down
11 changes: 10 additions & 1 deletion sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,16 @@ class BigQuery(Dialect):
"%D": "%m/%d/%y",
}

ESCAPE_SEQUENCES = {
"\\a": "\a",
"\\b": "\b",
"\\f": "\f",
"\\n": "\n",
"\\r": "\r",
"\\t": "\t",
"\\v": "\v",
}

FORMAT_MAPPING = {
"DD": "%d",
"MM": "%m",
Expand Down Expand Up @@ -416,7 +426,6 @@ class Generator(generator.Generator):
TABLE_HINTS = False
LIMIT_FETCH = "LIMIT"
RENAME_TABLE_WITH_DB = False
ESCAPE_LINE_BREAK = True
NVL2_SUPPORTED = False
UNNEST_WITH_ORDINALITY = False
COLLATE_IS_FUNC = True
Expand Down
4 changes: 4 additions & 0 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class ClickHouse(Dialect):
STRICT_STRING_CONCAT = True
SUPPORTS_USER_DEFINED_TYPES = False

ESCAPE_SEQUENCES = {
"\\0": "\0",
}

class Tokenizer(tokens.Tokenizer):
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
IDENTIFIERS = ['"', "`"]
Expand Down
7 changes: 7 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def __new__(cls, clsname, bases, attrs):
klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)

klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}

klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
klass.parser_class = getattr(klass, "Parser", Parser)
klass.generator_class = getattr(klass, "Generator", Generator)
Expand Down Expand Up @@ -188,6 +190,9 @@ class Dialect(metaclass=_Dialect):
# special syntax cast(x as date format 'yyyy') defaults to time_mapping
FORMAT_MAPPING: t.Dict[str, str] = {}

# Mapping of an unescaped escape sequence to the corresponding character
ESCAPE_SEQUENCES: t.Dict[str, str] = {}

# Columns that are auto-generated by the engine corresponding to this dialect
# Such columns may be excluded from SELECT * queries, for example
PSEUDOCOLUMNS: t.Set[str] = set()
Expand All @@ -204,6 +209,8 @@ class Dialect(metaclass=_Dialect):
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
INVERSE_TIME_TRIE: t.Dict = {}

INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}

def __eq__(self, other: t.Any) -> bool:
return type(self) == other

Expand Down
6 changes: 3 additions & 3 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,14 @@ class Generator:
# Autofilled
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
INVERSE_TIME_TRIE: t.Dict = {}
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
IDENTIFIERS_CAN_START_WITH_DIGIT = False
STRICT_STRING_CONCAT = False
NORMALIZE_FUNCTIONS: bool | str = "upper"
NULL_ORDERING = "nulls_are_small"
ESCAPE_LINE_BREAK = False

can_identify: t.Callable[[str, str | bool], bool]

Expand Down Expand Up @@ -1670,8 +1670,8 @@ def literal_sql(self, expression: exp.Literal) -> str:

def escape_str(self, text: str) -> str:
text = text.replace(self.QUOTE_END, self._escaped_quote_end)
if self.ESCAPE_LINE_BREAK:
text = text.replace("\n", "\\n")
if self.INVERSE_ESCAPE_SEQUENCES:
text = "".join(self.INVERSE_ESCAPE_SEQUENCES.get(ch, ch) for ch in text)
elif self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
return text
Expand Down
8 changes: 8 additions & 0 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ class Tokenizer(metaclass=_Tokenizer):
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
STRING_ESCAPES = ["'"]
VAR_SINGLE_TOKENS: t.Set[str] = set()
ESCAPE_SEQUENCES: t.Dict[str, str] = {}

# Autofilled
IDENTIFIERS_CAN_START_WITH_DIGIT: bool = False
Expand Down Expand Up @@ -1203,6 +1204,13 @@ def _extract_string(self, delimiter: str, escapes=None) -> str:
if self._end:
raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}")

if self.ESCAPE_SEQUENCES and self._peek and self._char in self.STRING_ESCAPES:
escaped_sequence = self.ESCAPE_SEQUENCES.get(self._char + self._peek)
if escaped_sequence:
self._advance(2)
text += escaped_sequence
continue

current = self._current - 1
self._advance(alnum=True)
text += self.sql[current : self._current - 1]
Expand Down
10 changes: 10 additions & 0 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ def test_bigquery(self):
self.validate_all('x <> """"""', write={"bigquery": "x <> ''"})
self.validate_all("x <> ''''''", write={"bigquery": "x <> ''"})
self.validate_all("CAST(x AS DATETIME)", read={"": "x::timestamp"})
self.validate_all(
"SELECT '\\n'",
read={
"bigquery": "SELECT '''\n'''",
},
write={
"bigquery": "SELECT '\\n'",
"postgres": "SELECT '\n'",
},
)
self.validate_all(
"TRIM(item, '*')",
read={
Expand Down
10 changes: 10 additions & 0 deletions tests/dialects/test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ def test_clickhouse(self):
"CREATE MATERIALIZED VIEW test_view (id UInt8) TO db.table1 AS SELECT * FROM test_data"
)

self.validate_all(
"SELECT '\\0'",
read={
"mysql": "SELECT '\0'",
},
write={
"clickhouse": "SELECT '\\0'",
"mysql": "SELECT '\0'",
},
)
self.validate_all(
"DATE_ADD('day', 1, x)",
read={
Expand Down

0 comments on commit f777155

Please sign in to comment.