Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add casts and string to datetime conversion #27

Merged
merged 25 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1f9dbaf
add round in polars
finn-rudolph Sep 27, 2024
4992c69
add some support for floating point log
finn-rudolph Sep 27, 2024
d3f63b7
add some tests for decimals
finn-rudolph Sep 27, 2024
70e94f8
add common numeric ops for decimal
finn-rudolph Sep 27, 2024
34e2ab7
add cast
finn-rudolph Sep 28, 2024
7ac1a25
make string to float casts work
finn-rudolph Sep 28, 2024
fe0ee84
test string to int cast
finn-rudolph Sep 28, 2024
b54438f
add tests for float to in casts
finn-rudolph Sep 28, 2024
b4d1b44
add floor / ceil
finn-rudolph Sep 28, 2024
5602830
make float to int casts consistent
finn-rudolph Sep 28, 2024
bb73672
also catch -nan in sqlite / mssql str -> float
finn-rudolph Sep 29, 2024
2c9a589
add stronger test cases
finn-rudolph Sep 29, 2024
2d00ebe
allow Callables in the pipe
finn-rudolph Sep 30, 2024
c35a980
add nan / inf cols manually
finn-rudolph Sep 30, 2024
e7383d4
sqlite: fix inf, don't distinguish nan and null
finn-rudolph Sep 30, 2024
ec306c9
fix duckdb / postgres cast float -> int
finn-rudolph Sep 30, 2024
611c3b0
add datetime to date cast
finn-rudolph Sep 30, 2024
5aa690b
implement int to string cast
finn-rudolph Sep 30, 2024
9c79efb
partially implement cast of float to string
finn-rudolph Sep 30, 2024
193905a
allow Date to String and Datetime to String cast
finn-rudolph Sep 30, 2024
b5fafe9
make sqa_type and pdt_type in sql classmethods
finn-rudolph Sep 30, 2024
ed5fc1b
implement str.to_datetime
finn-rudolph Sep 30, 2024
289d57f
make dtype comparisons more readable
finn-rudolph Sep 30, 2024
081712c
rename Int to Int64 for consistency with Float64
finn-rudolph Sep 30, 2024
a3933e5
fix operator registry tests
finn-rudolph Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/pydiverse/transform/backend/duckdb.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import polars as pl
import sqlalchemy as sqa

from pydiverse.transform.backend import sql
from pydiverse.transform.backend.sql import SqlImpl
from pydiverse.transform.backend.targets import Polars, Target
from pydiverse.transform.tree import dtypes
from pydiverse.transform.tree.ast import AstNode
from pydiverse.transform.tree.col_expr import Col
from pydiverse.transform.tree.col_expr import Cast, Col


class DuckDbImpl(SqlImpl):
Expand All @@ -21,3 +23,11 @@ def export(cls, nd: AstNode, target: Target, final_select: list[Col]):
DuckDbImpl.build_query(nd, final_select), connection=conn
)
return SqlImpl.export(nd, target, final_select)

@classmethod
def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast:
if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.Int64:
return sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)).cast(
sqa.BigInteger()
)
return super().compile_cast(cast, sqa_col)
77 changes: 76 additions & 1 deletion src/pydiverse/transform/backend/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any

import sqlalchemy as sqa
from sqlalchemy.dialects.mssql import DATETIME2

