Skip to content

Commit

Permalink
Merge pull request #26 from pydiverse/lazy-tree-eval
Browse files Browse the repository at this point in the history
translate lazily to the backends
  • Loading branch information
finn-rudolph authored Sep 26, 2024
2 parents a79906e + 1f1afc6 commit 6a547d6
Show file tree
Hide file tree
Showing 76 changed files with 4,683 additions and 6,269 deletions.
4 changes: 2 additions & 2 deletions docs/package/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ from pydiverse.transform.lazy import SQLTableImpl
from pydiverse.transform.eager import PandasTableImpl
from pydiverse.transform.core.verbs import *
import pandas as pd
import sqlalchemy as sa
import sqlalchemy as sqa


def main():
Expand Down Expand Up @@ -52,7 +52,7 @@ def main():
print("\nPandas based result:")
print(out1)

engine = sa.create_engine("sqlite:///:memory:")
engine = sqa.create_engine("sqlite:///:memory:")
dfA.to_sql("dfA", engine, index=False, if_exists="replace")
dfB.to_sql("dfB", engine, index=False, if_exists="replace")
input1 = Table(SQLTableImpl(engine, "dfA"))
Expand Down
23 changes: 16 additions & 7 deletions src/pydiverse/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
from __future__ import annotations

from pydiverse.transform.core import functions
from pydiverse.transform.core.alignment import aligned, eval_aligned
from pydiverse.transform.core.dispatchers import verb
from pydiverse.transform.core.expressions.lambda_getter import C
from pydiverse.transform.core.table import Table
from pydiverse.transform.backend.targets import DuckDb, Polars, SqlAlchemy
from pydiverse.transform.pipe.c import C
from pydiverse.transform.pipe.functions import (
count,
dense_rank,
max,
min,
rank,
row_number,
when,
)
from pydiverse.transform.pipe.pipeable import verb
from pydiverse.transform.pipe.table import Table

__all__ = [
"Polars",
"SqlAlchemy",
"DuckDb",
"Table",
"aligned",
"eval_aligned",
"functions",
"verb",
"C",
]
4 changes: 2 additions & 2 deletions src/pydiverse/transform/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import TYPE_CHECKING, Callable, TypeVar

if TYPE_CHECKING:
from pydiverse.transform.core.table_impl import AbstractTableImpl
from pydiverse.transform.backend.table_impl import TableImpl


T = TypeVar("T")
ImplT = TypeVar("ImplT", bound="AbstractTableImpl")
ImplT = TypeVar("ImplT", bound="TableImpl")
CallableT = TypeVar("CallableT", bound=Callable)
10 changes: 10 additions & 0 deletions src/pydiverse/transform/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

from .duckdb import DuckDbImpl
from .mssql import MsSqlImpl
from .polars import PolarsImpl
from .postgres import PostgresImpl
from .sql import SqlImpl
from .sqlite import SqliteImpl
from .table_impl import TableImpl
from .targets import DuckDb, Polars, SqlAlchemy
23 changes: 23 additions & 0 deletions src/pydiverse/transform/backend/duckdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

import polars as pl

from pydiverse.transform.backend import sql
from pydiverse.transform.backend.sql import SqlImpl
from pydiverse.transform.backend.targets import Polars, Target
from pydiverse.transform.tree.ast import AstNode
from pydiverse.transform.tree.col_expr import Col


class DuckDbImpl(SqlImpl):
dialect_name = "duckdb"

@classmethod
def export(cls, nd: AstNode, target: Target, final_select: list[Col]):
if isinstance(target, Polars):
engine = sql.get_engine(nd)
with engine.connect() as conn:
return pl.read_database(
DuckDbImpl.build_query(nd, final_select), connection=conn
)
return SqlImpl.export(nd, target, final_select)
291 changes: 291 additions & 0 deletions src/pydiverse/transform/backend/mssql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
from __future__ import annotations

import copy
import functools
from typing import Any

import sqlalchemy as sqa

