Skip to content

Commit

Permalink
feat(optimizer): replace date funcs (#2299)
Browse files Browse the repository at this point in the history
* feat(optimizer): replace date funcs

* fixup
  • Loading branch information
barakalon authored Sep 22, 2023
1 parent 6429042 commit 13877fe
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 5 deletions.
5 changes: 4 additions & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5947,7 +5947,10 @@ def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Ca
The new Cast instance.
"""
expression = maybe_parse(expression, **opts)
return Cast(this=expression, to=DataType.build(to, **opts))
data_type = DataType.build(to, **opts)
expression = Cast(this=expression, to=data_type)
expression.type = data_type
return expression


def table_(
Expand Down
1 change: 1 addition & 0 deletions sqlglot/optimizer/annotate_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.TimeAdd,
exp.TimeStrToTime,
exp.TimeSub,
exp.Timestamp,
exp.TimestampAdd,
exp.TimestampSub,
exp.UnixToTime,
Expand Down
14 changes: 10 additions & 4 deletions sqlglot/optimizer/canonicalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
exp.replace_children(expression, canonicalize)

expression = add_text_to_concat(expression)
expression = replace_date_funcs(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
expression = ensure_bool_predicates(expression)
Expand All @@ -31,6 +32,14 @@ def add_text_to_concat(node: exp.Expression) -> exp.Expression:
return node


def replace_date_funcs(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"):
return exp.cast(node.this, to=exp.DataType.Type.DATE)
if isinstance(node, exp.Timestamp) and not node.expression:
return exp.cast(node.this, to=exp.DataType.Type.TIMESTAMP)
return node


def coerce_type(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Binary):
_coerce_date(node.left, node.right)
Expand Down Expand Up @@ -84,10 +93,7 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:


def _replace_cast(node: exp.Expression, to: str) -> None:
data_type = exp.DataType.build(to)
cast = exp.Cast(this=node.copy(), to=data_type)
cast.type = data_type
node.replace(cast)
node.replace(exp.cast(node.copy(), to=to))


def _replace_int_predicate(expression: exp.Expression) -> None:
Expand Down
12 changes: 12 additions & 0 deletions tests/fixtures/optimizer/canonicalize.sql
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,15 @@ SELECT "x"."a" AS "a" FROM "x" AS "x" GROUP BY "x"."a" HAVING SUM("x"."b") <> 0

SELECT a FROM x WHERE 1;
SELECT "x"."a" AS "a" FROM "x" AS "x" WHERE 1 <> 0;

--------------------------------------
-- Replace date functions
--------------------------------------
DATE('2023-01-01');
CAST('2023-01-01' AS DATE);

TIMESTAMP('2023-01-01');
CAST('2023-01-01' AS TIMESTAMP);

TIMESTAMP('2023-01-01', '12:00:00');
TIMESTAMP('2023-01-01', '12:00:00');

0 comments on commit 13877fe

Please sign in to comment.