from pydiverse.transform import ops
from pydiverse.transform.backend import sql
Expand All @@ -13,6 +14,7 @@
from pydiverse.transform.tree.ast import AstNode
from pydiverse.transform.tree.col_expr import (
CaseExpr,
Cast,
Col,
ColExpr,
ColFn,
Expand All @@ -25,6 +27,37 @@
class MsSqlImpl(SqlImpl):
dialect_name = "mssql"

INF = sqa.cast(sqa.literal("1.0"), type_=sqa.Float()) / sqa.literal(
"0.0", type_=sqa.Float()
)
NEG_INF = -INF
NAN = INF + NEG_INF

@classmethod
def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast:
compiled_val = cls.compile_col_expr(cast.val, sqa_col)
if cast.val.dtype() == dtypes.String and cast.target_type == dtypes.Float64:
return sqa.case(
(compiled_val == "inf", cls.INF),
(compiled_val == "-inf", -cls.INF),
(compiled_val.in_(("nan", "-nan")), cls.NAN),
else_=sqa.cast(
compiled_val,
cls.sqa_type(cast.target_type),
),
)

if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.String:
compiled = sqa.cast(cls.compile_col_expr(cast.val, sqa_col), sqa.String)
return sqa.case(
(compiled == "1.#QNAN", "nan"),
(compiled == "1.#INF", "inf"),
(compiled == "-1.#INF", "-inf"),
else_=compiled,
)

return sqa.cast(compiled_val, cls.sqa_type(cast.target_type))

@classmethod
def build_select(cls, nd: AstNode, final_select: list[Col]) -> Any:
# boolean / bit conversion
Expand Down Expand Up @@ -54,6 +87,13 @@ def build_select(cls, nd: AstNode, final_select: list[Col]) -> Any:
table, query, _ = cls.compile_ast(nd, {col._uuid: 1 for col in final_select})
return cls.compile_query(table, query)

@classmethod
def sqa_type(cls, t: dtypes.Dtype):
if isinstance(t, dtypes.DateTime):
return DATETIME2()

return super().sqa_type(t)


def convert_order_list(order_list: list[Order]) -> list[Order]:
new_list: list[Order] = []
Expand Down Expand Up @@ -93,7 +133,7 @@ def convert_bool_bit(expr: ColExpr | Order, wants_bool_as_bit: bool) -> ColExpr
)

elif isinstance(expr, Col):
if not wants_bool_as_bit and isinstance(expr.dtype(), dtypes.Bool):
if not wants_bool_as_bit and expr.dtype() == dtypes.Bool:
return ColFn("__eq__", expr, LiteralCol(True))
return expr

Expand Down Expand Up @@ -146,6 +186,14 @@ def convert_bool_bit(expr: ColExpr | Order, wants_bool_as_bit: bool) -> ColExpr
elif isinstance(expr, LiteralCol):
return expr

elif isinstance(expr, Cast):
# TODO: does this really work for casting onto / from booleans? we probably have
# to use wants_bool_as_bit in some way when casting to bool
return Cast(
convert_bool_bit(expr.val, wants_bool_as_bit=wants_bool_as_bit),
expr.target_type,
)

raise AssertionError


Expand Down Expand Up @@ -289,3 +337,30 @@ def _day_of_week(x):
@op.auto
def _mean(x):
return sqa.func.AVG(sqa.cast(x, sqa.Double()), type_=sqa.Double())


with MsSqlImpl.op(ops.Log()) as op:

@op.auto
def _log(x):
# TODO: we still need to handle inf / -inf / nan
return sqa.case(
(x > 0, sqa.func.log(x)),
(x < 0, MsSqlImpl.NAN),
(x.is_(sqa.null()), None),
else_=-MsSqlImpl.INF,
)


with MsSqlImpl.op(ops.Ceil()) as op:

@op.auto
def _ceil(x):
return sqa.func.ceiling(x)


with MsSqlImpl.op(ops.StrToDateTime()) as op:

@op.auto
def _str_to_datetime(x):
return sqa.cast(x, DATETIME2)
95 changes: 66 additions & 29 deletions src/pydiverse/transform/backend/polars.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import datetime
from types import NoneType
from typing import Any
from uuid import UUID