from pydiverse.transform import ops
from pydiverse.transform.backend import sql
from pydiverse.transform.backend.sql import SqlImpl
from pydiverse.transform.tree import dtypes, verbs
from pydiverse.transform.tree.ast import AstNode
from pydiverse.transform.tree.col_expr import (
CaseExpr,
Col,
ColExpr,
ColFn,
LiteralCol,
Order,
)
from pydiverse.transform.util.warnings import warn_non_standard


class MsSqlImpl(SqlImpl):
dialect_name = "mssql"

@classmethod
def build_select(cls, nd: AstNode, final_select: list[Col]) -> Any:
# boolean / bit conversion
for desc in nd.iter_subtree():
if isinstance(desc, verbs.Verb):
desc.map_col_roots(
functools.partial(
convert_bool_bit,
wants_bool_as_bit=not isinstance(
desc, (verbs.Filter, verbs.Join)
),
)
)

# workaround for correct nulls_first / nulls_last behaviour on MSSQL
for desc in nd.iter_subtree():
if isinstance(desc, verbs.Arrange):
desc.order_by = convert_order_list(desc.order_by)
if isinstance(desc, verbs.Verb):
for node in desc.iter_col_nodes():
if isinstance(node, ColFn) and (
arrange := node.context_kwargs.get("arrange")
):
node.context_kwargs["arrange"] = convert_order_list(arrange)

sql.create_aliases(nd, {})
table, query, _ = cls.compile_ast(nd, {col._uuid: 1 for col in final_select})
return cls.compile_query(table, query)


def convert_order_list(order_list: list[Order]) -> list[Order]:
new_list: list[Order] = []
for ord in order_list:
# is True / is False are important here since we don't want to do this costly
# workaround if nulls_last is None (i.e. the user doesn't care)
if ord.nulls_last is True and not ord.descending:
new_list.append(
Order(
CaseExpr([(ord.order_by.is_null(), LiteralCol(1))], LiteralCol(0)),
)
)

elif ord.nulls_last is False and ord.descending:
new_list.append(
Order(
CaseExpr([(ord.order_by.is_null(), LiteralCol(0))], LiteralCol(1)),
)
)

new_list.append(Order(ord.order_by, ord.descending, None))

return new_list


# MSSQL doesn't have a boolean type. This means that expressions that return a boolean
# (e.g. ==, !=, >) can't be used in other expressions without casting to the BIT type.
# Conversely, after casting to BIT, we sometimes may need to convert back to booleans.


def convert_bool_bit(expr: ColExpr | Order, wants_bool_as_bit: bool) -> ColExpr | Order:
if isinstance(expr, Order):
return Order(
convert_bool_bit(expr.order_by, wants_bool_as_bit),
expr.descending,
expr.nulls_last,
)

elif isinstance(expr, Col):
if not wants_bool_as_bit and isinstance(expr.dtype(), dtypes.Bool):
return ColFn("__eq__", expr, LiteralCol(True))
return expr

elif isinstance(expr, ColFn):
op = MsSqlImpl.registry.get_op(expr.name)
wants_bool_as_bit_input = not isinstance(
op, (ops.logical.BooleanBinary, ops.logical.Invert)
)

converted = copy.copy(expr)
converted.args = [
convert_bool_bit(arg, wants_bool_as_bit_input) for arg in expr.args
]
converted.context_kwargs = {
key: [convert_bool_bit(val, wants_bool_as_bit) for val in arr]
for key, arr in expr.context_kwargs.items()
}

impl = MsSqlImpl.registry.get_impl(
expr.name, tuple(arg.dtype() for arg in expr.args)
)

if isinstance(impl.return_type, dtypes.Bool):
returns_bool_as_bit = not isinstance(op, ops.logical.Logical)

if wants_bool_as_bit and not returns_bool_as_bit:
return CaseExpr(
[(converted, LiteralCol(True)), (~converted, LiteralCol(False))],
None,
)
elif not wants_bool_as_bit and returns_bool_as_bit:
return ColFn("__eq__", converted, LiteralCol(True))

return converted

elif isinstance(expr, CaseExpr):
converted = copy.copy(expr)
converted.cases = [
(convert_bool_bit(cond, False), convert_bool_bit(val, True))
for cond, val in expr.cases
]
converted.default_val = (
None
if expr.default_val is None
else convert_bool_bit(expr.default_val, wants_bool_as_bit)
)