Expand All @@ -15,6 +13,7 @@
from pydiverse.transform.tree.ast import AstNode
from pydiverse.transform.tree.col_expr import (
CaseExpr,
Cast,
Col,
ColExpr,
ColFn,
Expand Down Expand Up @@ -159,7 +158,7 @@ def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr:

# the function was executed on the ordered arguments. here we
# restore the original order of the table.
inv_permutation = pl.int_range(0, pl.len(), dtype=pl.Int64).sort_by(
inv_permutation = pl.int_range(0, pl.len(), dtype=pl.Int64()).sort_by(
by=order_by,
descending=descending,
nulls_last=nulls_last,
Expand All @@ -182,10 +181,20 @@ def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr:
return compiled

elif isinstance(expr, LiteralCol):
if isinstance(expr.dtype(), dtypes.String):
if expr.dtype() == dtypes.String:
return pl.lit(expr.val) # polars interprets strings as column names
return expr.val

elif isinstance(expr, Cast):
compiled = compile_col_expr(expr.val, name_in_df).cast(
pdt_type_to_polars(expr.target_type)
)

if expr.val.dtype() == dtypes.Float64 and expr.target_type == dtypes.String:
compiled = compiled.replace("NaN", "nan")

return compiled

else:
raise AssertionError

Expand Down Expand Up @@ -340,9 +349,9 @@ def has_path_to_leaf_without_agg(expr: ColExpr):

def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype:
if t.is_float():
return dtypes.Float()
return dtypes.Float64()
elif t.is_integer():
return dtypes.Int()
return dtypes.Int64()
elif isinstance(t, pl.Boolean):
return dtypes.Bool()
elif isinstance(t, pl.String):
Expand All @@ -360,9 +369,9 @@ def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype:


def pdt_type_to_polars(t: dtypes.Dtype) -> pl.DataType:
if isinstance(t, dtypes.Float):
if isinstance(t, (dtypes.Float64, dtypes.Decimal)):
return pl.Float64()
elif isinstance(t, dtypes.Int):
elif isinstance(t, dtypes.Int64):
return pl.Int64()
elif isinstance(t, dtypes.Bool):
return pl.Boolean()
Expand All @@ -380,27 +389,6 @@ def pdt_type_to_polars(t: dtypes.Dtype) -> pl.DataType:
raise AssertionError


def python_type_to_polars(t: type) -> pl.DataType:
if t is int:
return pl.Int64()
elif t is float:
return pl.Float64()
elif t is bool:
return pl.Boolean()
elif t is str:
return pl.String()
elif t is datetime.datetime:
return pl.Datetime()
elif t is datetime.date:
return pl.Date()
elif t is datetime.timedelta:
return pl.Duration()
elif t is NoneType:
return pl.Null()

raise TypeError(f"python builtin type {t} is not supported by pydiverse.transform")


with PolarsImpl.op(ops.Mean()) as op:

@op.auto
Expand Down Expand Up @@ -709,3 +697,52 @@ def _greatest(*x):
@op.auto
def _least(*x):
return pl.min_horizontal(*x)


with PolarsImpl.op(ops.Round()) as op:

@op.auto
def _round(x, digits=0):
return x.round(digits)


with PolarsImpl.op(ops.Exp()) as op:

@op.auto
def _exp(x):
return x.exp()


with PolarsImpl.op(ops.Log()) as op:

@op.auto
def _log(x):
return x.log()


with PolarsImpl.op(ops.Floor()) as op:

@op.auto
def _floor(x):
return x.floor()


with PolarsImpl.op(ops.Ceil()) as op:

@op.auto
def _ceil(x):
return x.ceil()


with PolarsImpl.op(ops.StrToDateTime()) as op:

@op.auto
def _str_to_datetime(x):
return x.str.to_datetime()


with PolarsImpl.op(ops.StrToDate()) as op:

@op.auto
def _str_to_date(x):
return x.str.to_date()
21 changes: 21 additions & 0 deletions src/pydiverse/transform/backend/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,32 @@

from pydiverse.transform import ops
from pydiverse.transform.backend.sql import SqlImpl
from pydiverse.transform.tree import dtypes
from pydiverse.transform.tree.col_expr import Cast


class PostgresImpl(SqlImpl):
dialect_name = "postgresql"

@classmethod
def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast:
compiled_val = cls.compile_col_expr(cast.val, sqa_col)

if isinstance(cast.val.dtype(), dtypes.Float64):
if isinstance(cast.target_type, dtypes.Int64):
return sqa.func.trunc(compiled_val).cast(sqa.BigInteger())

if isinstance(cast.target_type, dtypes.String):
compiled = sqa.cast(compiled_val, sqa.String)
return sqa.case(
(compiled == "NaN", "nan"),
(compiled == "Infinity", "inf"),
(compiled == "-Infinity", "-inf"),
else_=compiled,
)

return sqa.cast(compiled_val, cls.sqa_type(cast.target_type))


with PostgresImpl.op(ops.Less()) as op:

Expand Down
Loading
Loading