return converted

elif isinstance(expr, LiteralCol):
return expr

raise AssertionError


with MsSqlImpl.op(ops.Equal()) as op:

@op("str, str -> bool")
def _eq(x, y):
warn_non_standard(
"MSSQL ignores trailing whitespace when comparing strings",
)
return x == y


with MsSqlImpl.op(ops.NotEqual()) as op:

@op("str, str -> bool")
def _ne(x, y):
warn_non_standard(
"MSSQL ignores trailing whitespace when comparing strings",
)
return x != y


with MsSqlImpl.op(ops.Less()) as op:

@op("str, str -> bool")
def _lt(x, y):
warn_non_standard(
"MSSQL ignores trailing whitespace when comparing strings",
)
return x < y


with MsSqlImpl.op(ops.LessEqual()) as op:

@op("str, str -> bool")
def _le(x, y):
warn_non_standard(
"MSSQL ignores trailing whitespace when comparing strings",
)
return x <= y


with MsSqlImpl.op(ops.Greater()) as op:

@op("str, str -> bool")
def _gt(x, y):
warn_non_standard(
"MSSQL ignores trailing whitespace when comparing strings",
)
return x > y


with MsSqlImpl.op(ops.GreaterEqual()) as op:

@op("str, str -> bool")
def _ge(x, y):
warn_non_standard(
"MSSQL ignores trailing whitespace when comparing strings",
)
return x >= y


with MsSqlImpl.op(ops.Pow()) as op:

@op.auto
def _pow(lhs, rhs):
# In MSSQL, the output type of pow is the same as the input type.
# This means, that if lhs is a decimal, then we may very easily loose
# a lot of precision if the exponent is <= 1
# https://learn.microsoft.com/en-us/sql/t-sql/functions/power-transact-sql?view=sql-server-ver16
return sqa.func.POWER(sqa.cast(lhs, sqa.Double()), rhs, type_=sqa.Double())


with MsSqlImpl.op(ops.RPow()) as op:

@op.auto
def _rpow(rhs, lhs):
return _pow(lhs, rhs)


with MsSqlImpl.op(ops.StrLen()) as op:

@op.auto
def _str_length(x):
return sqa.func.LENGTH(x + "a", type_=sqa.Integer()) - 1


with MsSqlImpl.op(ops.StrReplaceAll()) as op:

@op.auto
def _replace_all(x, y, z):
x = x.collate("Latin1_General_CS_AS")
return sqa.func.REPLACE(x, y, z, type_=x.type)


with MsSqlImpl.op(ops.StrStartsWith()) as op:

@op.auto
def _startswith(x, y):
x = x.collate("Latin1_General_CS_AS")
return x.startswith(y, autoescape=True)


with MsSqlImpl.op(ops.StrEndsWith()) as op:

@op.auto
def _endswith(x, y):
x = x.collate("Latin1_General_CS_AS")
return x.endswith(y, autoescape=True)


with MsSqlImpl.op(ops.StrContains()) as op:

@op.auto
def _contains(x, y):
x = x.collate("Latin1_General_CS_AS")
return x.contains(y, autoescape=True)


with MsSqlImpl.op(ops.StrSlice()) as op:

@op.auto
def _str_slice(x, offset, length):
return sqa.func.SUBSTRING(x, offset + 1, length)


with MsSqlImpl.op(ops.DtDayOfWeek()) as op:

@op.auto
def _day_of_week(x):
# Offset DOW such that Mon=1, Sun=7
_1 = sqa.literal_column("1")
_2 = sqa.literal_column("2")
_7 = sqa.literal_column("7")
return (sqa.extract("dow", x) + sqa.text("@@DATEFIRST") - _2) % _7 + _1


with MsSqlImpl.op(ops.Mean()) as op:

@op.auto
def _mean(x):
return sqa.func.AVG(sqa.cast(x, sqa.Double()), type_=sqa.Double())
Loading

0 comments on commit 6a547d6

Please sign in to comment.