diff --git a/docs/package/README.md b/docs/package/README.md index 054622b2..dc5d20b7 100644 --- a/docs/package/README.md +++ b/docs/package/README.md @@ -23,7 +23,7 @@ from pydiverse.transform.lazy import SQLTableImpl from pydiverse.transform.eager import PandasTableImpl from pydiverse.transform.core.verbs import * import pandas as pd -import sqlalchemy as sa +import sqlalchemy as sqa def main(): @@ -52,7 +52,7 @@ def main(): print("\nPandas based result:") print(out1) - engine = sa.create_engine("sqlite:///:memory:") + engine = sqa.create_engine("sqlite:///:memory:") dfA.to_sql("dfA", engine, index=False, if_exists="replace") dfB.to_sql("dfB", engine, index=False, if_exists="replace") input1 = Table(SQLTableImpl(engine, "dfA")) diff --git a/src/pydiverse/transform/__init__.py b/src/pydiverse/transform/__init__.py index 714f34c7..fe548e54 100644 --- a/src/pydiverse/transform/__init__.py +++ b/src/pydiverse/transform/__init__.py @@ -1,16 +1,25 @@ from __future__ import annotations -from pydiverse.transform.core import functions -from pydiverse.transform.core.alignment import aligned, eval_aligned -from pydiverse.transform.core.dispatchers import verb -from pydiverse.transform.core.expressions.lambda_getter import C -from pydiverse.transform.core.table import Table +from pydiverse.transform.backend.targets import DuckDb, Polars, SqlAlchemy +from pydiverse.transform.pipe.c import C +from pydiverse.transform.pipe.functions import ( + count, + dense_rank, + max, + min, + rank, + row_number, + when, +) +from pydiverse.transform.pipe.pipeable import verb +from pydiverse.transform.pipe.table import Table __all__ = [ + "Polars", + "SqlAlchemy", + "DuckDb", "Table", "aligned", - "eval_aligned", - "functions", "verb", "C", ] diff --git a/src/pydiverse/transform/_typing.py b/src/pydiverse/transform/_typing.py index e6577509..9406418f 100644 --- a/src/pydiverse/transform/_typing.py +++ b/src/pydiverse/transform/_typing.py @@ -3,9 +3,9 @@ from typing import TYPE_CHECKING, Callable, TypeVar if TYPE_CHECKING: - from pydiverse.transform.core.table_impl import AbstractTableImpl + from pydiverse.transform.backend.table_impl import TableImpl T = TypeVar("T") -ImplT = TypeVar("ImplT", bound="AbstractTableImpl") +ImplT = TypeVar("ImplT", bound="TableImpl") CallableT = TypeVar("CallableT", bound=Callable) diff --git a/src/pydiverse/transform/backend/__init__.py b/src/pydiverse/transform/backend/__init__.py new file mode 100644 index 00000000..cc21fb29 --- /dev/null +++ b/src/pydiverse/transform/backend/__init__.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from .duckdb import DuckDbImpl +from .mssql import MsSqlImpl +from .polars import PolarsImpl +from .postgres import PostgresImpl +from .sql import SqlImpl +from .sqlite import SqliteImpl +from .table_impl import TableImpl +from .targets import DuckDb, Polars, SqlAlchemy diff --git a/src/pydiverse/transform/backend/duckdb.py b/src/pydiverse/transform/backend/duckdb.py new file mode 100644 index 00000000..5937e3e5 --- /dev/null +++ b/src/pydiverse/transform/backend/duckdb.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import polars as pl + +from pydiverse.transform.backend import sql +from pydiverse.transform.backend.sql import SqlImpl +from pydiverse.transform.backend.targets import Polars, Target +from pydiverse.transform.tree.ast import AstNode +from pydiverse.transform.tree.col_expr import Col + + +class DuckDbImpl(SqlImpl): + dialect_name = "duckdb" + + @classmethod + def export(cls, nd: AstNode, target: Target, final_select: list[Col]): + if isinstance(target, Polars): + engine = sql.get_engine(nd) + with engine.connect() as conn: + return pl.read_database( + DuckDbImpl.build_query(nd, final_select), connection=conn + ) + return SqlImpl.export(nd, target, final_select) diff --git a/src/pydiverse/transform/backend/mssql.py b/src/pydiverse/transform/backend/mssql.py new file mode 100644 index 00000000..e00a6309 --- /dev/null +++ b/src/pydiverse/transform/backend/mssql.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +import copy +import functools +from typing import Any + +import sqlalchemy as sqa + +from pydiverse.transform import ops +from pydiverse.transform.backend import sql +from pydiverse.transform.backend.sql import SqlImpl +from pydiverse.transform.tree import dtypes, verbs +from pydiverse.transform.tree.ast import AstNode +from pydiverse.transform.tree.col_expr import ( + CaseExpr, + Col, + ColExpr, + ColFn, + LiteralCol, + Order, +) +from pydiverse.transform.util.warnings import warn_non_standard + + +class MsSqlImpl(SqlImpl): + dialect_name = "mssql" + + @classmethod + def build_select(cls, nd: AstNode, final_select: list[Col]) -> Any: + # boolean / bit conversion + for desc in nd.iter_subtree(): + if isinstance(desc, verbs.Verb): + desc.map_col_roots( + functools.partial( + convert_bool_bit, + wants_bool_as_bit=not isinstance( + desc, (verbs.Filter, verbs.Join) + ), + ) + ) + + # workaround for correct nulls_first / nulls_last behaviour on MSSQL + for desc in nd.iter_subtree(): + if isinstance(desc, verbs.Arrange): + desc.order_by = convert_order_list(desc.order_by) + if isinstance(desc, verbs.Verb): + for node in desc.iter_col_nodes(): + if isinstance(node, ColFn) and ( + arrange := node.context_kwargs.get("arrange") + ): + node.context_kwargs["arrange"] = convert_order_list(arrange) + + sql.create_aliases(nd, {}) + table, query, _ = cls.compile_ast(nd, {col._uuid: 1 for col in final_select}) + return cls.compile_query(table, query) + + +def convert_order_list(order_list: list[Order]) -> list[Order]: + new_list: list[Order] = [] + for ord in order_list: + # is True / is False are important here since we don't want to do this costly + # workaround if nulls_last is None (i.e. the user doesn't care) + if ord.nulls_last is True and not ord.descending: + new_list.append( + Order( + CaseExpr([(ord.order_by.is_null(), LiteralCol(1))], LiteralCol(0)), + ) + ) + + elif ord.nulls_last is False and ord.descending: + new_list.append( + Order( + CaseExpr([(ord.order_by.is_null(), LiteralCol(0))], LiteralCol(1)), + ) + ) + + new_list.append(Order(ord.order_by, ord.descending, None)) + + return new_list + + +# MSSQL doesn't have a boolean type. This means that expressions that return a boolean +# (e.g. ==, !=, >) can't be used in other expressions without casting to the BIT type. +# Conversely, after casting to BIT, we sometimes may need to convert back to booleans. + + +def convert_bool_bit(expr: ColExpr | Order, wants_bool_as_bit: bool) -> ColExpr | Order: + if isinstance(expr, Order): + return Order( + convert_bool_bit(expr.order_by, wants_bool_as_bit), + expr.descending, + expr.nulls_last, + ) + + elif isinstance(expr, Col): + if not wants_bool_as_bit and isinstance(expr.dtype(), dtypes.Bool): + return ColFn("__eq__", expr, LiteralCol(True)) + return expr + + elif isinstance(expr, ColFn): + op = MsSqlImpl.registry.get_op(expr.name) + wants_bool_as_bit_input = not isinstance( + op, (ops.logical.BooleanBinary, ops.logical.Invert) + ) + + converted = copy.copy(expr) + converted.args = [ + convert_bool_bit(arg, wants_bool_as_bit_input) for arg in expr.args + ] + converted.context_kwargs = { + key: [convert_bool_bit(val, wants_bool_as_bit) for val in arr] + for key, arr in expr.context_kwargs.items() + } + + impl = MsSqlImpl.registry.get_impl( + expr.name, tuple(arg.dtype() for arg in expr.args) + ) + + if isinstance(impl.return_type, dtypes.Bool): + returns_bool_as_bit = not isinstance(op, ops.logical.Logical) + + if wants_bool_as_bit and not returns_bool_as_bit: + return CaseExpr( + [(converted, LiteralCol(True)), (~converted, LiteralCol(False))], + None, + ) + elif not wants_bool_as_bit and returns_bool_as_bit: + return ColFn("__eq__", converted, LiteralCol(True)) + + return converted + + elif isinstance(expr, CaseExpr): + converted = copy.copy(expr) + converted.cases = [ + (convert_bool_bit(cond, False), convert_bool_bit(val, True)) + for cond, val in expr.cases + ] + converted.default_val = ( + None + if expr.default_val is None + else convert_bool_bit(expr.default_val, wants_bool_as_bit) + ) + + return converted + + elif isinstance(expr, LiteralCol): + return expr + + raise AssertionError + + +with MsSqlImpl.op(ops.Equal()) as op: + + @op("str, str -> bool") + def _eq(x, y): + warn_non_standard( + "MSSQL ignores trailing whitespace when comparing strings", + ) + return x == y + + +with MsSqlImpl.op(ops.NotEqual()) as op: + + @op("str, str -> bool") + def _ne(x, y): + warn_non_standard( + "MSSQL ignores trailing whitespace when comparing strings", + ) + return x != y + + +with MsSqlImpl.op(ops.Less()) as op: + + @op("str, str -> bool") + def _lt(x, y): + warn_non_standard( + "MSSQL ignores trailing whitespace when comparing strings", + ) + return x < y + + +with MsSqlImpl.op(ops.LessEqual()) as op: + + @op("str, str -> bool") + def _le(x, y): + warn_non_standard( + "MSSQL ignores trailing whitespace when comparing strings", + ) + return x <= y + + +with MsSqlImpl.op(ops.Greater()) as op: + + @op("str, str -> bool") + def _gt(x, y): + warn_non_standard( + "MSSQL ignores trailing whitespace when comparing strings", + ) + return x > y + + +with MsSqlImpl.op(ops.GreaterEqual()) as op: + + @op("str, str -> bool") + def _ge(x, y): + warn_non_standard( + "MSSQL ignores trailing whitespace when comparing strings", + ) + return x >= y + + +with MsSqlImpl.op(ops.Pow()) as op: + + @op.auto + def _pow(lhs, rhs): + # In MSSQL, the output type of pow is the same as the input type. + # This means, that if lhs is a decimal, then we may very easily loose + # a lot of precision if the exponent is <= 1 + # https://learn.microsoft.com/en-us/sql/t-sql/functions/power-transact-sql?view=sql-server-ver16 + return sqa.func.POWER(sqa.cast(lhs, sqa.Double()), rhs, type_=sqa.Double()) + + +with MsSqlImpl.op(ops.RPow()) as op: + + @op.auto + def _rpow(rhs, lhs): + return _pow(lhs, rhs) + + +with MsSqlImpl.op(ops.StrLen()) as op: + + @op.auto + def _str_length(x): + return sqa.func.LENGTH(x + "a", type_=sqa.Integer()) - 1 + + +with MsSqlImpl.op(ops.StrReplaceAll()) as op: + + @op.auto + def _replace_all(x, y, z): + x = x.collate("Latin1_General_CS_AS") + return sqa.func.REPLACE(x, y, z, type_=x.type) + + +with MsSqlImpl.op(ops.StrStartsWith()) as op: + + @op.auto + def _startswith(x, y): + x = x.collate("Latin1_General_CS_AS") + return x.startswith(y, autoescape=True) + + +with MsSqlImpl.op(ops.StrEndsWith()) as op: + + @op.auto + def _endswith(x, y): + x = x.collate("Latin1_General_CS_AS") + return x.endswith(y, autoescape=True) + + +with MsSqlImpl.op(ops.StrContains()) as op: + + @op.auto + def _contains(x, y): + x = x.collate("Latin1_General_CS_AS") + return x.contains(y, autoescape=True) + + +with MsSqlImpl.op(ops.StrSlice()) as op: + + @op.auto + def _str_slice(x, offset, length): + return sqa.func.SUBSTRING(x, offset + 1, length) + + +with MsSqlImpl.op(ops.DtDayOfWeek()) as op: + + @op.auto + def _day_of_week(x): + # Offset DOW such that Mon=1, Sun=7 + _1 = sqa.literal_column("1") + _2 = sqa.literal_column("2") + _7 = sqa.literal_column("7") + return (sqa.extract("dow", x) + sqa.text("@@DATEFIRST") - _2) % _7 + _1 + + +with MsSqlImpl.op(ops.Mean()) as op: + + @op.auto + def _mean(x): + return sqa.func.AVG(sqa.cast(x, sqa.Double()), type_=sqa.Double()) diff --git a/src/pydiverse/transform/backend/polars.py b/src/pydiverse/transform/backend/polars.py new file mode 100644 index 00000000..c2ed6e8a --- /dev/null +++ b/src/pydiverse/transform/backend/polars.py @@ -0,0 +1,711 @@ +from __future__ import annotations + +import datetime +from types import NoneType +from typing import Any +from uuid import UUID + +import polars as pl + +from pydiverse.transform import ops +from pydiverse.transform.backend.table_impl import TableImpl +from pydiverse.transform.backend.targets import Polars, Target +from pydiverse.transform.ops.core import Ftype +from pydiverse.transform.tree import dtypes, verbs +from pydiverse.transform.tree.ast import AstNode +from pydiverse.transform.tree.col_expr import ( + CaseExpr, + Col, + ColExpr, + ColFn, + LiteralCol, + Order, +) + + +class PolarsImpl(TableImpl): + def __init__(self, name: str, df: pl.DataFrame | pl.LazyFrame): + self.df = df + super().__init__( + name, + { + name: polars_type_to_pdt(dtype) + for name, dtype in df.collect_schema().items() + }, + ) + + @staticmethod + def build_query(nd: AstNode, final_select: list[Col]) -> None: + return None + + @staticmethod + def export(nd: AstNode, target: Target, final_select: list[Col]) -> Any: + lf, _, select, _ = compile_ast(nd) + lf = lf.select(select) + if isinstance(target, Polars): + if not target.lazy and isinstance(lf, pl.LazyFrame): + lf = lf.collect() + lf.name = nd.name + return lf + + def _clone(self) -> tuple[PolarsImpl, dict[AstNode, AstNode], dict[UUID, UUID]]: + cloned = PolarsImpl(self.name, self.df.clone()) + return ( + cloned, + {self: cloned}, + { + self.cols[name]._uuid: cloned.cols[name]._uuid + for name in self.cols.keys() + }, + ) + + +# merges descending and null_last markers into the ordering expression +def merge_desc_nulls_last( + order_by: list[pl.Expr], descending: list[bool], nulls_last: list[bool] +) -> list[pl.Expr]: + with_signs: list[pl.Expr] = [] + for ord, desc in zip(order_by, descending): + numeric = ord.rank("dense").cast(pl.Int64) + with_signs.append(-numeric if desc else numeric) + return [ + expr.fill_null( + pl.len().cast(pl.Int64) + 1 if nl else -(pl.len().cast(pl.Int64) + 1) + ) + for expr, nl in zip(with_signs, nulls_last) + ] + + +def compile_order( + order: Order, name_in_df: dict[UUID, str] +) -> tuple[pl.Expr, bool, bool]: + return ( + compile_col_expr(order.order_by, name_in_df), + order.descending, + order.nulls_last if order.nulls_last is not None else False, + ) + + +def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr: + if isinstance(expr, Col): + return pl.col(name_in_df[expr._uuid]) + + elif isinstance(expr, ColFn): + op = PolarsImpl.registry.get_op(expr.name) + args: list[pl.Expr] = [compile_col_expr(arg, name_in_df) for arg in expr.args] + impl = PolarsImpl.registry.get_impl( + expr.name, + tuple(arg.dtype() for arg in expr.args), + ) + + if (partition_by := expr.context_kwargs.get("partition_by")) is not None: + partition_by = [compile_col_expr(pb, name_in_df) for pb in partition_by] + + arrange = expr.context_kwargs.get("arrange") + if arrange: + order_by, descending, nulls_last = zip( + *[compile_order(order, name_in_df) for order in arrange] + ) + + # The following `if` block is absolutely unecessary and just an optimization. + # Otherwise, `over` would be used for sorting, but we cannot pass descending / + # nulls_last there and the required workaround is probably slower than polars`s + # native `sort_by`. + if arrange and not partition_by: + # order the args. if the table is grouped by group_by or + # partition_by=, the groups will be sorted via over(order_by=) + # anyways so it need not be done here. + args = [ + arg.sort_by(by=order_by, descending=descending, nulls_last=nulls_last) + if isinstance(arg, pl.Expr) + else arg + for arg in args + ] + + if op.name in ("rank", "dense_rank"): + assert len(expr.args) == 0 + args = [pl.struct(merge_desc_nulls_last(order_by, descending, nulls_last))] + arrange = None + + value: pl.Expr = impl(*args) + + # TODO: currently, count is the only aggregation function where we don't want + # to return null for cols containing only null values. If this happens for more + # aggregation functions, make this configurable in e.g. the operator spec. + if op.ftype == Ftype.AGGREGATE and op.name != "count": + # In `sum` / `any` and other aggregation functions, polars puts a + # default value (e.g. 0, False) for empty columns, but we want to put + # Null in this case to let the user decide about the default value via + # `fill_null` if he likes to set one. + assert all(arg.dtype().const for arg in expr.args[1:]) + value = pl.when(args[0].count() == 0).then(None).otherwise(value) + + if partition_by: + # when doing sort_by -> over in polars, for whatever reason the + # `nulls_last` argument is ignored. thus when both a grouping and an + # arrangment are specified, we manually add the descending and + # nulls_last markers to the ordering. + if arrange: + order_by = merge_desc_nulls_last(order_by, descending, nulls_last) + else: + order_by = None + value = value.over(partition_by, order_by=order_by) + + elif arrange: + if op.ftype == Ftype.AGGREGATE: + # TODO: don't fail, but give a warning that `arrange` is useless + # here + ... + + # the function was executed on the ordered arguments. here we + # restore the original order of the table. + inv_permutation = pl.int_range(0, pl.len(), dtype=pl.Int64).sort_by( + by=order_by, + descending=descending, + nulls_last=nulls_last, + ) + value = value.sort_by(inv_permutation) + + return value + + elif isinstance(expr, CaseExpr): + assert len(expr.cases) >= 1 + compiled = pl # to initialize the when/then-chain + for cond, val in expr.cases: + compiled = compiled.when(compile_col_expr(cond, name_in_df)).then( + compile_col_expr(val, name_in_df) + ) + if expr.default_val is not None: + compiled = compiled.otherwise( + compile_col_expr(expr.default_val, name_in_df) + ) + return compiled + + elif isinstance(expr, LiteralCol): + if isinstance(expr.dtype(), dtypes.String): + return pl.lit(expr.val) # polars interprets strings as column names + return expr.val + + else: + raise AssertionError + + +def compile_join_cond( + expr: ColExpr, name_in_df: dict[UUID, str] +) -> list[tuple[pl.Expr, pl.Expr]]: + if isinstance(expr, ColFn): + if expr.name == "__and__": + return compile_join_cond(expr.args[0], name_in_df) + compile_join_cond( + expr.args[1], name_in_df + ) + if expr.name == "__eq__": + return [ + ( + compile_col_expr(expr.args[0], name_in_df), + compile_col_expr(expr.args[1], name_in_df), + ) + ] + + raise AssertionError() + + +def compile_ast( + nd: AstNode, +) -> tuple[pl.LazyFrame, dict[UUID, str], list[str], list[UUID]]: + if isinstance(nd, verbs.Verb): + df, name_in_df, select, partition_by = compile_ast(nd.child) + + if isinstance(nd, (verbs.Mutate, verbs.Summarise)): + overwritten = set(name for name in nd.names if name in set(select)) + if overwritten: + # We rename overwritten cols to some unique dummy name + name_map = {name: f"{name}_{str(hex(id(nd)))[2:]}" for name in overwritten} + name_in_df = { + uid: (name_map[name] if name in name_map else name) + for uid, name in name_in_df.items() + } + df = df.rename(name_map) + + select = [col_name for col_name in select if col_name not in overwritten] + + if isinstance(nd, verbs.Select): + select = [name_in_df[col._uuid] for col in nd.select] + + elif isinstance(nd, verbs.Rename): + df = df.rename(nd.name_map) + name_in_df = { + uid: (nd.name_map[name] if name in nd.name_map else name) + for uid, name in name_in_df.items() + } + select = [ + nd.name_map[col_name] if col_name in nd.name_map else col_name + for col_name in select + ] + + elif isinstance(nd, verbs.Mutate): + df = df.with_columns( + **{ + name: compile_col_expr(value, name_in_df) + for name, value in zip(nd.names, nd.values) + } + ) + + name_in_df.update({uid: name for uid, name in zip(nd.uuids, nd.names)}) + select += nd.names + + elif isinstance(nd, verbs.Filter): + df = df.filter([compile_col_expr(fil, name_in_df) for fil in nd.filters]) + + elif isinstance(nd, verbs.Arrange): + order_by, descending, nulls_last = zip( + *[compile_order(order, name_in_df) for order in nd.order_by] + ) + df = df.sort( + order_by, + descending=descending, + nulls_last=nulls_last, + maintain_order=True, + ) + + elif isinstance(nd, verbs.Summarise): + # We support usage of aggregated columns in expressions in summarise, but polars + # creates arrays when doing that. Thus we unwrap the arrays when necessary. + def has_path_to_leaf_without_agg(expr: ColExpr): + if isinstance(expr, Col): + return True + if ( + isinstance(expr, ColFn) + and PolarsImpl.registry.get_op(expr.name).ftype == Ftype.AGGREGATE + ): + return False + return any( + has_path_to_leaf_without_agg(child) for child in expr.iter_children() + ) + + aggregations = {} + for name, val in zip(nd.names, nd.values): + compiled = compile_col_expr(val, name_in_df) + if has_path_to_leaf_without_agg(val): + compiled = compiled.first() + aggregations[name] = compiled + + if partition_by: + df = df.group_by(*(name_in_df[uid] for uid in partition_by)).agg( + **aggregations + ) + else: + df = df.select(**aggregations) + + name_in_df.update({uid: name for name, uid in zip(nd.names, nd.uuids)}) + select = [*(name_in_df[uid] for uid in partition_by), *nd.names] + partition_by = [] + + elif isinstance(nd, verbs.SliceHead): + df = df.slice(nd.offset, nd.n) + + elif isinstance(nd, verbs.GroupBy): + new_group_by = [col._uuid for col in nd.group_by] + partition_by = partition_by + new_group_by if nd.add else new_group_by + + elif isinstance(nd, verbs.Ungroup): + partition_by = [] + + elif isinstance(nd, verbs.Join): + right_df, right_name_in_df, right_select, _ = compile_ast(nd.right) + name_in_df.update( + {uid: name + nd.suffix for uid, name in right_name_in_df.items()} + ) + left_on, right_on = zip(*compile_join_cond(nd.on, name_in_df)) + + assert len(partition_by) == 0 + select += [col_name + nd.suffix for col_name in right_select] + + df = df.join( + right_df.rename({name: name + nd.suffix for name in right_df.columns}), + left_on=left_on, + right_on=right_on, + how=nd.how, + validate=nd.validate, + coalesce=False, + ) + + elif isinstance(nd, PolarsImpl): + df = nd.df + name_in_df = {col._uuid: col.name for col in nd.cols.values()} + select = list(nd.cols.keys()) + partition_by = [] + + return df, name_in_df, select, partition_by + + +def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype: + if t.is_float(): + return dtypes.Float() + elif t.is_integer(): + return dtypes.Int() + elif isinstance(t, pl.Boolean): + return dtypes.Bool() + elif isinstance(t, pl.String): + return dtypes.String() + elif isinstance(t, pl.Datetime): + return dtypes.DateTime() + elif isinstance(t, pl.Date): + return dtypes.Date() + elif isinstance(t, pl.Duration): + return dtypes.Duration() + elif isinstance(t, pl.Null): + return dtypes.NoneDtype() + + raise TypeError(f"polars type {t} is not supported by pydiverse.transform") + + +def pdt_type_to_polars(t: dtypes.Dtype) -> pl.DataType: + if isinstance(t, dtypes.Float): + return pl.Float64() + elif isinstance(t, dtypes.Int): + return pl.Int64() + elif isinstance(t, dtypes.Bool): + return pl.Boolean() + elif isinstance(t, dtypes.String): + return pl.String() + elif isinstance(t, dtypes.DateTime): + return pl.Datetime() + elif isinstance(t, dtypes.Date): + return pl.Date() + elif isinstance(t, dtypes.Duration): + return pl.Duration() + elif isinstance(t, dtypes.NoneDtype): + return pl.Null() + + raise AssertionError + + +def python_type_to_polars(t: type) -> pl.DataType: + if t is int: + return pl.Int64() + elif t is float: + return pl.Float64() + elif t is bool: + return pl.Boolean() + elif t is str: + return pl.String() + elif t is datetime.datetime: + return pl.Datetime() + elif t is datetime.date: + return pl.Date() + elif t is datetime.timedelta: + return pl.Duration() + elif t is NoneType: + return pl.Null() + + raise TypeError(f"python builtin type {t} is not supported by pydiverse.transform") + + +with PolarsImpl.op(ops.Mean()) as op: + + @op.auto + def _mean(x): + return x.mean() + + +with PolarsImpl.op(ops.Min()) as op: + + @op.auto + def _min(x): + return x.min() + + +with PolarsImpl.op(ops.Max()) as op: + + @op.auto + def _max(x): + return x.max() + + +with PolarsImpl.op(ops.Sum()) as op: + + @op.auto + def _sum(x): + return x.sum() + + +with PolarsImpl.op(ops.All()) as op: + + @op.auto + def _all(x): + return x.all() + + +with PolarsImpl.op(ops.Any()) as op: + + @op.auto + def _any(x): + return x.any() + + +with PolarsImpl.op(ops.IsNull()) as op: + + @op.auto + def _is_null(x): + return x.is_null() + + +with PolarsImpl.op(ops.IsNotNull()) as op: + + @op.auto + def _is_not_null(x): + return x.is_not_null() + + +with PolarsImpl.op(ops.FillNull()) as op: + + @op.auto + def _fill_null(x, y): + return x.fill_null(y) + + +with PolarsImpl.op(ops.DtYear()) as op: + + @op.auto + def _dt_year(x): + return x.dt.year() + + +with PolarsImpl.op(ops.DtMonth()) as op: + + @op.auto + def _dt_month(x): + return x.dt.month() + + +with PolarsImpl.op(ops.DtDay()) as op: + + @op.auto + def _dt_day(x): + return x.dt.day() + + +with PolarsImpl.op(ops.DtHour()) as op: + + @op.auto + def _dt_hour(x): + return x.dt.hour() + + +with PolarsImpl.op(ops.DtMinute()) as op: + + @op.auto + def _dt_minute(x): + return x.dt.minute() + + +with PolarsImpl.op(ops.DtSecond()) as op: + + @op.auto + def _dt_second(x): + return x.dt.second() + + +with PolarsImpl.op(ops.DtMillisecond()) as op: + + @op.auto + def _dt_millisecond(x): + return x.dt.millisecond() + + +with PolarsImpl.op(ops.DtDayOfWeek()) as op: + + @op.auto + def _dt_day_of_week(x): + return x.dt.weekday() + + +with PolarsImpl.op(ops.DtDayOfYear()) as op: + + @op.auto + def _dt_day_of_year(x): + return x.dt.ordinal_day() + + +with PolarsImpl.op(ops.DtDays()) as op: + + @op.auto + def _days(x): + return x.dt.total_days() + + +with PolarsImpl.op(ops.DtHours()) as op: + + @op.auto + def _hours(x): + return x.dt.total_hours() + + +with PolarsImpl.op(ops.DtMinutes()) as op: + + @op.auto + def _minutes(x): + return x.dt.total_minutes() + + +with PolarsImpl.op(ops.DtSeconds()) as op: + + @op.auto + def _seconds(x): + return x.dt.total_seconds() + + +with PolarsImpl.op(ops.DtMilliseconds()) as op: + + @op.auto + def _milliseconds(x): + return x.dt.total_milliseconds() + + +with PolarsImpl.op(ops.Sub()) as op: + + @op.extension(ops.DtSub) + def _dt_sub(lhs, rhs): + return lhs - rhs + + +with PolarsImpl.op(ops.RSub()) as op: + + @op.extension(ops.DtRSub) + def _dt_rsub(rhs, lhs): + return lhs - rhs + + +with PolarsImpl.op(ops.Add()) as op: + + @op.extension(ops.DtDurAdd) + def _dt_dur_add(lhs, rhs): + return lhs + rhs + + +with PolarsImpl.op(ops.RAdd()) as op: + + @op.extension(ops.DtDurRAdd) + def _dt_dur_radd(rhs, lhs): + return lhs + rhs + + +with PolarsImpl.op(ops.RowNumber()) as op: + + @op.auto + def _row_number(): + return pl.int_range(start=1, end=pl.len() + 1, dtype=pl.Int64) + + +with PolarsImpl.op(ops.Rank()) as op: + + @op.auto + def _rank(x): + return x.rank("min").cast(pl.Int64) + + +with PolarsImpl.op(ops.DenseRank()) as op: + + @op.auto + def _dense_rank(x): + return x.rank("dense").cast(pl.Int64) + + +with PolarsImpl.op(ops.Shift()) as op: + + @op.auto + def _shift(x, n, fill_value=None): + return x.shift(n, fill_value=fill_value) + + +with PolarsImpl.op(ops.IsIn()) as op: + + @op.auto + def _isin(x, *values): + return pl.any_horizontal( + (x == v if v is not None else x.is_null()) for v in values + ) + + +with PolarsImpl.op(ops.StrContains()) as op: + + @op.auto + def _contains(x, y): + return x.str.contains(y) + + +with PolarsImpl.op(ops.StrStartsWith()) as op: + + @op.auto + def _starts_with(x, y): + return x.str.starts_with(y) + + +with PolarsImpl.op(ops.StrEndsWith()) as op: + + @op.auto + def _ends_with(x, y): + return x.str.ends_with(y) + + +with PolarsImpl.op(ops.StrToLower()) as op: + + @op.auto + def _lower(x): + return x.str.to_lowercase() + + +with PolarsImpl.op(ops.StrToUpper()) as op: + + @op.auto + def _upper(x): + return x.str.to_uppercase() + + +with PolarsImpl.op(ops.StrReplaceAll()) as op: + + @op.auto + def _replace_all(x, to_replace, replacement): + return x.str.replace_all(to_replace, replacement) + + +with PolarsImpl.op(ops.StrLen()) as op: + + @op.auto + def _string_length(x): + return x.str.len_chars().cast(pl.Int64) + + +with PolarsImpl.op(ops.StrStrip()) as op: + + @op.auto + def _str_strip(x): + return x.str.strip_chars() + + +with PolarsImpl.op(ops.StrSlice()) as op: + + @op.auto + def _str_slice(x, offset, length): + return x.str.slice(offset, length) + + +with PolarsImpl.op(ops.Count()) as op: + + @op.auto + def _count(x=None): + return pl.len() if x is None else x.count() + + +with PolarsImpl.op(ops.Greatest()) as op: + + @op.auto + def _greatest(*x): + return pl.max_horizontal(*x) + + +with PolarsImpl.op(ops.Least()) as op: + + @op.auto + def _least(*x): + return pl.min_horizontal(*x) diff --git a/src/pydiverse/transform/backend/postgres.py b/src/pydiverse/transform/backend/postgres.py new file mode 100644 index 00000000..b83a32db --- /dev/null +++ b/src/pydiverse/transform/backend/postgres.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import sqlalchemy as sqa + +from pydiverse.transform import ops +from pydiverse.transform.backend.sql import SqlImpl + + +class PostgresImpl(SqlImpl): + dialect_name = "postgresql" + + +with PostgresImpl.op(ops.Less()) as op: + + @op("str, str -> bool") + def _lt(x, y): + return x < sqa.collate(y, "POSIX") + + +with PostgresImpl.op(ops.LessEqual()) as op: + + @op("str, str -> bool") + def _le(x, y): + return x <= sqa.collate(y, "POSIX") + + +with PostgresImpl.op(ops.Greater()) as op: + + @op("str, str -> bool") + def _gt(x, y): + return x > sqa.collate(y, "POSIX") + + +with PostgresImpl.op(ops.GreaterEqual()) as op: + + @op("str, str -> bool") + def _ge(x, y): + return x >= sqa.collate(y, "POSIX") + + +with PostgresImpl.op(ops.Round()) as op: + + @op.auto + def _round(x, decimals=0): + if decimals == 0: + if isinstance(x.type, sqa.Integer): + return x + return sqa.func.ROUND(x, type_=x.type) + + if isinstance(x.type, sqa.Float): + # Postgres doesn't support rounding of doubles to specific precision + # -> Must first cast to numeric + return sqa.func.ROUND(sqa.cast(x, sqa.Numeric), decimals, type_=sqa.Numeric) + + return sqa.func.ROUND(x, decimals, type_=x.type) + + +with PostgresImpl.op(ops.DtSecond()) as op: + + @op.auto + def _second(x): + return sqa.func.FLOOR(sqa.extract("second", x), type_=sqa.Integer()) + + +with PostgresImpl.op(ops.DtMillisecond()) as op: + + @op.auto + def _millisecond(x): + _1000 = sqa.literal_column("1000") + return sqa.func.FLOOR( + sqa.extract("milliseconds", x) % _1000, type_=sqa.Integer() + ) + + +with PostgresImpl.op(ops.Greatest()) as op: + + @op("str... -> str") + def _greatest(*x): + # TODO: Determine return type + return sqa.func.GREATEST(*(sqa.collate(e, "POSIX") for e in x)) + + +with PostgresImpl.op(ops.Least()) as op: + + @op("str... -> str") + def _least(*x): + # TODO: Determine return type + return sqa.func.LEAST(*(sqa.collate(e, "POSIX") for e in x)) + + +with PostgresImpl.op(ops.Any()) as op: + + @op.auto + def _any(x, *, _window_partition_by=None, _window_order_by=None): + return sqa.func.coalesce(sqa.func.BOOL_OR(x, type_=sqa.Boolean()), sqa.null()) + + @op.auto(variant="window") + def _any(x, *, partition_by=None, order_by=None): + return sqa.func.coalesce( + sqa.func.BOOL_OR(x, type_=sqa.Boolean()).over( + partition_by=partition_by, + order_by=order_by, + ), + sqa.null(), + ) + + +with PostgresImpl.op(ops.All()) as op: + + @op.auto + def _all(x): + return sqa.func.coalesce(sqa.func.BOOL_AND(x, type_=sqa.Boolean()), sqa.null()) + + @op.auto(variant="window") + def _all(x, *, partition_by=None, order_by=None): + return sqa.func.coalesce( + sqa.func.BOOL_AND(x, type_=sqa.Boolean()).over( + partition_by=partition_by, + order_by=order_by, + ), + sqa.null(), + ) diff --git a/src/pydiverse/transform/backend/sql.py b/src/pydiverse/transform/backend/sql.py new file mode 100644 index 00000000..05be684c --- /dev/null +++ b/src/pydiverse/transform/backend/sql.py @@ -0,0 +1,970 @@ +from __future__ import annotations + +import dataclasses +import functools +import inspect +import itertools +import operator +from collections.abc import Iterable +from typing import Any +from uuid import UUID + +import polars as pl +import sqlalchemy as sqa + +from pydiverse.transform import ops +from pydiverse.transform.backend.polars import pdt_type_to_polars +from pydiverse.transform.backend.table_impl import TableImpl +from pydiverse.transform.backend.targets import Polars, SqlAlchemy, Target +from pydiverse.transform.ops.core import Ftype +from pydiverse.transform.tree import dtypes, verbs +from pydiverse.transform.tree.ast import AstNode +from pydiverse.transform.tree.col_expr import ( + CaseExpr, + Col, + ColExpr, + ColFn, + LiteralCol, + Order, +) +from pydiverse.transform.tree.dtypes import Dtype + + +class SqlImpl(TableImpl): + Dialects: dict[str, type[TableImpl]] = {} + + def __new__(cls, *args, **kwargs) -> SqlImpl: + engine: str | sqa.Engine = ( + inspect.signature(cls.__init__) + .bind(None, *args, **kwargs) + .arguments["conf"] + .engine + ) + + dialect = ( + engine.dialect.name + if isinstance(engine, sqa.Engine) + else sqa.make_url(engine).get_dialect().name + ) + + return super().__new__(SqlImpl.Dialects[dialect]) + + def __init__(self, table: str | sqa.Table, conf: SqlAlchemy, name: str | None): + assert type(self) is not SqlImpl + + self.engine = ( + conf.engine + if isinstance(conf.engine, sqa.Engine) + else sqa.create_engine(conf.engine) + ) + if isinstance(table, str): + self.table = sqa.Table( + table, sqa.MetaData(), schema=conf.schema, autoload_with=self.engine + ) + else: + self.table = table + + if name is None: + name = self.table.name + + super().__init__( + name, + {col.name: sqa_type_to_pdt(col.type) for col in self.table.columns}, + ) + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + SqlImpl.Dialects[cls.dialect_name] = cls + + def col_names(self) -> list[str]: + return [col.name for col in self.table.columns] + + def schema(self) -> dict[str, Dtype]: + return {col.name: sqa_type_to_pdt(col.type) for col in self.table.columns} + + def _clone(self) -> tuple[SqlImpl, dict[AstNode, AstNode], dict[UUID, UUID]]: + cloned = self.__class__(self.table, SqlAlchemy(self.engine), self.name) + return ( + cloned, + {self: cloned}, + { + self.cols[name]._uuid: cloned.cols[name]._uuid + for name in self.cols.keys() + }, + ) + + @classmethod + def build_select(cls, nd: AstNode, final_select: list[Col]) -> sqa.Select: + create_aliases(nd, {}) + nd, query, _ = cls.compile_ast(nd, {col._uuid: 1 for col in final_select}) + return cls.compile_query(nd, query) + + @classmethod + def export(cls, nd: AstNode, target: Target, final_select: list[Col]) -> Any: + sel = cls.build_select(nd, final_select) + engine = get_engine(nd) + if isinstance(target, Polars): + with engine.connect() as conn: + df = pl.read_database( + sel, + connection=conn, + schema_overrides={ + sql_col.name: pdt_type_to_polars(col.dtype()) + for sql_col, col in zip(sel.columns.values(), final_select) + }, + ) + df.name = nd.name + return df + + raise NotImplementedError + + @classmethod + def build_query(cls, nd: AstNode, final_select: list[Col]) -> str | None: + sel = cls.build_select(nd, final_select) + engine = get_engine(nd) + return str( + sel.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True}) + ) + + @classmethod + def compile_order( + cls, order: Order, sqa_col: dict[str, sqa.Label] + ) -> sqa.UnaryExpression: + order_expr = cls.compile_col_expr(order.order_by, sqa_col) + order_expr = order_expr.desc() if order.descending else order_expr.asc() + if order.nulls_last is not None: + order_expr = ( + order_expr.nulls_last() + if order.nulls_last + else order_expr.nulls_first() + ) + return order_expr + + @classmethod + def compile_col_expr( + cls, expr: ColExpr, sqa_col: dict[str, sqa.Label] + ) -> sqa.ColumnElement: + if isinstance(expr, Col): + return sqa_col[expr._uuid] + + elif isinstance(expr, ColFn): + args: list[sqa.ColumnElement] = [ + cls.compile_col_expr(arg, sqa_col) for arg in expr.args + ] + impl = cls.registry.get_impl( + expr.name, tuple(arg.dtype() for arg in expr.args) + ) + + partition_by = expr.context_kwargs.get("partition_by") + if partition_by is not None: + partition_by = sqa.sql.expression.ClauseList( + *(cls.compile_col_expr(col, sqa_col) for col in partition_by) + ) + + arrange = expr.context_kwargs.get("arrange") + + if arrange: + order_by = sqa.sql.expression.ClauseList( + *( + dedup_order_by( + cls.compile_order(order, sqa_col) for order in arrange + ) + ) + ) + else: + order_by = None + + # we need this since some backends cannot do `any` / `all` as a window + # function, so we need to emulate it via `max` / `min`. + if (partition_by is not None or order_by is not None) and ( + window_impl := impl.get_variant("window") + ): + value = window_impl(*args, partition_by=partition_by, order_by=order_by) + + else: + value: sqa.ColumnElement = impl(*args) + if partition_by is not None or order_by is not None: + value = value.over(partition_by=partition_by, order_by=order_by) + + return value + + elif isinstance(expr, CaseExpr): + return sqa.case( + *( + ( + cls.compile_col_expr(cond, sqa_col), + cls.compile_col_expr(val, sqa_col), + ) + for cond, val in expr.cases + ), + else_=( + cls.compile_col_expr(expr.default_val, sqa_col) + if expr.default_val is not None + else None + ), + ) + + elif isinstance(expr, LiteralCol): + return expr.val + + raise AssertionError + + @classmethod + def compile_query(cls, table: sqa.Table, query: Query) -> sqa.sql.Select: + sel = table.select().select_from(table) + + for j in query.join: + sel = sel.join( + j.right, + onclause=j.on, + isouter=j.how != "inner", + full=j.how == "outer", + ) + + if query.where: + sel = sel.where(*query.where) + + if query.group_by: + sel = sel.group_by(*query.group_by) + + if query.having: + sel = sel.having(*query.having) + + if query.limit is not None: + sel = sel.limit(query.limit).offset(query.offset) + + if query.order_by: + sel = sel.order_by(*query.order_by) + + sel = sel.with_only_columns(*query.select) + + return sel + + @classmethod + def compile_ast( + cls, nd: AstNode, needed_cols: dict[UUID, int] + ) -> tuple[sqa.Table, Query, dict[UUID, sqa.Label]]: + if isinstance(nd, verbs.Verb): + # store a counter how often each UUID is referenced by ancestors. This + # allows to only select necessary columns in a subquery. + for node in nd.iter_col_nodes(): + if isinstance(node, Col): + cnt = needed_cols.get(node._uuid) + if cnt is None: + needed_cols[node._uuid] = 1 + else: + needed_cols[node._uuid] = cnt + 1 + + table, query, sqa_col = cls.compile_ast(nd.child, needed_cols) + + # check if a subquery is required + if ( + ( + isinstance( + nd, + ( + verbs.Filter, + verbs.Summarise, + verbs.Arrange, + verbs.GroupBy, + verbs.Join, + ), + ) + and query.limit is not None + ) + or ( + isinstance(nd, (verbs.Mutate, verbs.Filter)) + and any( + node.ftype(agg_is_window=True) == Ftype.WINDOW + for node in nd.iter_col_nodes() + if isinstance(node, Col) + ) + ) + or ( + isinstance(nd, verbs.Summarise) + and ( + ( + bool(query.group_by) + and set(query.group_by) != set(query.partition_by) + ) + or any( + ( + node.ftype(agg_is_window=False) + in (Ftype.WINDOW, Ftype.AGGREGATE) + ) + for node in nd.iter_col_nodes() + if isinstance(node, Col) + ) + ) + ) + ): + if needed_cols.keys().isdisjoint(sqa_col.keys()): + # We cannot select zero columns from a subquery. This happens when the + # user only 0-ary functions after the subquery, e.g. `count`. + needed_cols[next(iter(sqa_col.keys()))] = 1 + + # TODO: do we want `alias` to automatically create a subquery? or add a + # flag to the node that a subquery would be allowed? or special verb to + # mark subquery? + + # We only want to select those columns that (1) the user uses in some + # expression later or (2) are present in the final selection. + orig_select = query.select + query.select = [ + sqa_col[uid] for uid in needed_cols.keys() if uid in sqa_col + ] + table = cls.compile_query(table, query).subquery() + sqa_col.update( + { + uid: sqa.label( + sqa_col[uid].name, table.columns.get(sqa_col[uid].name) + ) + for uid in needed_cols.keys() + if uid in sqa_col + } + ) + + # rewire col refs to the subquery + query = Query( + [ + sqa.Label(lb.name, col) + for lb in orig_select + if (col := table.columns.get(lb.name)) is not None + ], + partition_by=[ + sqa.Label(lb.name, col) + for lb in query.partition_by + if (col := table.columns.get(lb.name)) is not None + ], + ) + + if isinstance(nd, (verbs.Mutate, verbs.Summarise)): + query.select = [lb for lb in query.select if lb.name not in set(nd.names)] + + if isinstance(nd, verbs.Select): + query.select = [sqa_col[col._uuid] for col in nd.select] + + elif isinstance(nd, verbs.Rename): + sqa_col = { + uid: ( + sqa.label(nd.name_map[lb.name], lb) + if lb.name in nd.name_map + else lb + ) + for uid, lb in sqa_col.items() + } + + query.select, query.partition_by, query.group_by = ( + [ + sqa.label(nd.name_map[lb.name], lb) + if lb.name in nd.name_map + else lb + for lb in label_arr + ] + for label_arr in (query.select, query.partition_by, query.group_by) + ) + + elif isinstance(nd, verbs.Mutate): + for name, val, uid in zip(nd.names, nd.values, nd.uuids): + sqa_col[uid] = sqa.label(name, cls.compile_col_expr(val, sqa_col)) + query.select.append(sqa_col[uid]) + + elif isinstance(nd, verbs.Filter): + if query.group_by: + query.having.extend( + cls.compile_col_expr(fil, sqa_col) for fil in nd.filters + ) + else: + query.where.extend( + cls.compile_col_expr(fil, sqa_col) for fil in nd.filters + ) + + elif isinstance(nd, verbs.Arrange): + query.order_by = dedup_order_by( + itertools.chain( + (cls.compile_order(ord, sqa_col) for ord in nd.order_by), + query.order_by, + ) + ) + + elif isinstance(nd, verbs.Summarise): + query.group_by.extend(query.partition_by) + + for name, val, uid in zip(nd.names, nd.values, nd.uuids): + sqa_col[uid] = sqa.Label(name, cls.compile_col_expr(val, sqa_col)) + + query.select = query.partition_by + [sqa_col[uid] for uid in nd.uuids] + query.partition_by = [] + query.order_by.clear() + + elif isinstance(nd, verbs.SliceHead): + if query.limit is None: + query.limit = nd.n + query.offset = nd.offset + else: + query.limit = min(abs(query.limit - nd.offset), nd.n) + query.offset += nd.offset + + elif isinstance(nd, verbs.GroupBy): + compiled_group_by = (sqa_col[col._uuid] for col in nd.group_by) + if nd.add: + query.partition_by.extend(compiled_group_by) + else: + query.partition_by = list(compiled_group_by) + + elif isinstance(nd, verbs.Ungroup): + assert not (query.partition_by and query.group_by) + query.partition_by.clear() + + elif isinstance(nd, verbs.Join): + right_table, right_query, right_sqa_col = cls.compile_ast( + nd.right, needed_cols + ) + + sqa_col.update( + { + uid: sqa.label(lb.name + nd.suffix, lb) + for uid, lb in right_sqa_col.items() + } + ) + + j = SqlJoin( + right_table, + cls.compile_col_expr(nd.on, sqa_col), + nd.how, + ) + + if nd.how == "inner": + query.where.extend(right_query.where) + elif nd.how == "left": + j.on = functools.reduce(operator.and_, (j.on, *right_query.where)) + elif nd.how == "outer": + if query.where or right_query.where: + raise ValueError("invalid filter before outer join") + + query.join.append(j) + query.select += [ + sqa.Label(lb.name + nd.suffix, lb) for lb in right_query.select + ] + + elif isinstance(nd, SqlImpl): + table = nd.table + query = Query([sqa.Label(col.name, col) for col in nd.table.columns]) + sqa_col = { + col._uuid: sqa.label(col.name, nd.table.columns[col.name]) + for col in nd.cols.values() + } + + if isinstance(nd, verbs.Verb): + # decrease counters (`needed_cols` is not copied) + for node in nd.iter_col_nodes(): + if isinstance(node, Col): + cnt = needed_cols.get(node._uuid) + if cnt == 1: + del needed_cols[node._uuid] + else: + needed_cols[node._uuid] = cnt - 1 + + return table, query, sqa_col + + +@dataclasses.dataclass(slots=True) +class Query: + select: list[sqa.Label] + join: list[SqlJoin] = dataclasses.field(default_factory=list) + partition_by: list[sqa.Label] = dataclasses.field(default_factory=list) + group_by: list[sqa.Label] = dataclasses.field(default_factory=list) + where: list[sqa.ColumnElement] = dataclasses.field(default_factory=list) + having: list[sqa.ColumnElement] = dataclasses.field(default_factory=list) + order_by: list[sqa.UnaryExpression] = dataclasses.field(default_factory=list) + limit: int | None = None + offset: int | None = None + + +@dataclasses.dataclass(slots=True) +class SqlJoin: + right: sqa.Subquery + on: sqa.ColumnElement + how: verbs.JoinHow + + +# MSSQL complains about duplicates in ORDER BY. +def dedup_order_by( + order_by: Iterable[sqa.UnaryExpression], +) -> list[sqa.UnaryExpression]: + new_order_by: list[sqa.UnaryExpression] = [] + occurred: set[sqa.ColumnElement] = set() + + for ord in order_by: + peeled = ord + while isinstance(peeled, sqa.UnaryExpression) and peeled.modifier is not None: + peeled = peeled.element + if peeled not in occurred: + new_order_by.append(ord) + occurred.add(peeled) + + return new_order_by + + +# Gives any leaf a unique alias to allow self-joins. We do this here to not force +# the user to come up with dummy names that are not required later anymore. It has +# to be done before a join so that all column references in the join subtrees remain +# valid. +def create_aliases(nd: AstNode, num_occurences: dict[str, int]) -> dict[str, int]: + if isinstance(nd, verbs.Verb): + num_occurences = create_aliases(nd.child, num_occurences) + + if isinstance(nd, verbs.Join): + num_occurences = create_aliases(nd.right, num_occurences) + + elif isinstance(nd, SqlImpl): + if cnt := num_occurences.get(nd.table.name): + nd.table = nd.table.alias(f"{nd.table.name}_{cnt}") + else: + cnt = 0 + num_occurences[nd.table.name] = cnt + 1 + + else: + raise AssertionError + + return num_occurences + + +def get_engine(nd: AstNode) -> sqa.Engine: + if isinstance(nd, verbs.Verb): + engine = get_engine(nd.child) + + if isinstance(nd, verbs.Join): + right_engine = get_engine(nd.right) + if engine != right_engine: + raise NotImplementedError # TODO: find some good error for this + + else: + assert isinstance(nd, SqlImpl) + engine = nd.engine + + return engine + + +def sqa_type_to_pdt(t: sqa.types.TypeEngine) -> Dtype: + if isinstance(t, sqa.Integer): + return dtypes.Int() + elif isinstance(t, sqa.Numeric): + return dtypes.Float() + elif isinstance(t, sqa.String): + return dtypes.String() + elif isinstance(t, sqa.Boolean): + return dtypes.Bool() + elif isinstance(t, sqa.DateTime): + return dtypes.DateTime() + elif isinstance(t, sqa.Date): + return dtypes.Date() + elif isinstance(t, sqa.Interval): + return dtypes.Duration() + elif isinstance(t, sqa.Null): + return dtypes.NoneDtype() + + raise TypeError(f"SQLAlchemy type {t} not supported by pydiverse.transform") + + +def pdt_type_to_sqa(t: Dtype) -> sqa.types.TypeEngine: + if isinstance(t, dtypes.Int): + return sqa.Integer() + elif isinstance(t, dtypes.Float): + return sqa.Numeric() + elif isinstance(t, dtypes.String): + return sqa.String() + elif isinstance(t, dtypes.Bool): + return sqa.Boolean() + elif isinstance(t, dtypes.DateTime): + return sqa.DateTime() + elif isinstance(t, dtypes.Date): + return sqa.Date() + elif isinstance(t, dtypes.Duration): + return sqa.Interval() + elif isinstance(t, dtypes.NoneDtype): + return sqa.types.NullType() + + raise AssertionError + + +with SqlImpl.op(ops.FloorDiv(), check_super=False) as op: + if sqa.__version__ < "2": + + @op.auto + def _floordiv(lhs, rhs): + return sqa.cast(lhs / rhs, sqa.Integer()) + + else: + + @op.auto + def _floordiv(lhs, rhs): + return lhs // rhs + + +with SqlImpl.op(ops.RFloorDiv(), check_super=False) as op: + + @op.auto + def _rfloordiv(rhs, lhs): + return _floordiv(lhs, rhs) + + +with SqlImpl.op(ops.Pow()) as op: + + @op.auto + def _pow(lhs, rhs): + if isinstance(lhs.type, sqa.Float) or isinstance(rhs.type, sqa.Float): + type_ = sqa.Double() + elif isinstance(lhs.type, sqa.Numeric) or isinstance(rhs, sqa.Numeric): + type_ = sqa.Numeric() + else: + type_ = sqa.Double() + + return sqa.func.POW(lhs, rhs, type_=type_) + + +with SqlImpl.op(ops.RPow()) as op: + + @op.auto + def _rpow(rhs, lhs): + return _pow(lhs, rhs) + + +with SqlImpl.op(ops.Xor()) as op: + + @op.auto + def _xor(lhs, rhs): + return lhs != rhs + + +with SqlImpl.op(ops.RXor()) as op: + + @op.auto + def _rxor(rhs, lhs): + return lhs != rhs + + +with SqlImpl.op(ops.Pos()) as op: + + @op.auto + def _pos(x): + return x + + +with SqlImpl.op(ops.Abs()) as op: + + @op.auto + def _abs(x): + return sqa.func.ABS(x, type_=x.type) + + +with SqlImpl.op(ops.Round()) as op: + + @op.auto + def _round(x, decimals=0): + return sqa.func.ROUND(x, decimals, type_=x.type) + + +with SqlImpl.op(ops.IsIn()) as op: + + @op.auto + def _isin(x, *values): + return functools.reduce(operator.or_, map(lambda v: x == v, values)) + + +with SqlImpl.op(ops.IsNull()) as op: + + @op.auto + def _is_null(x): + return x.is_(sqa.null()) + + +with SqlImpl.op(ops.IsNotNull()) as op: + + @op.auto + def _is_not_null(x): + return x.is_not(sqa.null()) + + +#### String Functions #### + + +with SqlImpl.op(ops.StrStrip()) as op: + + @op.auto + def _str_strip(x): + return sqa.func.TRIM(x, type_=x.type) + + +with SqlImpl.op(ops.StrLen()) as op: + + @op.auto + def _str_length(x): + return sqa.func.LENGTH(x, type_=sqa.Integer()) + + +with SqlImpl.op(ops.StrToUpper()) as op: + + @op.auto + def _upper(x): + return sqa.func.UPPER(x, type_=x.type) + + +with SqlImpl.op(ops.StrToLower()) as op: + + @op.auto + def _upper(x): + return sqa.func.LOWER(x, type_=x.type) + + +with SqlImpl.op(ops.StrReplaceAll()) as op: + + @op.auto + def _replace_all(x, y, z): + return sqa.func.REPLACE(x, y, z, type_=x.type) + + +with SqlImpl.op(ops.StrStartsWith()) as op: + + @op.auto + def _startswith(x, y): + return x.startswith(y, autoescape=True) + + +with SqlImpl.op(ops.StrEndsWith()) as op: + + @op.auto + def _endswith(x, y): + return x.endswith(y, autoescape=True) + + +with SqlImpl.op(ops.StrContains()) as op: + + @op.auto + def _contains(x, y): + return x.contains(y, autoescape=True) + + +with SqlImpl.op(ops.StrSlice()) as op: + + @op.auto + def _str_slice(x, offset, length): + # SQL has 1-indexed strings but we do it 0-indexed + return sqa.func.SUBSTR(x, offset + 1, length) + + +#### Datetime Functions #### + + +with SqlImpl.op(ops.DtYear()) as op: + + @op.auto + def _year(x): + return sqa.extract("year", x) + + +with SqlImpl.op(ops.DtMonth()) as op: + + @op.auto + def _month(x): + return sqa.extract("month", x) + + +with SqlImpl.op(ops.DtDay()) as op: + + @op.auto + def _day(x): + return sqa.extract("day", x) + + +with SqlImpl.op(ops.DtHour()) as op: + + @op.auto + def _hour(x): + return sqa.extract("hour", x) + + +with SqlImpl.op(ops.DtMinute()) as op: + + @op.auto + def _minute(x): + return sqa.extract("minute", x) + + +with SqlImpl.op(ops.DtSecond()) as op: + + @op.auto + def _second(x): + return sqa.extract("second", x) + + +with SqlImpl.op(ops.DtMillisecond()) as op: + + @op.auto + def _millisecond(x): + return sqa.extract("milliseconds", x) % 1000 + + +with SqlImpl.op(ops.DtDayOfWeek()) as op: + + @op.auto + def _day_of_week(x): + return sqa.extract("dow", x) + + +with SqlImpl.op(ops.DtDayOfYear()) as op: + + @op.auto + def _day_of_year(x): + return sqa.extract("doy", x) + + +#### Generic Functions #### + + +with SqlImpl.op(ops.Greatest()) as op: + + @op.auto + def _greatest(*x): + # TODO: Determine return type + return sqa.func.GREATEST(*x) + + +with SqlImpl.op(ops.Least()) as op: + + @op.auto + def _least(*x): + # TODO: Determine return type + return sqa.func.LEAST(*x) + + +#### Summarising Functions #### + + +with SqlImpl.op(ops.Mean()) as op: + + @op.auto + def _mean(x): + type_ = sqa.Numeric() + if isinstance(x.type, sqa.Float): + type_ = sqa.Double() + + return sqa.func.AVG(x, type_=type_) + + +with SqlImpl.op(ops.Min()) as op: + + @op.auto + def _min(x): + return sqa.func.min(x) + + +with SqlImpl.op(ops.Max()) as op: + + @op.auto + def _max(x): + return sqa.func.max(x) + + +with SqlImpl.op(ops.Sum()) as op: + + @op.auto + def _sum(x): + return sqa.func.sum(x) + + +with SqlImpl.op(ops.Any()) as op: + + @op.auto + def _any(x, *, _window_partition_by=None, _window_order_by=None): + return sqa.func.coalesce(sqa.func.max(x), sqa.null()) + + @op.auto(variant="window") + def _any(x, *, partition_by=None, order_by=None): + return sqa.func.coalesce( + sqa.func.max(x).over( + partition_by=partition_by, + order_by=order_by, + ), + sqa.null(), + ) + + +with SqlImpl.op(ops.All()) as op: + + @op.auto + def _all(x): + return sqa.func.coalesce(sqa.func.min(x), sqa.null()) + + @op.auto(variant="window") + def _all(x, *, partition_by=None, order_by=None): + return sqa.func.coalesce( + sqa.func.min(x).over( + partition_by=partition_by, + order_by=order_by, + ), + sqa.null(), + ) + + +with SqlImpl.op(ops.Count()) as op: + + @op.auto + def _count(x=None): + if x is None: + # Get the number of rows + return sqa.func.count() + else: + # Count non null values + return sqa.func.count(x) + + +#### Window Functions #### + + +with SqlImpl.op(ops.Shift()) as op: + + @op.auto + def _shift(): + raise RuntimeError("This is a stub") + + @op.auto(variant="window") + def _shift( + x, + by, + empty_value=None, + *, + partition_by=None, + order_by=None, + ): + if by == 0: + return x + if by > 0: + return sqa.func.LAG(x, by, empty_value, type_=x.type).over( + partition_by=partition_by, order_by=order_by + ) + if by < 0: + return sqa.func.LEAD(x, -by, empty_value, type_=x.type).over( + partition_by=partition_by, order_by=order_by + ) + + +with SqlImpl.op(ops.RowNumber()) as op: + + @op.auto + def _row_number(): + return sqa.func.ROW_NUMBER(type_=sqa.Integer()) + + +with SqlImpl.op(ops.Rank()) as op: + + @op.auto + def _rank(): + return sqa.func.rank() + + +with SqlImpl.op(ops.DenseRank()) as op: + + @op.auto + def _dense_rank(): + return sqa.func.dense_rank() diff --git a/src/pydiverse/transform/sql/sqlite.py b/src/pydiverse/transform/backend/sqlite.py similarity index 70% rename from src/pydiverse/transform/sql/sqlite.py rename to src/pydiverse/transform/backend/sqlite.py index 5f30744a..1dc14f07 100644 --- a/src/pydiverse/transform/sql/sqlite.py +++ b/src/pydiverse/transform/backend/sqlite.py @@ -1,27 +1,27 @@ from __future__ import annotations -import sqlalchemy as sa +import sqlalchemy as sqa from pydiverse.transform import ops -from pydiverse.transform.sql.sql_table import SQLTableImpl +from pydiverse.transform.backend.sql import SqlImpl from pydiverse.transform.util.warnings import warn_non_standard -class SQLiteTableImpl(SQLTableImpl): - _dialect_name = "sqlite" +class SqliteImpl(SqlImpl): + dialect_name = "sqlite" -with SQLiteTableImpl.op(ops.Round()) as op: +with SqliteImpl.op(ops.Round()) as op: @op.auto def _round(x, decimals=0): if decimals >= 0: - return sa.func.ROUND(x, decimals, type_=x.type) + return sqa.func.ROUND(x, decimals, type_=x.type) # For some reason SQLite doesn't like negative decimals values - return sa.func.ROUND(x / (10**-decimals), type_=x.type) * (10**-decimals) + return sqa.func.ROUND(x / (10**-decimals), type_=x.type) * (10**-decimals) -with SQLiteTableImpl.op(ops.StrStartsWith()) as op: +with SqliteImpl.op(ops.StrStartsWith()) as op: @op.auto def _startswith(x, y): @@ -33,7 +33,7 @@ def _startswith(x, y): return x.startswith(y, autoescape=True) -with SQLiteTableImpl.op(ops.StrEndsWith()) as op: +with SqliteImpl.op(ops.StrEndsWith()) as op: @op.auto def _endswith(x, y): @@ -45,7 +45,7 @@ def _endswith(x, y): return x.endswith(y, autoescape=True) -with SQLiteTableImpl.op(ops.StrContains()) as op: +with SqliteImpl.op(ops.StrContains()) as op: @op.auto def _contains(x, y): @@ -57,19 +57,19 @@ def _contains(x, y): return x.contains(y, autoescape=True) -with SQLiteTableImpl.op(ops.DtMillisecond()) as op: +with SqliteImpl.op(ops.DtMillisecond()) as op: @op.auto def _millisecond(x): warn_non_standard( "SQLite returns rounded milliseconds", ) - _1000 = sa.literal_column("1000") - frac_seconds = sa.cast(sa.func.STRFTIME("%f", x), sa.Numeric()) - return sa.cast((frac_seconds * _1000) % _1000, sa.Integer()) + _1000 = sqa.literal_column("1000") + frac_seconds = sqa.cast(sqa.func.STRFTIME("%f", x), sqa.Numeric()) + return sqa.cast((frac_seconds * _1000) % _1000, sqa.Integer()) -with SQLiteTableImpl.op(ops.Greatest()) as op: +with SqliteImpl.op(ops.Greatest()) as op: @op.auto def _greatest(*x): @@ -83,10 +83,10 @@ def _greatest(*x): right = _greatest(*x[mid:]) # TODO: Determine return type - return sa.func.coalesce(sa.func.MAX(left, right), left, right) + return sqa.func.coalesce(sqa.func.MAX(left, right), left, right) -with SQLiteTableImpl.op(ops.Least()) as op: +with SqliteImpl.op(ops.Least()) as op: @op.auto def _least(*x): @@ -100,4 +100,4 @@ def _least(*x): right = _least(*x[mid:]) # TODO: Determine return type - return sa.func.coalesce(sa.func.MIN(left, right), left, right) + return sqa.func.coalesce(sqa.func.MIN(left, right), left, right) diff --git a/src/pydiverse/transform/backend/table_impl.py b/src/pydiverse/transform/backend/table_impl.py new file mode 100644 index 00000000..763a9a98 --- /dev/null +++ b/src/pydiverse/transform/backend/table_impl.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +import uuid +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any + +from pydiverse.transform import ops +from pydiverse.transform.backend.targets import Target +from pydiverse.transform.ops.core import Ftype +from pydiverse.transform.tree.ast import AstNode +from pydiverse.transform.tree.col_expr import ( + Col, +) +from pydiverse.transform.tree.dtypes import Dtype +from pydiverse.transform.tree.registry import ( + OperatorRegistrationContextManager, + OperatorRegistry, +) + +if TYPE_CHECKING: + from pydiverse.transform.ops import Operator + + +class TableImpl(AstNode): + """ + Base class from which all table backend implementations are derived from. + """ + + registry = OperatorRegistry("TableImpl") + + def __init__(self, name: str, schema: dict[str, Dtype]): + self.name = name + self.cols = { + name: Col(name, self, uuid.uuid1(), dtype, Ftype.EWISE) + for name, dtype in schema.items() + } + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + # Add new `registry` class variable to subclass. + # We define the super registry by walking up the MRO. This allows us + # to check for potential operation definitions in the parent classes. + super_reg = None + for super_cls in cls.__mro__: + if hasattr(super_cls, "registry"): + super_reg = super_cls.registry + break + cls.registry = OperatorRegistry(cls.__name__, super_reg) + + def iter_subtree(self) -> Iterable[AstNode]: + yield self + + @staticmethod + def build_query(nd: AstNode, final_select: list[Col]) -> str | None: ... + + @staticmethod + def export(nd: AstNode, target: Target, final_select: list[Col]) -> Any: ... + + @classmethod + def _html_repr_expr(cls, expr): + """ + Return an appropriate string to display an expression from this backend. + This is mainly used to IPython. + """ + return repr(expr) + + @classmethod + def op(cls, operator: Operator, **kwargs) -> OperatorRegistrationContextManager: + return OperatorRegistrationContextManager(cls.registry, operator, **kwargs) + + +with TableImpl.op(ops.NullsFirst()) as op: + + @op.auto + def _nulls_first(_): + raise AssertionError + + +with TableImpl.op(ops.NullsLast()) as op: + + @op.auto + def _nulls_last(_): + raise AssertionError + + +with TableImpl.op(ops.Ascending()) as op: + + @op.auto + def _ascending(_): + raise AssertionError + + +with TableImpl.op(ops.Descending()) as op: + + @op.auto + def _descending(_): + raise AssertionError + + +with TableImpl.op(ops.Add()) as op: + + @op.auto + def _add(lhs, rhs): + return lhs + rhs + + @op.extension(ops.StrAdd) + def _str_add(lhs, rhs): + return lhs + rhs + + +with TableImpl.op(ops.RAdd()) as op: + + @op.auto + def _radd(rhs, lhs): + return lhs + rhs + + @op.extension(ops.StrRAdd) + def _str_radd(lhs, rhs): + return lhs + rhs + + +with TableImpl.op(ops.Sub()) as op: + + @op.auto + def _sub(lhs, rhs): + return lhs - rhs + + +with TableImpl.op(ops.RSub()) as op: + + @op.auto + def _rsub(rhs, lhs): + return lhs - rhs + + +with TableImpl.op(ops.Mul()) as op: + + @op.auto + def _mul(lhs, rhs): + return lhs * rhs + + +with TableImpl.op(ops.RMul()) as op: + + @op.auto + def _rmul(rhs, lhs): + return lhs * rhs + + +with TableImpl.op(ops.TrueDiv()) as op: + + @op.auto + def _truediv(lhs, rhs): + return lhs / rhs + + +with TableImpl.op(ops.RTrueDiv()) as op: + + @op.auto + def _rtruediv(rhs, lhs): + return lhs / rhs + + +with TableImpl.op(ops.FloorDiv()) as op: + + @op.auto + def _floordiv(lhs, rhs): + return lhs // rhs + + +with TableImpl.op(ops.RFloorDiv()) as op: + + @op.auto + def _rfloordiv(rhs, lhs): + return lhs // rhs + + +with TableImpl.op(ops.Pow()) as op: + + @op.auto + def _pow(lhs, rhs): + return lhs**rhs + + +with TableImpl.op(ops.RPow()) as op: + + @op.auto + def _rpow(rhs, lhs): + return lhs**rhs + + +with TableImpl.op(ops.Mod()) as op: + + @op.auto + def _mod(lhs, rhs): + return lhs % rhs + + +with TableImpl.op(ops.RMod()) as op: + + @op.auto + def _rmod(rhs, lhs): + return lhs % rhs + + +with TableImpl.op(ops.Neg()) as op: + + @op.auto + def _neg(x): + return -x + + +with TableImpl.op(ops.Pos()) as op: + + @op.auto + def _pos(x): + return +x + + +with TableImpl.op(ops.Abs()) as op: + + @op.auto + def _abs(x): + return abs(x) + + +with TableImpl.op(ops.And()) as op: + + @op.auto + def _and(lhs, rhs): + return lhs & rhs + + +with TableImpl.op(ops.RAnd()) as op: + + @op.auto + def _rand(rhs, lhs): + return lhs & rhs + + +with TableImpl.op(ops.Or()) as op: + + @op.auto + def _or(lhs, rhs): + return lhs | rhs + + +with TableImpl.op(ops.ROr()) as op: + + @op.auto + def _ror(rhs, lhs): + return lhs | rhs + + +with TableImpl.op(ops.Xor()) as op: + + @op.auto + def _xor(lhs, rhs): + return lhs ^ rhs + + +with TableImpl.op(ops.RXor()) as op: + + @op.auto + def _rxor(rhs, lhs): + return lhs ^ rhs + + +with TableImpl.op(ops.Invert()) as op: + + @op.auto + def _invert(x): + return ~x + + +with TableImpl.op(ops.Equal()) as op: + + @op.auto + def _eq(lhs, rhs): + return lhs == rhs + + +with TableImpl.op(ops.NotEqual()) as op: + + @op.auto + def _ne(lhs, rhs): + return lhs != rhs + + +with TableImpl.op(ops.Less()) as op: + + @op.auto + def _lt(lhs, rhs): + return lhs < rhs + + +with TableImpl.op(ops.LessEqual()) as op: + + @op.auto + def _le(lhs, rhs): + return lhs <= rhs + + +with TableImpl.op(ops.Greater()) as op: + + @op.auto + def _gt(lhs, rhs): + return lhs > rhs + + +with TableImpl.op(ops.GreaterEqual()) as op: + + @op.auto + def _ge(lhs, rhs): + return lhs >= rhs diff --git a/src/pydiverse/transform/backend/targets.py b/src/pydiverse/transform/backend/targets.py new file mode 100644 index 00000000..d16db36f --- /dev/null +++ b/src/pydiverse/transform/backend/targets.py @@ -0,0 +1,25 @@ +# This module defines the config classes provided to the user to configure +# the backend on import / export. + + +from __future__ import annotations + +import sqlalchemy as sqa + + +# TODO: better name for this? (the user sees this) +class Target: ... + + +class Polars(Target): + def __init__(self, *, lazy: bool = False) -> None: + self.lazy = lazy + + +class DuckDb(Target): ... + + +class SqlAlchemy(Target): + def __init__(self, engine: sqa.Engine, *, schema: str | None = None): + self.engine = engine + self.schema = schema diff --git a/src/pydiverse/transform/core/__init__.py b/src/pydiverse/transform/core/__init__.py deleted file mode 100644 index bc4e289f..00000000 --- a/src/pydiverse/transform/core/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from __future__ import annotations - -from .table import Table -from .table_impl import AbstractTableImpl - -__all__ = [ - Table, - AbstractTableImpl, -] diff --git a/src/pydiverse/transform/core/alignment.py b/src/pydiverse/transform/core/alignment.py deleted file mode 100644 index 2c25fccd..00000000 --- a/src/pydiverse/transform/core/alignment.py +++ /dev/null @@ -1,101 +0,0 @@ -from __future__ import annotations - -import inspect -from typing import TYPE_CHECKING - -from pydiverse.transform.core.expressions import ( - Column, - LiteralColumn, - SymbolicExpression, - util, -) -from pydiverse.transform.errors import AlignmentError - -if TYPE_CHECKING: - from pydiverse.transform.core import AbstractTableImpl, Table - - -def aligned(*, with_: str): - """Decorator for aligned functions.""" - from pydiverse.transform.core import AbstractTableImpl, Table - - if callable(with_): - raise ValueError("Decorator @aligned requires with_ argument.") - - def decorator(func): - signature = inspect.signature(func) - if not isinstance(with_, str): - raise TypeError( - f"Argument 'with_' must be of type str, not '{type(with_).__name__}'." - ) - if with_ not in signature.parameters: - raise ValueError(f"Function has no argument named '{with_}'") - - def wrapper(*args, **kwargs): - # Execute func - result = func(*args, **kwargs) - if not isinstance(result, SymbolicExpression): - raise TypeError( - "Aligned function must return a symbolic expression not" - f" '{result}'." - ) - - # Extract the correct `with_` argument for eval_aligned - bound_sig = signature.bind(*args, **kwargs) - bound_sig.apply_defaults() - - alignment_param = bound_sig.arguments[with_] - if isinstance(alignment_param, SymbolicExpression): - alignment_param = alignment_param._ - - if isinstance(alignment_param, Column): - aligned_with = alignment_param.table - elif isinstance(alignment_param, (Table, AbstractTableImpl)): - aligned_with = alignment_param - else: - raise NotImplementedError - - # Evaluate aligned - return eval_aligned(result, with_=aligned_with) - - return wrapper - - return decorator - - -def eval_aligned( - sexpr: SymbolicExpression, with_: AbstractTableImpl | Table = None, **kwargs -) -> SymbolicExpression[LiteralColumn]: - """Evaluates an expression using the AlignedExpressionEvaluator.""" - from pydiverse.transform.core import AbstractTableImpl, Table - - expr = sexpr._ if isinstance(sexpr, SymbolicExpression) else sexpr - - # Determine Backend - backend = util.determine_expr_backend(expr) - if backend is None: - # TODO: Handle this case. Should return some value... - raise NotImplementedError - - # Evaluate the function calls on the shared backend - alignedEvaluator = backend.AlignedExpressionEvaluator(backend.operator_registry) - result = alignedEvaluator.translate(expr, **kwargs) - - literal_column = LiteralColumn(typed_value=result, expr=expr, backend=backend) - - # Check if alignment condition holds - if with_ is not None: - if isinstance(with_, Table): - with_ = with_._impl - if not isinstance(with_, AbstractTableImpl): - raise TypeError( - "'with_' must either be an instance of a Table or TableImpl. Not" - f" '{with_}'." - ) - - if not with_.is_aligned_with(literal_column): - raise AlignmentError(f"Result of eval_aligned isn't aligned with {with_}.") - - # Convert to sexpr so that the user can easily continue transforming - # it symbolically. - return SymbolicExpression(literal_column) diff --git a/src/pydiverse/transform/core/dispatchers.py b/src/pydiverse/transform/core/dispatchers.py deleted file mode 100644 index f7a2692e..00000000 --- a/src/pydiverse/transform/core/dispatchers.py +++ /dev/null @@ -1,179 +0,0 @@ -from __future__ import annotations - -import copy -from functools import partial, reduce, wraps -from typing import Any - -from pydiverse.transform.core.expressions import ( - Column, - LambdaColumn, - unwrap_symbolic_expressions, -) -from pydiverse.transform.core.util import bidict, traverse - - -class Pipeable: - def __init__(self, f=None, calls=None): - if f is not None: - if calls is not None: - raise ValueError - self.calls = [f] - else: - self.calls = calls - - def __rshift__(self, other) -> Pipeable: - """ - Pipeable >> other - -> Lazy. Extend pipe. - """ - if isinstance(other, Pipeable): - return Pipeable(calls=self.calls + other.calls) - elif callable(other): - return Pipeable(calls=self.calls + [other]) - - raise RuntimeError - - def __rrshift__(self, other): - """ - other >> Pipeable - -> Eager. - """ - if callable(other): - return Pipeable(calls=[other] + self.calls) - return self(other) - - def __call__(self, arg): - return reduce(lambda x, f: f(x), self.calls, arg) - - -class inverse_partial(partial): - """ - Just like partial, but the arguments get applied to the back instead of the front. - This means that a function `def x(a, b, c)` decorated with `@inverse_partial(1, 2)` - that gets called with `x(0)` is equivalent to calling `x(0, 1, 2)` on the non - decorated function. - """ - - def __call__(self, /, *args, **keywords): - keywords = {**self.keywords, **keywords} - # ↙ *args moved to front. - return self.func(*args, *self.args, **keywords) - - -def verb(func): - from pydiverse.transform.core.table import Table - - def copy_tables(arg: Any = None): - return traverse(arg, lambda x: copy.copy(x) if isinstance(x, Table) else x) - - @wraps(func) - def wrapper(*args, **kwargs): - # Copy Table objects to prevent mutating them - # This can be the case if the user uses __setitem__ inside the verb - def f(*args, **kwargs): - args = copy_tables(args) - kwargs = copy_tables(kwargs) - return func(*args, **kwargs) - - f = inverse_partial(f, *args, **kwargs) # Bind arguments - return Pipeable(f) - - return wrapper - - -def builtin_verb(backends=None): - def wrap_and_unwrap(func): - @wraps(func) - def wrapper(*args, **kwargs): - args = list(args) - args = unwrap_symbolic_expressions(args) - if len(args): - args[0] = col_to_table(args[0]) - args = unwrap_tables(args) - - kwargs = unwrap_symbolic_expressions(kwargs) - kwargs = unwrap_tables(kwargs) - - return wrap_tables(func(*args, **kwargs)) - - return wrapper - - def check_backend(func): - if backends is None: - return func - - @wraps(func) - def wrapper(*args, **kwargs): - assert len(args) > 0 - impl = args[0]._impl - if isinstance(impl, backends): - return func(*args, **kwargs) - raise TypeError(f"Backend {impl} not supported for verb '{func.__name__}'.") - - return wrapper - - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - f = func - f = wrap_and_unwrap(f) # Convert from Table to Impl and back - f = check_backend(f) # Check type of backend - f = inverse_partial(f, *args, **kwargs) # Bind arguments - return Pipeable(f) # Make pipeable - - return wrapper - - return decorator - - -# Helper - - -def col_to_table(arg: Any = None): - """ - Takes a single argument and if it is a column, replaces it with a table - implementation that only contains that one column. - - This allows for more eager style code where you perform operations on - columns like with the following example:: - - def get_c(b, tB): - tC = b >> left_join(tB, b == tB.b) - return tC[tB.c] - feature_col = get_c(tblA.b, tblB) - - """ - from pydiverse.transform.core.verbs import select - - if isinstance(arg, Column): - tbl = (arg.table >> select(arg))._impl - col = tbl.get_col(arg) - - tbl.available_cols = {col.uuid} - tbl.named_cols = bidict({col.name: col.uuid}) - return tbl - elif isinstance(arg, LambdaColumn): - raise ValueError("Can't start a pipe with a lambda column.") - - return arg - - -def unwrap_tables(arg: Any = None): - """ - Takes an instance or collection of `Table` objects and replaces them with - their implementation. - """ - from pydiverse.transform.core.table import Table - - return traverse(arg, lambda x: x._impl if isinstance(x, Table) else x) - - -def wrap_tables(arg: Any = None): - """ - Takes an instance or collection of `AbstractTableImpl` objects and wraps - them in a `Table` object. This is an inverse to the `unwrap_tables` function. - """ - from pydiverse.transform.core.table import Table - from pydiverse.transform.core.table_impl import AbstractTableImpl - - return traverse(arg, lambda x: Table(x) if isinstance(x, AbstractTableImpl) else x) diff --git a/src/pydiverse/transform/core/expressions/__init__.py b/src/pydiverse/transform/core/expressions/__init__.py deleted file mode 100644 index 85666d3a..00000000 --- a/src/pydiverse/transform/core/expressions/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from __future__ import annotations - -from .expressions import ( - CaseExpression, - Column, - FunctionCall, - LambdaColumn, - LiteralColumn, - expr_repr, -) -from .symbolic_expressions import SymbolicExpression, unwrap_symbolic_expressions -from .translator import Translator, TypedValue -from .util import iterate_over_expr diff --git a/src/pydiverse/transform/core/expressions/expressions.py b/src/pydiverse/transform/core/expressions/expressions.py deleted file mode 100644 index b9cb0821..00000000 --- a/src/pydiverse/transform/core/expressions/expressions.py +++ /dev/null @@ -1,262 +0,0 @@ -from __future__ import annotations - -import uuid -from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Generic - -from pydiverse.transform._typing import ImplT, T -from pydiverse.transform.core.dtypes import DType - -if TYPE_CHECKING: - from pydiverse.transform.core.expressions.translator import TypedValue - from pydiverse.transform.core.table_impl import AbstractTableImpl - - -def expr_repr(it: Any): - from pydiverse.transform.core.expressions import SymbolicExpression - - if isinstance(it, SymbolicExpression): - return expr_repr(it._) - if isinstance(it, BaseExpression): - return it._expr_repr() - if isinstance(it, (list, tuple)): - return f"[{ ', '.join([expr_repr(e) for e in it]) }]" - return repr(it) - - -_dunder_expr_repr = { - "__add__": lambda lhs, rhs: f"({lhs} + {rhs})", - "__radd__": lambda rhs, lhs: f"({lhs} + {rhs})", - "__sub__": lambda lhs, rhs: f"({lhs} - {rhs})", - "__rsub__": lambda rhs, lhs: f"({lhs} - {rhs})", - "__mul__": lambda lhs, rhs: f"({lhs} * {rhs})", - "__rmul__": lambda rhs, lhs: f"({lhs} * {rhs})", - "__truediv__": lambda lhs, rhs: f"({lhs} / {rhs})", - "__rtruediv__": lambda rhs, lhs: f"({lhs} / {rhs})", - "__floordiv__": lambda lhs, rhs: f"({lhs} // {rhs})", - "__rfloordiv__": lambda rhs, lhs: f"({lhs} // {rhs})", - "__pow__": lambda lhs, rhs: f"({lhs} ** {rhs})", - "__rpow__": lambda rhs, lhs: f"({lhs} ** {rhs})", - "__mod__": lambda lhs, rhs: f"({lhs} % {rhs})", - "__rmod__": lambda rhs, lhs: f"({lhs} % {rhs})", - "__round__": lambda x, y=None: f"round({x}, {y})" if y else f"round({x})", - "__pos__": lambda x: f"(+{x})", - "__neg__": lambda x: f"(-{x})", - "__abs__": lambda x: f"abs({x})", - "__and__": lambda lhs, rhs: f"({lhs} & {rhs})", - "__rand__": lambda rhs, lhs: f"({lhs} & {rhs})", - "__or__": lambda lhs, rhs: f"({lhs} | {rhs})", - "__ror__": lambda rhs, lhs: f"({lhs} | {rhs})", - "__xor__": lambda lhs, rhs: f"({lhs} ^ {rhs})", - "__rxor__": lambda rhs, lhs: f"({lhs} ^ {rhs})", - "__invert__": lambda x: f"(~{x})", - "__lt__": lambda lhs, rhs: f"({lhs} < {rhs})", - "__le__": lambda lhs, rhs: f"({lhs} <= {rhs})", - "__eq__": lambda lhs, rhs: f"({lhs} == {rhs})", - "__ne__": lambda lhs, rhs: f"({lhs} != {rhs})", - "__gt__": lambda lhs, rhs: f"({lhs} > {rhs})", - "__ge__": lambda lhs, rhs: f"({lhs} >= {rhs})", -} - - -class BaseExpression: - def _expr_repr(self) -> str: - """String repr that, when executed, returns the same expression""" - raise NotImplementedError - - -class Column(BaseExpression, Generic[ImplT]): - __slots__ = ("name", "table", "dtype", "uuid") - - def __init__(self, name: str, table: ImplT, dtype: DType, uuid: uuid.UUID = None): - self.name = name - self.table = table - self.dtype = dtype - self.uuid = uuid or Column.generate_col_uuid() - - def __repr__(self): - return f"<{self.table.name}.{self.name}({self.dtype})>" - - def _expr_repr(self) -> str: - return f"{self.table.name}.{self.name}" - - def __eq__(self, other): - if not isinstance(other, self.__class__): - return False - return self.name == other.name and self.uuid == other.uuid - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash(self.uuid) - - @classmethod - def generate_col_uuid(cls) -> uuid.UUID: - return uuid.uuid1() - - -class LambdaColumn(BaseExpression): - """Anonymous Column - - A lambda column is a column without an associated table or UUID. This means - that it can be used to reference columns in the same pipe as it was created. - - Example: - The following fails because `table.a` gets referenced before it gets created. - table >> mutate(a = table.x) >> mutate(b = table.a) - Instead you can use a lambda column to achieve this: - table >> mutate(a = table.x) >> mutate(b = C.a) - """ - - __slots__ = "name" - - def __init__(self, name: str): - self.name = name - - def __repr__(self): - return f"" - - def _expr_repr(self) -> str: - return f"C.{self.name}" - - def __eq__(self, other): - if isinstance(other, self.__class__): - return self.name == other.name - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash(("C", self.name)) - - -class LiteralColumn(BaseExpression, Generic[T]): - __slots__ = ("typed_value", "expr", "backend") - - def __init__( - self, - typed_value: TypedValue[T], - expr: Any, - backend: type[AbstractTableImpl], - ): - self.typed_value = typed_value - self.expr = expr - self.backend = backend - - def __repr__(self): - return f"" - - def _expr_repr(self) -> str: - return repr(self) - - def __eq__(self, other): - if not isinstance(other, self.__class__): - return False - return ( - self.typed_value == other.typed_value - and self.expr == other.expr - and self.backend == other.backend - ) - - def __ne__(self, other): - return not self.__eq__(other) - - -class FunctionCall(BaseExpression): - """ - AST node to represent a function / operator call. - """ - - def __init__(self, name: str, *args, **kwargs): - from pydiverse.transform.core.expressions.symbolic_expressions import ( - unwrap_symbolic_expressions, - ) - - # Unwrap all symbolic expressions in the input - args = unwrap_symbolic_expressions(args) - kwargs = unwrap_symbolic_expressions(kwargs) - - self.name = name - self.args = args - self.kwargs = kwargs - - def __repr__(self): - args = [repr(e) for e in self.args] + [ - f"{k}={repr(v)}" for k, v in self.kwargs.items() - ] - return f'{self.name}({", ".join(args)})' - - def _expr_repr(self) -> str: - args = [expr_repr(e) for e in self.args] + [ - f"{k}={expr_repr(v)}" for k, v in self.kwargs.items() - ] - - if self.name in _dunder_expr_repr: - return _dunder_expr_repr[self.name](*args) - - if len(self.args) == 0: - args_str = ", ".join(args) - return f"f.{self.name}({args_str})" - else: - args_str = ", ".join(args[1:]) - return f"{args[0]}.{self.name}({args_str})" - - def __eq__(self, other): - if isinstance(other, self.__class__): - return self.__dict__ == other.__dict__ - else: - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash((self.name, self.args, tuple(self.kwargs.items()))) - - def iter_children(self): - yield from self.args - - -class CaseExpression(BaseExpression): - def __init__( - self, switching_on: Any | None, cases: Iterable[tuple[Any, Any]], default: Any - ): - from pydiverse.transform.core.expressions.symbolic_expressions import ( - unwrap_symbolic_expressions, - ) - - # Unwrap all symbolic expressions in the input - switching_on = unwrap_symbolic_expressions(switching_on) - cases = unwrap_symbolic_expressions(list(cases)) - default = unwrap_symbolic_expressions(default) - - self.switching_on = switching_on - self.cases = cases - self.default = default - - def __repr__(self): - if self.switching_on: - return f"case({self.switching_on}, {self.cases}, default={self.default})" - else: - return f"case({self.cases}, default={self.default})" - - def _expr_repr(self) -> str: - prefix = "f" - if self.switching_on: - prefix = expr_repr(self.switching_on) - - args = [expr_repr(case) for case in self.cases] - args.append(f"default={expr_repr(self.default)}") - return f"{prefix}.case({', '.join(args)})" - - def iter_children(self): - if self.switching_on: - yield self.switching_on - - for k, v in self.cases: - yield k - yield v - - yield self.default diff --git a/src/pydiverse/transform/core/expressions/lambda_getter.py b/src/pydiverse/transform/core/expressions/lambda_getter.py deleted file mode 100644 index eb935d32..00000000 --- a/src/pydiverse/transform/core/expressions/lambda_getter.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from pydiverse.transform.core.expressions import LambdaColumn -from pydiverse.transform.core.expressions.symbolic_expressions import SymbolicExpression - -__all__ = ["C"] - - -class MC(type): - def __getattr__(cls, name: str) -> SymbolicExpression: - return SymbolicExpression(LambdaColumn(name)) - - def __getitem__(cls, name: str) -> SymbolicExpression: - return SymbolicExpression(LambdaColumn(name)) - - -class C(metaclass=MC): - pass diff --git a/src/pydiverse/transform/core/expressions/symbolic_expressions.py b/src/pydiverse/transform/core/expressions/symbolic_expressions.py deleted file mode 100644 index 45144f98..00000000 --- a/src/pydiverse/transform/core/expressions/symbolic_expressions.py +++ /dev/null @@ -1,153 +0,0 @@ -from __future__ import annotations - -from html import escape -from typing import Any, Generic - -from pydiverse.transform._typing import T -from pydiverse.transform.core.expressions import CaseExpression, FunctionCall, util -from pydiverse.transform.core.registry import OperatorRegistry -from pydiverse.transform.core.util import traverse - - -class SymbolicExpression(Generic[T]): - """ - Base class to represent a symbolic expression. It can be manipulated using - standard python operators (for example you can add them) or by calling - attributes of it. - - To get the non-symbolic version of this expression you use the - underscore `_` attribute. - """ - - __slots__ = ("_",) - - def __init__(self, underlying: T): - self._ = underlying - - def __getattr__(self, item) -> SymbolAttribute: - if item.startswith("_") and item.endswith("_") and len(item) >= 3: - # Attribute names can't begin and end with an underscore because - # IPython calls hasattr() to select the correct pretty printing - # function. Instead of hard coding a specific list, just throw - # an exception for all attributes that match the general pattern. - raise AttributeError( - f"Invalid attribute {item}. Attributes can't begin and end with an" - " underscore." - ) - - return SymbolAttribute(item, self) - - def __getitem__(self, item): - return SymbolicExpression(FunctionCall("__getitem__", self, item)) - - def case(self, *cases: tuple[Any, Any], default: Any = None) -> SymbolicExpression: - case_expression = CaseExpression( - switching_on=self, - cases=cases, - default=default, - ) - - return SymbolicExpression(case_expression) - - def __dir__(self): - # TODO: Instead of displaying all available operators, translate the - # expression and according to the dtype and backend only display - # the operators that actually are available. - return sorted(OperatorRegistry.ALL_REGISTERED_OPS) - - # __contains__, __iter__ and __bool__ are all invalid on s-expressions - __contains__ = None - __iter__ = None - - def __bool__(self): - raise TypeError( - "Symbolic expressions can't be converted to True/False, " - "or used with these keywords: not, and, or." - ) - - def __str__(self): - from pydiverse.transform.core.alignment import eval_aligned - - try: - result = eval_aligned(self._, check_alignment=False)._ - - dtype = result.typed_value.dtype - value = result.typed_value.value - return ( - f"Symbolic Expression: {repr(self._)}\ndtype: {dtype}\n\n{str(value)}" - ) - except Exception as e: - return ( - f"Symbolic Expression: {repr(self._)}\n" - "Failed to get evaluate due to an exception:\n" - f"{type(e).__name__}: {str(e)}" - ) - - def __repr__(self): - return f"" - - def _repr_html_(self): - from pydiverse.transform.core.alignment import eval_aligned - - html = f"
Symbolic Expression:\n{escape(repr(self._))}
" - - try: - result = eval_aligned(self._, check_alignment=False)._ - backend = util.determine_expr_backend(self._) - - value_repr = backend._html_repr_expr(result.typed_value.value) - html += ( - f"dtype: {escape(str(result.typed_value.dtype))}

" - ) - html += f"
{escape(value_repr)}
" - except Exception as e: - html += ( - "
Failed to get evaluate due to an exception:\n"
-                f"{escape(e.__class__.__name__)}: {escape(str(e))}
" - ) - - return html - - def _repr_pretty_(self, p, cycle): - p.text(str(self) if not cycle else "...") - - -class SymbolAttribute: - def __init__(self, name: str, on: SymbolicExpression): - self.__name = name - self.__on = on - - def __getattr__(self, item) -> SymbolAttribute: - return SymbolAttribute(self.__name + "." + item, self.__on) - - def __call__(self, *args, **kwargs) -> SymbolicExpression: - return SymbolicExpression(FunctionCall(self.__name, self.__on, *args, **kwargs)) - - def __hash__(self): - raise Exception( - "Nope... You probably didn't want to do this. Did you misspell the" - f" attribute name '{self.__name}' of '{self.__on}'? Maybe you forgot a" - " leading underscore." - ) - - -def unwrap_symbolic_expressions(arg: Any = None): - """ - Replaces all symbolic expressions in the input with their underlying value. - """ - return traverse(arg, lambda x: x._ if isinstance(x, SymbolicExpression) else x) - - -# Add all supported dunder methods to `SymbolicExpression`. -# This has to be done, because Python doesn't call __getattr__ for -# dunder methods. -def create_operator(op): - def impl(*args, **kwargs): - return SymbolicExpression(FunctionCall(op, *args, **kwargs)) - - return impl - - -for dunder in OperatorRegistry.SUPPORTED_DUNDER: - setattr(SymbolicExpression, dunder, create_operator(dunder)) -del create_operator diff --git a/src/pydiverse/transform/core/expressions/translator.py b/src/pydiverse/transform/core/expressions/translator.py deleted file mode 100644 index 47a8381d..00000000 --- a/src/pydiverse/transform/core/expressions/translator.py +++ /dev/null @@ -1,182 +0,0 @@ -from __future__ import annotations - -import dataclasses -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic - -from pydiverse.transform._typing import T -from pydiverse.transform.core import registry -from pydiverse.transform.core.expressions import ( - CaseExpression, - Column, - FunctionCall, - LiteralColumn, -) -from pydiverse.transform.ops.core import Operator, OPType -from pydiverse.transform.util import reraise - -if TYPE_CHECKING: - from pydiverse.transform.core.dtypes import DType - - -# Basic container to store value and associated type metadata -@dataclass -class TypedValue(Generic[T]): - value: T - dtype: DType - ftype: OPType = dataclasses.field(default=OPType.EWISE) - - def __iter__(self): - return iter((self.value, self.dtype)) - - -class Translator(Generic[T]): - def translate(self, expr, **kwargs) -> T: - """Translate an expression recursively.""" - try: - return bottom_up_replace(expr, lambda e: self._translate(e, **kwargs)) - except Exception as e: - msg = f"This exception occurred while translating the expression: {expr}" - reraise(e, suffix=msg) - - def _translate(self, expr, **kwargs) -> T: - """Translate an expression non recursively.""" - raise NotImplementedError - - -class DelegatingTranslator(Translator[T], Generic[T]): - """ - Translator that dispatches to different translate functions based on - the type of the expression. - """ - - def __init__(self, operator_registry: registry.OperatorRegistry): - self.operator_registry = operator_registry - - def translate(self, expr, **kwargs): - """Translate an expression recursively.""" - try: - return self._translate(expr, **kwargs) - except Exception as e: - msg = f"This exception occurred while translating the expression: {expr}" - reraise(e, suffix=msg) - - def _translate(self, expr, **kwargs): - if isinstance(expr, Column): - return self._translate_col(expr, **kwargs) - - if isinstance(expr, LiteralColumn): - return self._translate_literal_col(expr, **kwargs) - - if isinstance(expr, FunctionCall): - operator = self.operator_registry.get_operator(expr.name) - expr = FunctionCall(expr.name, *expr.args, **expr.kwargs) - - op_args, op_kwargs, context_kwargs = self._translate_function_arguments( - expr, operator, **kwargs - ) - - if op_kwargs: - raise NotImplementedError - - signature = tuple(arg.dtype for arg in op_args) - implementation = self.operator_registry.get_implementation( - expr.name, signature - ) - - return self._translate_function( - implementation, op_args, context_kwargs, **kwargs - ) - - if isinstance(expr, CaseExpression): - switching_on = ( - self._translate(expr.switching_on, **{**kwargs, "context": "case_val"}) - if expr.switching_on is not None - else None - ) - - cases = [] - for cond, value in expr.cases: - cases.append( - ( - self._translate(cond, **{**kwargs, "context": "case_cond"}), - self._translate(value, **{**kwargs, "context": "case_val"}), - ) - ) - - default = self._translate(expr.default, **{**kwargs, "context": "case_val"}) - return self._translate_case(expr, switching_on, cases, default, **kwargs) - - if literal_result := self._translate_literal(expr, **kwargs): - return literal_result - - raise NotImplementedError( - f"Couldn't find a way to translate object of type {type(expr)} with value" - f" {expr}." - ) - - def _translate_col(self, col: Column, **kwargs) -> T: - raise NotImplementedError - - def _translate_literal_col(self, col: LiteralColumn, **kwargs) -> T: - raise NotImplementedError - - def _translate_function( - self, - implementation: registry.TypedOperatorImpl, - op_args: list[T], - context_kwargs: dict[str, Any], - **kwargs, - ) -> T: - raise NotImplementedError - - def _translate_case( - self, - expr: CaseExpression, - switching_on: T | None, - cases: list[tuple[T, T]], - default: T, - **kwargs, - ) -> T: - raise NotImplementedError - - def _translate_literal(self, expr, **kwargs) -> T: - raise NotImplementedError - - def _translate_function_arguments( - self, expr: FunctionCall, operator: Operator, **kwargs - ) -> tuple[list[T], dict[str, T], dict[str, Any]]: - op_args = [self._translate(arg, **kwargs) for arg in expr.args] - op_kwargs = {} - context_kwargs = {} - - for k, v in expr.kwargs.items(): - if k in operator.context_kwargs: - context_kwargs[k] = v - else: - op_kwargs[k] = self._translate(v, **kwargs) - - return op_args, op_kwargs, context_kwargs - - -def bottom_up_replace(expr, replace): - def transform(expr): - if isinstance(expr, FunctionCall): - f = FunctionCall( - expr.name, - *(transform(arg) for arg in expr.args), - **{k: transform(v) for k, v in expr.kwargs.items()}, - ) - return replace(f) - - if isinstance(expr, CaseExpression): - c = CaseExpression( - switching_on=transform(expr.switching_on), - cases=[(transform(k), transform(v)) for k, v in expr.cases], - default=transform(expr.default), - ) - return replace(c) - - return replace(expr) - - return transform(expr) diff --git a/src/pydiverse/transform/core/expressions/util.py b/src/pydiverse/transform/core/expressions/util.py deleted file mode 100644 index d3c23d0b..00000000 --- a/src/pydiverse/transform/core/expressions/util.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from pydiverse.transform.core.expressions import ( - CaseExpression, - Column, - FunctionCall, - LiteralColumn, -) - -if TYPE_CHECKING: - # noinspection PyUnresolvedReferences - from pydiverse.transform.core.table_impl import AbstractTableImpl - - -def iterate_over_expr(expr, expand_literal_col=False): - """ - Iterate in depth-first preorder over the expression and yield all components. - """ - - yield expr - - if isinstance(expr, FunctionCall): - for child in expr.iter_children(): - yield from iterate_over_expr(child, expand_literal_col=expand_literal_col) - return - - if isinstance(expr, CaseExpression): - for child in expr.iter_children(): - yield from iterate_over_expr(child, expand_literal_col=expand_literal_col) - return - - if expand_literal_col and isinstance(expr, LiteralColumn): - yield from iterate_over_expr(expr.expr, expand_literal_col=expand_literal_col) - return - - -def determine_expr_backend(expr) -> type[AbstractTableImpl] | None: - """Returns the backend used in an expression. - - Iterates over an expression and extracts the underlying backend type used. - If no backend can be determined (because the expression doesn't contain a - column), None is returned instead. If different backends are being used, - throws an exception. - """ - - backends = set() - for atom in iterate_over_expr(expr): - if isinstance(atom, Column): - backends.add(type(atom.table)) - if isinstance(atom, LiteralColumn): - backends.add(atom.backend) - - if len(backends) == 1: - return backends.pop() - if len(backends) >= 2: - raise ValueError( - "Expression contains different backends " - f"(found: {[backend.__name__ for backend in backends]})." - ) - return None diff --git a/src/pydiverse/transform/core/functions.py b/src/pydiverse/transform/core/functions.py deleted file mode 100644 index eaa12a94..00000000 --- a/src/pydiverse/transform/core/functions.py +++ /dev/null @@ -1,55 +0,0 @@ -from __future__ import annotations - -from typing import Any - -from pydiverse.transform.core.expressions import ( - CaseExpression, - FunctionCall, - SymbolicExpression, -) - -__all__ = [ - "count", - "row_number", -] - - -def _sym_f_call(name, *args, **kwargs) -> SymbolicExpression[FunctionCall]: - return SymbolicExpression(FunctionCall(name, *args, **kwargs)) - - -def count(expr: SymbolicExpression | None = None): - if expr is None: - return _sym_f_call("count") - else: - return _sym_f_call("count", expr) - - -def row_number(*, arrange: list, partition_by: list | None = None): - return _sym_f_call("row_number", arrange=arrange, partition_by=partition_by) - - -def rank(*, arrange: list, partition_by: list | None = None): - return _sym_f_call("rank", arrange=arrange, partition_by=partition_by) - - -def dense_rank(*, arrange: list, partition_by: list | None = None): - return _sym_f_call("dense_rank", arrange=arrange, partition_by=partition_by) - - -def case(*cases: tuple[Any, Any], default: Any = None): - case_expression = CaseExpression( - switching_on=None, - cases=cases, - default=default, - ) - - return SymbolicExpression(case_expression) - - -def min(first: Any, *expr: Any): - return _sym_f_call("__least", first, *expr) - - -def max(first: Any, *expr: Any): - return _sym_f_call("__greatest", first, *expr) diff --git a/src/pydiverse/transform/core/table.py b/src/pydiverse/transform/core/table.py deleted file mode 100644 index fd2052c6..00000000 --- a/src/pydiverse/transform/core/table.py +++ /dev/null @@ -1,122 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterable -from html import escape -from typing import Generic - -from pydiverse.transform._typing import ImplT -from pydiverse.transform.core.expressions import ( - Column, - LambdaColumn, - SymbolicExpression, -) -from pydiverse.transform.core.verbs import export - - -class Table(Generic[ImplT]): - """ - All attributes of a table are columns except for the `_impl` attribute - which is a reference to the underlying table implementation. - """ - - def __init__(self, implementation: ImplT): - self._impl = implementation - - def __getitem__(self, key) -> SymbolicExpression[Column]: - if isinstance(key, SymbolicExpression): - key = key._ - return SymbolicExpression(self._impl.get_col(key)) - - def __setitem__(self, col, expr): - """Mutate a column - :param col: Either a str or SymbolicColumn - """ - from pydiverse.transform.core.verbs import mutate - - col_name = None - - if isinstance(col, SymbolicExpression): - underlying = col._ - if isinstance(underlying, (Column, LambdaColumn)): - col_name = underlying.name - elif isinstance(col, str): - col_name = col - - if not col_name: - raise KeyError( - f"Invalid key {col}. Must be either a string, Column or LambdaColumn." - ) - self._impl = (self >> mutate(**{col_name: expr}))._impl - - def __getattr__(self, name) -> SymbolicExpression[Column]: - return SymbolicExpression(self._impl.get_col(name)) - - def __iter__(self) -> Iterable[SymbolicExpression[Column]]: - # Capture current state (this allows modifying the table inside a loop) - cols = [ - SymbolicExpression(self._impl.get_col(name)) - for name, _ in self._impl.selected_cols() - ] - return iter(cols) - - def __eq__(self, other): - return isinstance(other, type(self)) and self._impl == other._impl - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash(self._impl) - - def __dir__(self): - return sorted(self._impl.named_cols.fwd.keys()) - - def __contains__(self, item): - if isinstance(item, SymbolicExpression): - item = item._ - if isinstance(item, LambdaColumn): - return item.name in self._impl.named_cols.fwd - if isinstance(item, Column): - return item.uuid in self._impl.available_cols - return False - - def __copy__(self): - impl_copy = self._impl.copy() - return self.__class__(impl_copy) - - def __str__(self): - try: - return ( - f"Table: {self._impl.name}, backend: {type(self._impl).__name__}\n" - f"{self >> export()}" - ) - except Exception as e: - return ( - f"Table: {self._impl.name}, backend: {type(self._impl).__name__}\n" - "Failed to collect table due to an exception:\n" - f"{type(e).__name__}: {str(e)}" - ) - - def _repr_html_(self) -> str | None: - html = ( - f"Table {self._impl.name} using" - f" {type(self._impl).__name__} backend:
" - ) - try: - # TODO: For lazy backend only show preview (eg. take first 20 rows) - html += (self >> export())._repr_html_() - except Exception as e: - html += ( - "
Failed to collect table due to an exception:\n"
-                f"{escape(e.__class__.__name__)}: {escape(str(e))}
" - ) - return html - - def _repr_pretty_(self, p, cycle): - p.text(str(self) if not cycle else "...") - - def cols(self) -> list[Column]: - return [ - self._impl.cols[uuid].as_column(name, self._impl) - for (name, uuid) in self._impl.selected_cols() - ] diff --git a/src/pydiverse/transform/core/table_impl.py b/src/pydiverse/transform/core/table_impl.py deleted file mode 100644 index 3831e0bd..00000000 --- a/src/pydiverse/transform/core/table_impl.py +++ /dev/null @@ -1,697 +0,0 @@ -from __future__ import annotations - -import copy -import dataclasses -import datetime -import uuid -import warnings -from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar - -from pydiverse.transform import ops -from pydiverse.transform._typing import ImplT -from pydiverse.transform.core import dtypes -from pydiverse.transform.core.expressions import ( - CaseExpression, - Column, - LambdaColumn, - LiteralColumn, -) -from pydiverse.transform.core.expressions.translator import ( - DelegatingTranslator, - Translator, - TypedValue, -) -from pydiverse.transform.core.registry import ( - OperatorRegistrationContextManager, - OperatorRegistry, -) -from pydiverse.transform.core.util import bidict, ordered_set -from pydiverse.transform.errors import ExpressionTypeError, FunctionTypeError -from pydiverse.transform.ops import OPType - -if TYPE_CHECKING: - from pydiverse.transform.core.util import OrderingDescriptor - from pydiverse.transform.ops import Operator - - -ExprCompT = TypeVar("ExprCompT", bound="TypedValue") -AlignedT = TypeVar("AlignedT", bound="TypedValue") - - -class AbstractTableImpl: - """ - Base class from which all table backend implementations are derived from. - It tracks various metadata that is relevant for all backends. - - Attributes: - name: The name of the table. - - selects: Ordered set of selected names. - named_cols: Map from name to column uuid containing all columns that - have been named. - available_cols: Set of UUIDs that can be referenced in symbolic - expressions. This set gets used to validate verb inputs. It usually - contains the same uuids as the col_exprs. Only a summarising - operation resets this. - col_expr: Map from uuid to the `SymbolicExpression` that corresponds - to this column. - col_dtype: Map from uuid to the datatype of the corresponding column. - It is the responsibility of the backend to keep track of - this information. - - grouped_by: Ordered set of columns by which the table is grouped by. - intrinsic_grouped_by: Ordered set of columns representing the underlying - grouping level of the table. This gets set when performing a - summarising operation. - """ - - operator_registry = OperatorRegistry("AbstractTableImpl") - - def __init__( - self, - name: str, - columns: dict[str, Column], - ): - self.name = name - self.compiler = self.ExpressionCompiler(self) - self.lambda_translator = self.LambdaTranslator(self) - - self.selects: ordered_set[str] = ordered_set() # subset of named_cols - self.named_cols: bidict[str, uuid.UUID] = bidict() - self.available_cols: set[uuid.UUID] = set() - self.cols: dict[uuid.UUID, ColumnMetaData] = dict() - - self.grouped_by: ordered_set[Column] = ordered_set() - self.intrinsic_grouped_by: ordered_set[Column] = ordered_set() - - # Init Values - for name, col in columns.items(): - self.selects.add(name) - self.named_cols.fwd[name] = col.uuid - self.available_cols.add(col.uuid) - self.cols[col.uuid] = ColumnMetaData.from_expr(col.uuid, col, self) - - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - - # Add new `operator_registry` class variable to subclass. - # We define the super registry by walking up the MRO. This allows us - # to check for potential operation definitions in the parent classes. - super_reg = None - for super_cls in cls.__mro__: - if hasattr(super_cls, "operator_registry"): - super_reg = super_cls.operator_registry - break - cls.operator_registry = OperatorRegistry(cls.__name__, super_reg) - - def copy(self): - c = copy.copy(self) - # Copy containers - for k, v in self.__dict__.items(): - if isinstance(v, (list, dict, set, bidict, ordered_set)): - c.__dict__[k] = copy.copy(v) - - # Must create a new translator, so that it can access the current df. - c.compiler = self.ExpressionCompiler(c) - c.lambda_translator = self.LambdaTranslator(c) - return c - - def get_col(self, key: str | Column | LambdaColumn): - """Getter used by `Table.__getattr__`""" - - if isinstance(key, LambdaColumn): - key = key.name - - if isinstance(key, str): - if uuid := self.named_cols.fwd.get(key, None): - return self.cols[uuid].as_column(key, self) - # Must return AttributeError, else `hasattr` doesn't work on Table instances - raise AttributeError(f"Table '{self.name}' has not column named '{key}'.") - - if isinstance(key, Column): - uuid = key.uuid - if uuid in self.available_cols: - name = self.named_cols.bwd[uuid] - return self.cols[uuid].as_column(name, self) - raise KeyError(f"Table '{self.name}' has no column that matches '{key}'.") - - def selected_cols(self) -> Iterable[tuple[str, uuid.UUID]]: - for name in self.selects: - yield (name, self.named_cols.fwd[name]) - - def resolve_lambda_cols(self, expr: Any): - return self.lambda_translator.translate(expr) - - def is_aligned_with(self, col: Column | LiteralColumn) -> bool: - """Determine if a column is aligned with the table. - - :param col: The column or literal colum against which alignment - should be checked. - :return: A boolean indicating if `col` is aligned with self. - """ - raise NotImplementedError - - @classmethod - def _html_repr_expr(cls, expr): - """ - Return an appropriate string to display an expression from this backend. - This is mainly used to IPython. - """ - return repr(expr) - - #### Verb Callbacks #### - - def preverb_hook(self, verb: str, *args, **kwargs) -> None: - """Hook that gets called right after `copy` inside a verb - - This gives the backend a chance to react and modify it's state. This - can, for example, be used to create a subquery based on specific - conditions. - - :param verb: The name of the verb - :param args: The arguments passed to the verb - :param kwargs: The keyword arguments passed to the verb - """ - ... - - def alias(self, name=None) -> AbstractTableImpl: ... - - def collect(self): ... - - def build_query(self): ... - - def select(self, *args): ... - - def mutate(self, **kwargs): ... - - def join(self, right, on, how, *, validate="m:m"): ... - - def filter(self, *args): ... - - def arrange(self, ordering: list[OrderingDescriptor]): ... - - def group_by(self, *args): ... - - def ungroup(self, *args): ... - - def summarise(self, **kwargs): ... - - def slice_head(self, n: int, offset: int): ... - - def export(self): ... - - #### Symbolic Operators #### - - @classmethod - def op(cls, operator: Operator, **kwargs) -> OperatorRegistrationContextManager: - return OperatorRegistrationContextManager( - cls.operator_registry, operator, **kwargs - ) - - #### Expressions #### - - class ExpressionCompiler( - DelegatingTranslator[ExprCompT], Generic[ImplT, ExprCompT] - ): - """ - Class convert an expression into a function that, when provided with - the appropriate arguments, evaluates the expression. - - The reason we can't just eagerly evaluate the expression is because for - grouped data we often have to use the split-apply-combine strategy. - """ - - def __init__(self, backend: ImplT): - self.backend = backend - super().__init__(backend.operator_registry) - - def _translate_literal(self, expr, **kwargs): - literal = self._translate_literal_value(expr) - - if isinstance(expr, bool): - return TypedValue(literal, dtypes.Bool(const=True)) - if isinstance(expr, int): - return TypedValue(literal, dtypes.Int(const=True)) - if isinstance(expr, float): - return TypedValue(literal, dtypes.Float(const=True)) - if isinstance(expr, str): - return TypedValue(literal, dtypes.String(const=True)) - if isinstance(expr, datetime.datetime): - return TypedValue(literal, dtypes.DateTime(const=True)) - if isinstance(expr, datetime.date): - return TypedValue(literal, dtypes.Date(const=True)) - if isinstance(expr, datetime.timedelta): - return TypedValue(literal, dtypes.Duration(const=True)) - - if expr is None: - return TypedValue(literal, dtypes.NoneDType(const=True)) - - def _translate_literal_value(self, expr): - def literal_func(*args, **kwargs): - return expr - - return literal_func - - def _translate_case_common( - self, - expr: CaseExpression, - switching_on: ExprCompT | None, - cases: list[tuple[ExprCompT, ExprCompT]], - default: ExprCompT, - **kwargs, - ) -> tuple[dtypes.DType, OPType]: - # Determine dtype of result - val_dtypes = [default.dtype.without_modifiers()] - for _, val in cases: - val_dtypes.append(val.dtype.without_modifiers()) - - result_dtype = dtypes.promote_dtypes(val_dtypes) - - # Determine ftype of result - val_ftypes = set() - if not default.dtype.const: - val_ftypes.add(default.ftype) - - for _, val in cases: - if not val.dtype.const: - val_ftypes.add(val.ftype) - - if len(val_ftypes) == 0: - result_ftype = OPType.EWISE - elif len(val_ftypes) == 1: - (result_ftype,) = val_ftypes - elif OPType.WINDOW in val_ftypes: - result_ftype = OPType.WINDOW - else: - # AGGREGATE and EWISE are incompatible - raise FunctionTypeError( - "Incompatible function types found in case statement: " ", ".join( - val_ftypes - ) - ) - - if result_ftype is OPType.EWISE and switching_on is not None: - result_ftype = switching_on.ftype - - # Type check conditions - if switching_on is None: - # All conditions must be boolean - for cond, _ in cases: - if not dtypes.Bool().same_kind(cond.dtype): - raise ExpressionTypeError( - "All conditions in a case statement return booleans. " - f"{cond} is of type {cond.dtype}." - ) - else: - # All conditions must be of the same type as switching_on - for cond, _ in cases: - if not cond.dtype.can_promote_to( - switching_on.dtype.without_modifiers() - ): - # Can't compare - raise ExpressionTypeError( - f"Condition value {cond} (dtype: {cond.dtype}) " - f"is incompatible with switch dtype {switching_on.dtype}." - ) - - return result_dtype, result_ftype - - class AlignedExpressionEvaluator(DelegatingTranslator[AlignedT], Generic[AlignedT]): - """ - Used for evaluating an expression in a typical eager style where, as - long as two columns have the same alignment / length, we can perform - operations on them without first having to join them. - """ - - def _translate_literal(self, expr, **kwargs): - if isinstance(expr, bool): - return TypedValue(expr, dtypes.Bool(const=True)) - if isinstance(expr, int): - return TypedValue(expr, dtypes.Int(const=True)) - if isinstance(expr, float): - return TypedValue(expr, dtypes.Float(const=True)) - if isinstance(expr, str): - return TypedValue(expr, dtypes.String(const=True)) - if isinstance(expr, datetime.datetime): - return TypedValue(expr, dtypes.DateTime(const=True)) - if isinstance(expr, datetime.date): - return TypedValue(expr, dtypes.Date(const=True)) - if isinstance(expr, datetime.timedelta): - return TypedValue(expr, dtypes.Duration(const=True)) - - if expr is None: - return TypedValue(expr, dtypes.NoneDType(const=True)) - - class LambdaTranslator(Translator): - """ - Translator that takes an expression and replaces all LambdaColumns - inside it with the corresponding Column instance. - """ - - def __init__(self, backend: ImplT): - self.backend = backend - super().__init__() - - def _translate(self, expr, **kwargs): - # Resolve lambda and return Column object - if isinstance(expr, LambdaColumn): - if expr.name not in self.backend.named_cols.fwd: - raise ValueError( - f"Invalid lambda column '{expr.name}'. No column with this name" - f" found for table '{self.backend.name}'." - ) - uuid = self.backend.named_cols.fwd[expr.name] - return self.backend.cols[uuid].as_column(expr.name, self.backend) - return expr - - #### Helpers #### - - @classmethod - def _get_op_ftype( - cls, args, operator: Operator, override_ftype: OPType = None, strict=False - ) -> OPType: - """ - Get the ftype based on a function implementation and the arguments. - - e(e) -> e a(e) -> a w(e) -> w - e(a) -> a a(a) -> Err w(a) -> w - e(w) -> w a(w) -> Err w(w) -> Err - - If the implementation ftype is incompatible with the arguments, this - function raises an Exception. - """ - - ftypes = [arg.ftype for arg in args] - op_ftype = override_ftype or operator.ftype - - if op_ftype == OPType.EWISE: - if OPType.WINDOW in ftypes: - return OPType.WINDOW - if OPType.AGGREGATE in ftypes: - return OPType.AGGREGATE - return op_ftype - - if op_ftype == OPType.AGGREGATE: - if OPType.WINDOW in ftypes: - if strict: - raise FunctionTypeError( - "Can't nest a window function inside an aggregate function" - f" ({operator.name})." - ) - else: - # TODO: Replace with logger - warnings.warn( - "Nesting a window function inside an aggregate function is not" - " supported by SQL backend." - ) - if OPType.AGGREGATE in ftypes: - raise FunctionTypeError( - "Can't nest an aggregate function inside an aggregate function" - f" ({operator.name})." - ) - return op_ftype - - if op_ftype == OPType.WINDOW: - if OPType.WINDOW in ftypes: - if strict: - raise FunctionTypeError( - "Can't nest a window function inside a window function" - f" ({operator.name})." - ) - else: - warnings.warn( - "Nesting a window function inside a window function is not" - " supported by SQL backend." - ) - return op_ftype - - -@dataclasses.dataclass -class ColumnMetaData: - uuid: uuid.UUID - expr: Any - compiled: Callable[[Any], TypedValue] - dtype: dtypes.DType - ftype: OPType - - @classmethod - def from_expr(cls, uuid, expr, table: AbstractTableImpl, **kwargs): - v: TypedValue = table.compiler.translate(expr, **kwargs) - return cls( - uuid=uuid, - expr=expr, - compiled=v.value, - dtype=v.dtype.without_modifiers(), - ftype=v.ftype, - ) - - def __hash__(self): - return hash(self.uuid) - - def as_column(self, name, table: AbstractTableImpl): - return Column(name, table, self.dtype, self.uuid) - - -#### MARKER OPERATIONS ######################################################### - - -with AbstractTableImpl.op(ops.NullsFirst()) as op: - - @op.auto - def _nulls_first(_): - raise RuntimeError("This is just a marker that never should get called") - - -with AbstractTableImpl.op(ops.NullsLast()) as op: - - @op.auto - def _nulls_last(_): - raise RuntimeError("This is just a marker that never should get called") - - -#### ARITHMETIC OPERATORS ###################################################### - - -with AbstractTableImpl.op(ops.Add()) as op: - - @op.auto - def _add(lhs, rhs): - return lhs + rhs - - @op.extension(ops.StrAdd) - def _str_add(lhs, rhs): - return lhs + rhs - - -with AbstractTableImpl.op(ops.RAdd()) as op: - - @op.auto - def _radd(rhs, lhs): - return lhs + rhs - - @op.extension(ops.StrRAdd) - def _str_radd(lhs, rhs): - return lhs + rhs - - -with AbstractTableImpl.op(ops.Sub()) as op: - - @op.auto - def _sub(lhs, rhs): - return lhs - rhs - - -with AbstractTableImpl.op(ops.RSub()) as op: - - @op.auto - def _rsub(rhs, lhs): - return lhs - rhs - - -with AbstractTableImpl.op(ops.Mul()) as op: - - @op.auto - def _mul(lhs, rhs): - return lhs * rhs - - -with AbstractTableImpl.op(ops.RMul()) as op: - - @op.auto - def _rmul(rhs, lhs): - return lhs * rhs - - -with AbstractTableImpl.op(ops.TrueDiv()) as op: - - @op.auto - def _truediv(lhs, rhs): - return lhs / rhs - - -with AbstractTableImpl.op(ops.RTrueDiv()) as op: - - @op.auto - def _rtruediv(rhs, lhs): - return lhs / rhs - - -with AbstractTableImpl.op(ops.FloorDiv()) as op: - - @op.auto - def _floordiv(lhs, rhs): - return lhs // rhs - - -with AbstractTableImpl.op(ops.RFloorDiv()) as op: - - @op.auto - def _rfloordiv(rhs, lhs): - return lhs // rhs - - -with AbstractTableImpl.op(ops.Pow()) as op: - - @op.auto - def _pow(lhs, rhs): - return lhs**rhs - - -with AbstractTableImpl.op(ops.RPow()) as op: - - @op.auto - def _rpow(rhs, lhs): - return lhs**rhs - - -with AbstractTableImpl.op(ops.Mod()) as op: - - @op.auto - def _mod(lhs, rhs): - return lhs % rhs - - -with AbstractTableImpl.op(ops.RMod()) as op: - - @op.auto - def _rmod(rhs, lhs): - return lhs % rhs - - -with AbstractTableImpl.op(ops.Neg()) as op: - - @op.auto - def _neg(x): - return -x - - -with AbstractTableImpl.op(ops.Pos()) as op: - - @op.auto - def _pos(x): - return +x - - -with AbstractTableImpl.op(ops.Abs()) as op: - - @op.auto - def _abs(x): - return abs(x) - - -#### BINARY OPERATORS ########################################################## - - -with AbstractTableImpl.op(ops.And()) as op: - - @op.auto - def _and(lhs, rhs): - return lhs & rhs - - -with AbstractTableImpl.op(ops.RAnd()) as op: - - @op.auto - def _rand(rhs, lhs): - return lhs & rhs - - -with AbstractTableImpl.op(ops.Or()) as op: - - @op.auto - def _or(lhs, rhs): - return lhs | rhs - - -with AbstractTableImpl.op(ops.ROr()) as op: - - @op.auto - def _ror(rhs, lhs): - return lhs | rhs - - -with AbstractTableImpl.op(ops.Xor()) as op: - - @op.auto - def _xor(lhs, rhs): - return lhs ^ rhs - - -with AbstractTableImpl.op(ops.RXor()) as op: - - @op.auto - def _rxor(rhs, lhs): - return lhs ^ rhs - - -with AbstractTableImpl.op(ops.Invert()) as op: - - @op.auto - def _invert(x): - return ~x - - -#### COMPARISON OPERATORS ###################################################### - - -with AbstractTableImpl.op(ops.Equal()) as op: - - @op.auto - def _eq(lhs, rhs): - return lhs == rhs - - -with AbstractTableImpl.op(ops.NotEqual()) as op: - - @op.auto - def _ne(lhs, rhs): - return lhs != rhs - - -with AbstractTableImpl.op(ops.Less()) as op: - - @op.auto - def _lt(lhs, rhs): - return lhs < rhs - - -with AbstractTableImpl.op(ops.LessEqual()) as op: - - @op.auto - def _le(lhs, rhs): - return lhs <= rhs - - -with AbstractTableImpl.op(ops.Greater()) as op: - - @op.auto - def _gt(lhs, rhs): - return lhs > rhs - - -with AbstractTableImpl.op(ops.GreaterEqual()) as op: - - @op.auto - def _ge(lhs, rhs): - return lhs >= rhs diff --git a/src/pydiverse/transform/core/util/__init__.py b/src/pydiverse/transform/core/util/__init__.py deleted file mode 100644 index 04973739..00000000 --- a/src/pydiverse/transform/core/util/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from __future__ import annotations - -from .bidict import bidict -from .ordered_set import ordered_set -from .util import * diff --git a/src/pydiverse/transform/core/util/bidict.py b/src/pydiverse/transform/core/util/bidict.py deleted file mode 100644 index 59308976..00000000 --- a/src/pydiverse/transform/core/util/bidict.py +++ /dev/null @@ -1,95 +0,0 @@ -from __future__ import annotations - -from collections.abc import ( - ItemsView, - Iterable, - KeysView, - Mapping, - MutableMapping, - ValuesView, -) -from typing import ( - Generic, - TypeVar, -) - -KT = TypeVar("KT") -VT = TypeVar("VT") - - -class bidict(Generic[KT, VT]): - """ - Bidirectional Dictionary - All keys and values must be unique (bijective one to one mapping). - - To go from key to value use `bidict.fwd`. - To go from value to key use `bidict.bwd`. - """ - - def __init__(self, seq: Mapping[KT, VT] = None, /, *, fwd=None, bwd=None): - if fwd is not None and bwd is not None: - self.__fwd = fwd - self.__bwd = bwd - else: - self.__fwd = dict(seq) if seq is not None else dict() - self.__bwd = {v: k for k, v in self.__fwd.items()} - - if len(self.__fwd) != len(self.__bwd) != len(seq): - raise ValueError( - "Input sequence contains duplicate key value pairs. Mapping must be" - " unique." - ) - - self.fwd = _BidictInterface(self.__fwd, self.__bwd) # type: _BidictInterface[KT, VT] - self.bwd = _BidictInterface(self.__bwd, self.__fwd) # type: _BidictInterface[VT, KT] - - def __copy__(self): - return bidict(fwd=self.__fwd.copy(), bwd=self.__bwd.copy()) - - def __len__(self): - return len(self.__fwd) - - def clear(self): - self.__fwd.clear() - self.__bwd.clear() - - -class _BidictInterface(MutableMapping[KT, VT]): - def __init__(self, fwd: dict[KT, VT], bwd: dict[VT, KT]): - self.__fwd = fwd - self.__bwd = bwd - - def __setitem__(self, key: KT, value: VT): - if key in self.__fwd: - fwd_value = self.__fwd[key] - del self.__bwd[fwd_value] - if value in self.__bwd: - raise ValueError(f"Duplicate value '{value}'. Mapping must be unique.") - self.__fwd[key] = value - self.__bwd[value] = key - - def __getitem__(self, key: KT) -> VT: - return self.__fwd[key] - - def __delitem__(self, key: KT): - value = self.__fwd[key] - del self.__fwd[key] - del self.__bwd[value] - - def __iter__(self) -> Iterable[KT]: - yield from self.__fwd.__iter__() - - def __len__(self) -> int: - return len(self.__fwd) - - def __contains__(self, item) -> bool: - return item in self.__fwd - - def items(self) -> ItemsView[KT, VT]: - return self.__fwd.items() - - def keys(self) -> KeysView[KT]: - return self.__fwd.keys() - - def values(self) -> ValuesView[VT]: - return self.__fwd.values() diff --git a/src/pydiverse/transform/core/util/ordered_set.py b/src/pydiverse/transform/core/util/ordered_set.py deleted file mode 100644 index 085bb2a1..00000000 --- a/src/pydiverse/transform/core/util/ordered_set.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterable, MutableSet - -from pydiverse.transform._typing import T - - -class ordered_set(MutableSet[T]): - def __init__(self, values: Iterable[T] = tuple()): - self.__data = {v: None for v in values} - - def __contains__(self, item: T) -> bool: - return item in self.__data - - def __iter__(self) -> Iterable[T]: - yield from self.__data.keys() - - def __len__(self) -> int: - return len(self.__data) - - def __repr__(self): - return f'{", ".join(repr(e) for e in self)}' - - def __copy__(self): - return self.__class__(self) - - def add(self, value: T) -> None: - self.__data[value] = None - - def discard(self, value: T) -> None: - del self.__data[value] - - def clear(self) -> None: - self.__data.clear() - - def copy(self): - return self.__copy__() - - def pop_back(self) -> None: - """Return the popped value.Raise KeyError if empty.""" - if len(self) == 0: - raise KeyError("Ordered set is empty.") - back = next(reversed(self.__data.keys())) - self.discard(back) - return back diff --git a/src/pydiverse/transform/core/util/util.py b/src/pydiverse/transform/core/util/util.py deleted file mode 100644 index 238d3f9f..00000000 --- a/src/pydiverse/transform/core/util/util.py +++ /dev/null @@ -1,90 +0,0 @@ -from __future__ import annotations - -import typing -from dataclasses import dataclass - -from pydiverse.transform._typing import T -from pydiverse.transform.core.expressions import FunctionCall - -__all__ = ( - "traverse", - "sign_peeler", - "OrderingDescriptor", - "translate_ordering", -) - - -def traverse(obj: T, callback: typing.Callable) -> T: - if isinstance(obj, list): - return [traverse(elem, callback) for elem in obj] - if isinstance(obj, dict): - return {k: traverse(v, callback) for k, v in obj.items()} - if isinstance(obj, tuple): - if type(obj) is not tuple: - # Named tuples cause problems - raise Exception - return tuple(traverse(elem, callback) for elem in obj) - - return callback(obj) - - -def peel_markers(expr, markers): - found_markers = [] - while isinstance(expr, FunctionCall): - if expr.name in markers: - found_markers.append(expr.name) - assert len(expr.args) == 1 - expr = expr.args[0] - else: - break - return expr, found_markers - - -def sign_peeler(expr): - """ - Remove unary - and + prefix and return the sign - :return: `True` for `+` and `False` for `-` - """ - - expr, markers = peel_markers(expr, {"__neg__", "__pos__"}) - num_neg = markers.count("__neg__") - return expr, num_neg % 2 == 0 - - -def ordering_peeler(expr): - expr, markers = peel_markers( - expr, {"__neg__", "__pos__", "nulls_first", "nulls_last"} - ) - - ascending = markers.count("__neg__") % 2 == 0 - nulls_first = False - for marker in markers: - if marker == "nulls_first": - nulls_first = True - break - if marker == "nulls_last": - break - - return expr, ascending, nulls_first - - -#### - - -@dataclass -class OrderingDescriptor: - __slots__ = ("order", "asc", "nulls_first") - - order: typing.Any - asc: bool - nulls_first: bool - - -def translate_ordering(tbl, order_list) -> list[OrderingDescriptor]: - ordering = [] - for arg in order_list: - col, ascending, nulls_first = ordering_peeler(arg) - col = tbl.resolve_lambda_cols(col) - ordering.append(OrderingDescriptor(col, ascending, nulls_first)) - - return ordering diff --git a/src/pydiverse/transform/core/verbs.py b/src/pydiverse/transform/core/verbs.py deleted file mode 100644 index 6ab184a2..00000000 --- a/src/pydiverse/transform/core/verbs.py +++ /dev/null @@ -1,506 +0,0 @@ -from __future__ import annotations - -import functools -from collections import ChainMap -from collections.abc import Iterable -from typing import Literal - -from pydiverse.transform.core import dtypes -from pydiverse.transform.core.dispatchers import builtin_verb -from pydiverse.transform.core.expressions import ( - Column, - LambdaColumn, - SymbolicExpression, -) -from pydiverse.transform.core.expressions.util import iterate_over_expr -from pydiverse.transform.core.table_impl import AbstractTableImpl, ColumnMetaData -from pydiverse.transform.core.util import ( - bidict, - ordered_set, - sign_peeler, - translate_ordering, -) -from pydiverse.transform.errors import ExpressionTypeError, FunctionTypeError -from pydiverse.transform.ops import OPType - -__all__ = [ - "alias", - "collect", - "build_query", - "show_query", - "select", - "rename", - "mutate", - "join", - "left_join", - "inner_join", - "outer_join", - "filter", - "arrange", - "group_by", - "ungroup", - "summarise", - "slice_head", - "export", -] - - -def check_cols_available( - tables: AbstractTableImpl | Iterable[AbstractTableImpl], - columns: set[Column], - function_name: str, -): - if isinstance(tables, AbstractTableImpl): - tables = (tables,) - available_columns = ChainMap(*(table.available_cols for table in tables)) - missing_columns = [] - for col in columns: - if col.uuid not in available_columns: - missing_columns.append(col) - if missing_columns: - missing_columns_str = ", ".join(map(lambda x: str(x), missing_columns)) - raise ValueError( - f"Can't access column(s) {missing_columns_str} in {function_name}() because" - " they aren't available in the input." - ) - - -def check_lambdas_valid(tbl: AbstractTableImpl, *expressions): - lambdas = [] - for expression in expressions: - lambdas.extend( - lc for lc in iterate_over_expr(expression) if isinstance(lc, LambdaColumn) - ) - missing_lambdas = {lc for lc in lambdas if lc.name not in tbl.named_cols.fwd} - if missing_lambdas: - missing_lambdas_str = ", ".join(map(lambda x: str(x), missing_lambdas)) - raise ValueError(f"Invalid lambda column(s) {missing_lambdas_str}.") - - -def cols_in_expression(expression) -> set[Column]: - return {c for c in iterate_over_expr(expression) if isinstance(c, Column)} - - -def cols_in_expressions(expressions) -> set[Column]: - if len(expressions) == 0: - return set() - return set.union(*(cols_in_expression(e) for e in expressions)) - - -def validate_table_args(*tables): - if len(tables) == 0: - return - - for table in tables: - if not isinstance(table, AbstractTableImpl): - raise TypeError(f"Expected a TableImpl but got {type(table)} instead.") - - backend = type(tables[0]) - for table in tables: - if type(table) is not backend: - raise ValueError( - f"Can't mix tables with different backends. Expected '{backend}' but" - f" found '{type(table)}'." - ) - - -@builtin_verb() -def alias(tbl: AbstractTableImpl, name: str | None = None): - """Creates a new table object with a different name and reassigns column UUIDs. - Must be used before performing a self-join.""" - validate_table_args(tbl) - return tbl.alias(name) - - -@builtin_verb() -def collect(tbl: AbstractTableImpl): - validate_table_args(tbl) - return tbl.collect() - - -@builtin_verb() -def export(tbl: AbstractTableImpl): - validate_table_args(tbl) - return tbl.export() - - -@builtin_verb() -def build_query(tbl: AbstractTableImpl): - return tbl.build_query() - - -@builtin_verb() -def show_query(tbl: AbstractTableImpl): - if query := tbl.build_query(): - print(query) - else: - print(f"No query to show for {type(tbl).__name__}") - - return tbl - - -@builtin_verb() -def select(tbl: AbstractTableImpl, *args: Column | LambdaColumn): - if len(args) == 1 and args[0] is Ellipsis: - # >> select(...) -> Select all columns - args = [ - tbl.cols[uuid].as_column(name, tbl) - for name, uuid in tbl.named_cols.fwd.items() - ] - - # Validate input - validate_table_args(tbl) - check_cols_available(tbl, cols_in_expressions(args), "select") - check_lambdas_valid(tbl, *args) - - cols = [] - positive_selection = None - for col in args: - col, is_pos = sign_peeler(col) - if positive_selection is None: - positive_selection = is_pos - else: - if is_pos is not positive_selection: - raise ValueError( - "All columns in input must have the same sign." - " Can't mix selection with deselection." - ) - - if not isinstance(col, (Column, LambdaColumn)): - raise TypeError( - "Arguments to select verb must be of type 'Column' or 'LambdaColumn'" - f" and not {type(col)}." - ) - cols.append(col) - - selects = [] - for col in cols: - if isinstance(col, Column): - selects.append(tbl.named_cols.bwd[col.uuid]) - elif isinstance(col, LambdaColumn): - selects.append(col.name) - - # Invert selection - if positive_selection is False: - exclude = set(selects) - selects.clear() - for name in tbl.selects: - if name not in exclude: - selects.append(name) - - new_tbl = tbl.copy() - new_tbl.preverb_hook("select", *args) - new_tbl.selects = ordered_set(selects) - new_tbl.select(*args) - return new_tbl - - -@builtin_verb() -def rename(tbl: AbstractTableImpl, name_map: dict[str, str]): - # Type check - for k, v in name_map.items(): - if not isinstance(k, str) or not isinstance(v, str): - raise TypeError( - f"Key and Value of `name_map` must both be strings: ({k!r}, {v!r})" - ) - - # Reference col that doesn't exist - if missing_cols := name_map.keys() - tbl.named_cols.fwd.keys(): - raise KeyError("Table has no columns named: " + ", ".join(missing_cols)) - - # Can't rename two cols to the same name - _seen = set() - if duplicate_names := { - name for name in name_map.values() if name in _seen or _seen.add(name) - }: - raise ValueError( - "Can't rename multiple columns to the same name: " - + ", ".join(duplicate_names) - ) - - # Can't rename a column to one that already exists - unmodified_cols = tbl.named_cols.fwd.keys() - name_map.keys() - if duplicate_names := unmodified_cols & set(name_map.values()): - raise ValueError( - "Table already contains columns named: " + ", ".join(duplicate_names) - ) - - # Rename - new_tbl = tbl.copy() - new_tbl.selects = ordered_set(name_map.get(name, name) for name in new_tbl.selects) - - uuid_name_map = {new_tbl.named_cols.fwd[old]: new for old, new in name_map.items()} - for uuid in uuid_name_map: - del new_tbl.named_cols.bwd[uuid] - for uuid, name in uuid_name_map.items(): - new_tbl.named_cols.bwd[uuid] = name - - return new_tbl - - -@builtin_verb() -def mutate(tbl: AbstractTableImpl, **kwargs: SymbolicExpression): - validate_table_args(tbl) - check_cols_available(tbl, cols_in_expressions(kwargs.values()), "mutate") - - new_tbl = tbl.copy() - new_tbl.preverb_hook("mutate", **kwargs) - kwargs = {k: new_tbl.resolve_lambda_cols(v) for k, v in kwargs.items()} - - for name, expr in kwargs.items(): - uid = Column.generate_col_uuid() - col = ColumnMetaData.from_expr(uid, expr, new_tbl, verb="mutate") - - if dtypes.NoneDType().same_kind(col.dtype): - raise ExpressionTypeError( - f"Column '{name}' has an invalid type: {col.dtype}" - ) - - new_tbl.selects.add(name) - new_tbl.named_cols.fwd[name] = uid - new_tbl.available_cols.add(uid) - new_tbl.cols[uid] = col - - new_tbl.mutate(**kwargs) - return new_tbl - - -@builtin_verb() -def join( - left: AbstractTableImpl, - right: AbstractTableImpl, - on: SymbolicExpression, - how: Literal["inner", "left", "outer"], - *, - validate: Literal["1:1", "1:m", "m:1", "m:m"] = "m:m", - suffix: str | None = None, # appended to cols of the right table -): - validate_table_args(left, right) - - if left.grouped_by or right.grouped_by: - raise ValueError("Can't join grouped tables. You first have to ungroup them.") - - # Check args only contains valid columns - check_cols_available((left, right), cols_in_expression(on), "join") - - if how not in ("inner", "left", "outer"): - raise ValueError( - "join type must be one of 'inner', 'left' or 'outer' (value provided:" - f" {how})" - ) - - new_left = left.copy() - new_left.preverb_hook("join", right, on, how, validate=validate) - - if set(new_left.named_cols.fwd.values()) & set(right.named_cols.fwd.values()): - raise ValueError( - f"{how} join of `{left.name}` and `{right.name}` failed: " - f"duplicate columns detected. If you want to do a self-join or join a " - f"table twice, use `alias` on one table before the join." - ) - - if suffix is not None: - # check that the user-provided suffix does not lead to collisions - if collisions := set(new_left.named_cols.fwd.keys()) & set( - name + suffix for name in right.named_cols.fwd.keys() - ): - raise ValueError( - f"{how} join of `{left.name}` and `{right.name}` failed: " - f"using the suffix `{suffix}` for right columns, the following column " - f"names appear both in the left and right table: {collisions}" - ) - else: - # try `_{right.name}`, then `_{right.name}1`, `_{right.name}2` and so on - cnt = 0 - suffix = "_" + right.name - for rname in right.named_cols.fwd.keys(): - while rname + suffix in new_left.named_cols.fwd.keys(): - cnt += 1 - suffix = "_" + right.name + str(cnt) - - new_left.selects |= {name + suffix for name in right.selects} - new_left.named_cols.fwd.update( - {name + suffix: uuid for name, uuid in right.named_cols.fwd.items()} - ) - new_left.available_cols.update(right.available_cols) - new_left.cols.update(right.cols) - - # By resolving lambdas this late, we enable the user to use lambda columns - # to reference mutated columns from the right side of the join. - # -> `C.columnname_righttablename` is a valid lambda in the on condition. - check_lambdas_valid(new_left, on) - on = new_left.resolve_lambda_cols(on) - - new_left.join(right, on, how, validate=validate) - return new_left - - -inner_join = functools.partial(join, how="inner") -left_join = functools.partial(join, how="left") -outer_join = functools.partial(join, how="outer") - - -@builtin_verb() -def filter(tbl: AbstractTableImpl, *args: SymbolicExpression): - # TODO: Type check expression - validate_table_args(tbl) - check_cols_available(tbl, cols_in_expressions(args), "filter") - args = [tbl.resolve_lambda_cols(arg) for arg in args] - - new_tbl = tbl.copy() - new_tbl.preverb_hook("filter", *args) - new_tbl.filter(*args) - return new_tbl - - -@builtin_verb() -def arrange(tbl: AbstractTableImpl, *args: Column | LambdaColumn): - if len(args) == 0: - return tbl - - # Validate Input - validate_table_args(tbl) - check_cols_available(tbl, cols_in_expressions(args), "arrange") - check_lambdas_valid(tbl, *args) - - ordering = translate_ordering(tbl, args) - - new_tbl = tbl.copy() - new_tbl.preverb_hook("arrange", *args) - new_tbl.arrange(ordering) - return new_tbl - - -@builtin_verb() -def group_by(tbl: AbstractTableImpl, *args: Column | LambdaColumn, add=False): - # Validate Input - validate_table_args(tbl) - check_cols_available(tbl, cols_in_expressions(args), "group_by") - check_lambdas_valid(tbl, *args) - - # WARNING: Depending on the SQL backend, you might - # only be allowed to reference columns - if not args: - raise ValueError( - "Expected columns to group by, but none were specified. To remove the" - " grouping use the ungroup verb instead." - ) - for col in args: - if not isinstance(col, (Column, LambdaColumn)): - raise TypeError( - "Arguments to group_by verb must be of type 'Column' or 'LambdaColumn'" - f" and not '{type(col)}'." - ) - - args = [tbl.resolve_lambda_cols(arg) for arg in args] - - new_tbl = tbl.copy() - new_tbl.preverb_hook("group_by", *args, add=add) - if add: - new_tbl.grouped_by |= ordered_set(args) - else: - new_tbl.grouped_by = ordered_set(args) - new_tbl.group_by(*args) - return new_tbl - - -@builtin_verb() -def ungroup(tbl: AbstractTableImpl): - """Remove all groupings from table.""" - validate_table_args(tbl) - - new_tbl = tbl.copy() - new_tbl.preverb_hook("ungroup") - new_tbl.grouped_by.clear() - new_tbl.ungroup() - return new_tbl - - -@builtin_verb() -def summarise(tbl: AbstractTableImpl, **kwargs: SymbolicExpression): - # Validate Input - validate_table_args(tbl) - check_cols_available(tbl, cols_in_expressions(kwargs.values()), "summarise") - - new_tbl = tbl.copy() - new_tbl.preverb_hook("summarise", **kwargs) - kwargs = {k: new_tbl.resolve_lambda_cols(v) for k, v in kwargs.items()} - - # TODO: Validate that the functions are actually aggregating functions. - ... - - # Calculate state for new table - selects = ordered_set() - named_cols = bidict() - available_cols = set() - cols = {} - - # Add grouping cols to beginning of select. - for col in tbl.grouped_by: - selects.add(tbl.named_cols.bwd[col.uuid]) - available_cols.add(col.uuid) - named_cols.fwd[col.name] = col.uuid - - # Add summarizing cols to the end of the select. - for name, expr in kwargs.items(): - if name in selects: - raise ValueError( - f"Column with name '{name}' already in select. The new summarised" - " columns must have a different name than the grouping columns." - ) - uid = Column.generate_col_uuid() - col = ColumnMetaData.from_expr(uid, expr, new_tbl, verb="summarise") - - if dtypes.NoneDType().same_kind(col.dtype): - raise ExpressionTypeError( - f"Column '{name}' has an invalid type: {col.dtype}" - ) - if col.ftype != OPType.AGGREGATE: - raise FunctionTypeError( - f"Expression for column '{name}' doesn't summarise any values." - ) - - selects.add(name) - named_cols.fwd[name] = uid - available_cols.add(uid) - cols[uid] = col - - # Update new_tbl - new_tbl.selects = ordered_set(selects) - new_tbl.named_cols = named_cols - new_tbl.available_cols = available_cols - new_tbl.cols.update(cols) - new_tbl.intrinsic_grouped_by = new_tbl.grouped_by.copy() - new_tbl.summarise(**kwargs) - - # Reduce the grouping level by one -> drop last - if len(new_tbl.grouped_by): - new_tbl.grouped_by.pop_back() - - if len(new_tbl.grouped_by): - new_tbl.group_by(*new_tbl.grouped_by) - else: - new_tbl.ungroup() - - return new_tbl - - -@builtin_verb() -def slice_head(tbl: AbstractTableImpl, n: int, *, offset: int = 0): - validate_table_args(tbl) - if not isinstance(n, int): - raise TypeError("'n' must be an int") - if not isinstance(offset, int): - raise TypeError("'offset' must be an int") - if n <= 0: - raise ValueError(f"'n' must be a positive integer (value: {n})") - if offset < 0: - raise ValueError(f"'offset' can't be negative (value: {offset})") - - if tbl.grouped_by: - raise ValueError("Can't slice table that is grouped. Must ungroup first.") - - new_tbl = tbl.copy() - new_tbl.preverb_hook("slice_head") - new_tbl.slice_head(n, offset) - return new_tbl diff --git a/src/pydiverse/transform/errors/__init__.py b/src/pydiverse/transform/errors/__init__.py index 107dbe2e..8e71df7c 100644 --- a/src/pydiverse/transform/errors/__init__.py +++ b/src/pydiverse/transform/errors/__init__.py @@ -1,39 +1,18 @@ from __future__ import annotations -class OperatorNotSupportedError(Exception): - """ - Exception raised when a specific operation is not supported by a backend. - """ - - -class ExpressionError(Exception): - """ - Generic exception related to an invalid expression. - """ - - -class ExpressionTypeError(ExpressionError): +class DataTypeError(Exception): """ Exception related to invalid types in an expression """ -class FunctionTypeError(ExpressionError): +class FunctionTypeError(Exception): """ Exception related to function type """ -class AlignmentError(Exception): - """ - Raised when something isn't aligned. - """ - - -# WARNINGS - - class NonStandardBehaviourWarning(UserWarning): """ Category for when a specific backend deviates from diff --git a/src/pydiverse/transform/ops/core.py b/src/pydiverse/transform/ops/core.py index ff987d7e..790d8010 100644 --- a/src/pydiverse/transform/ops/core.py +++ b/src/pydiverse/transform/ops/core.py @@ -5,10 +5,10 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from pydiverse.transform.core.registry import OperatorSignature + from pydiverse.transform.tree.registry import OperatorSignature __all__ = [ - "OPType", + "Ftype", "Operator", "OperatorExtension", "Arity", @@ -22,7 +22,7 @@ ] -class OPType(enum.IntEnum): +class Ftype(enum.IntEnum): EWISE = 1 AGGREGATE = 2 WINDOW = 3 @@ -55,7 +55,7 @@ class Operator: """ name: str = NotImplemented - ftype: OPType = NotImplemented + ftype: Ftype = NotImplemented signatures: list[str] = None context_kwargs: set[str] = None @@ -134,21 +134,21 @@ class Binary(Arity): class ElementWise(Operator): - ftype = OPType.EWISE + ftype = Ftype.EWISE class Aggregate(Operator): - ftype = OPType.AGGREGATE + ftype = Ftype.AGGREGATE context_kwargs = { - "partition_by", # list[Column, LambdaColumn] + "partition_by", # list[Col] "filter", # SymbolicExpression (NOT a list) } class Window(Operator): - ftype = OPType.WINDOW + ftype = Ftype.WINDOW context_kwargs = { - "arrange", # list[Column | LambdaColumn] + "arrange", # list[Col] "partition_by", } diff --git a/src/pydiverse/transform/ops/datetime.py b/src/pydiverse/transform/ops/datetime.py index 39b7b9a2..e7788791 100644 --- a/src/pydiverse/transform/ops/datetime.py +++ b/src/pydiverse/transform/ops/datetime.py @@ -30,18 +30,18 @@ class DtExtract(ElementWise, Unary): class DateExtract(ElementWise, Unary): - signatures = ["date -> int"] + signatures = ["datetime -> int", "date -> int"] -class DtYear(DtExtract, DateExtract): +class DtYear(DateExtract): name = "dt.year" -class DtMonth(DtExtract, DateExtract): +class DtMonth(DateExtract): name = "dt.month" -class DtDay(DtExtract, DateExtract): +class DtDay(DateExtract): name = "dt.day" @@ -61,11 +61,11 @@ class DtMillisecond(DtExtract): name = "dt.millisecond" -class DtDayOfWeek(DtExtract, DateExtract): +class DtDayOfWeek(DateExtract): name = "dt.day_of_week" -class DtDayOfYear(DtExtract, DateExtract): +class DtDayOfYear(DateExtract): name = "dt.day_of_year" diff --git a/src/pydiverse/transform/ops/logical.py b/src/pydiverse/transform/ops/logical.py index 1af64df9..f4c1ef9c 100644 --- a/src/pydiverse/transform/ops/logical.py +++ b/src/pydiverse/transform/ops/logical.py @@ -1,7 +1,7 @@ from __future__ import annotations -from pydiverse.transform.core import dtypes from pydiverse.transform.ops.core import Binary, ElementWise, Operator, Unary +from pydiverse.transform.tree import dtypes __all__ = [ "Equal", @@ -95,7 +95,7 @@ class IsIn(ElementWise, Logical): name = "isin" signatures = [ # TODO: A signature like "T, const list[const T] -> bool" would be better - "T, const T... -> bool", + "T, T... -> bool", ] diff --git a/src/pydiverse/transform/ops/markers.py b/src/pydiverse/transform/ops/markers.py index 2b1ed462..621498bb 100644 --- a/src/pydiverse/transform/ops/markers.py +++ b/src/pydiverse/transform/ops/markers.py @@ -2,19 +2,24 @@ from pydiverse.transform.ops.core import Marker -__all__ = [ - "NullsFirst", - "NullsLast", -] +__all__ = ["NullsFirst", "NullsLast", "Ascending", "Descending"] -# Mark order-by column that it should be ordered with NULLs first class NullsFirst(Marker): name = "nulls_first" signatures = ["T -> T"] -# Mark order-by column that it should be ordered with NULLs last class NullsLast(Marker): name = "nulls_last" signatures = ["T -> T"] + + +class Ascending(Marker): + name = "ascending" + signatures = ["T -> T"] + + +class Descending(Marker): + name = "descending" + signatures = ["T -> T"] diff --git a/src/pydiverse/transform/pipe/c.py b/src/pydiverse/transform/pipe/c.py new file mode 100644 index 00000000..70f9bcee --- /dev/null +++ b/src/pydiverse/transform/pipe/c.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from pydiverse.transform.tree.col_expr import ColName + + +class MC(type): + def __getattr__(cls, name: str) -> ColName: + return ColName(name) + + def __getitem__(cls, name: str) -> ColName: + return ColName(name) + + +class C(metaclass=MC): + pass diff --git a/src/pydiverse/transform/pipe/functions.py b/src/pydiverse/transform/pipe/functions.py new file mode 100644 index 00000000..6837ae0d --- /dev/null +++ b/src/pydiverse/transform/pipe/functions.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from collections.abc import Iterable + +from pydiverse.transform.tree import dtypes +from pydiverse.transform.tree.col_expr import ( + ColExpr, + ColFn, + WhenClause, + wrap_literal, +) + +__all__ = ["count", "row_number", "rank", "when", "dense_rank", "min", "max"] + + +def clean_kwargs(**kwargs) -> dict[str, list[ColExpr]]: + return {key: wrap_literal(val) for key, val in kwargs.items() if val is not None} + + +def when(condition: ColExpr) -> WhenClause: + if condition.dtype() is not None and not isinstance(condition.dtype(), dtypes.Bool): + raise TypeError( + "argument for `when` must be of boolean type, but has type " + f"`{condition.dtype()}`" + ) + + return WhenClause([], wrap_literal(condition)) + + +def count( + expr: ColExpr | None = None, + *, + filter: ColExpr | Iterable[ColExpr] | None = None, # noqa: A002 +): + if expr is None: + return ColFn("count", **clean_kwargs(filter=filter)) + else: + return ColFn("count", wrap_literal(expr)) + + +def row_number( + *, + arrange: ColExpr | Iterable[ColExpr], + partition_by: ColExpr | list[ColExpr] | None = None, +): + return ColFn( + "row_number", **clean_kwargs(arrange=arrange, partition_by=partition_by) + ) + + +def rank( + *, + arrange: ColExpr | Iterable[ColExpr], + partition_by: ColExpr | Iterable[ColExpr] | None = None, +): + return ColFn("rank", **clean_kwargs(arrange=arrange, partition_by=partition_by)) + + +def dense_rank( + *, + arrange: ColExpr | Iterable[ColExpr], + partition_by: ColExpr | Iterable[ColExpr] | None = None, +): + return ColFn( + "dense_rank", **clean_kwargs(arrange=arrange, partition_by=partition_by) + ) + + +def min(arg: ColExpr, *additional_args: ColExpr): + return ColFn("__least", wrap_literal(arg), *wrap_literal(additional_args)) + + +def max(arg: ColExpr, *additional_args: ColExpr): + return ColFn("__greatest", wrap_literal(arg), *wrap_literal(additional_args)) diff --git a/src/pydiverse/transform/pipe/pipeable.py b/src/pydiverse/transform/pipe/pipeable.py new file mode 100644 index 00000000..dd8beeb6 --- /dev/null +++ b/src/pydiverse/transform/pipe/pipeable.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from functools import partial, reduce, wraps + + +class Pipeable: + def __init__(self, f=None, calls=None): + if f is not None: + if calls is not None: + raise ValueError + self.calls = [f] + else: + self.calls = calls + + def __rshift__(self, other) -> Pipeable: + """ + Pipeable >> other + -> Lazy. Extend pipe. + """ + if isinstance(other, Pipeable): + return Pipeable(calls=self.calls + other.calls) + elif callable(other): + return Pipeable(calls=self.calls + [other]) + + raise RuntimeError + + def __rrshift__(self, other): + """ + other >> Pipeable + -> Eager. + """ + if callable(other): + return Pipeable(calls=[other] + self.calls) + return self(other) + + def __call__(self, arg): + return reduce(lambda x, f: f(x), self.calls, arg) + + +class inverse_partial(partial): + """ + Just like partial, but the arguments get applied to the back instead of the front. + This means that a function `def x(a, b, c)` decorated with `@inverse_partial(1, 2)` + that gets called with `x(0)` is equivalent to calling `x(0, 1, 2)` on the non + decorated function. + """ + + def __call__(self, /, *args, **keywords): + keywords = {**self.keywords, **keywords} + # ↙ *args moved to front. + return self.func(*args, *self.args, **keywords) + + +# TODO: validate that the first arg is a table here + + +def verb(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + return Pipeable(inverse_partial(fn, *args, **kwargs)) + + return wrapper + + +def builtin_verb(backends=None): + def decorator(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + return Pipeable(inverse_partial(fn, *args, **kwargs)) + + return wrapper + + return decorator diff --git a/src/pydiverse/transform/pipe/table.py b/src/pydiverse/transform/pipe/table.py new file mode 100644 index 00000000..503d1dbb --- /dev/null +++ b/src/pydiverse/transform/pipe/table.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import copy +import dataclasses +from collections.abc import Iterable +from html import escape + +import sqlalchemy as sqa + +from pydiverse.transform.backend.table_impl import TableImpl +from pydiverse.transform.tree.ast import AstNode +from pydiverse.transform.tree.col_expr import ( + Col, + ColExpr, +) + + +# TODO: if we decide that select controls the C-space, the columns in _select will +# always be the same as those that we have to keep in _schema. However, we still need +# _select for the order. +class Table: + __slots__ = ["_ast", "_cache"] + + """ + All attributes of a table are columns except for the `_ast` attribute + which is a reference to the underlying abstract syntax tree. + """ + + # TODO: define exactly what can be given for the two + def __init__(self, resource, backend=None, *, name: str | None = None): + import polars as pl + + from pydiverse.transform.backend import ( + PolarsImpl, + SqlAlchemy, + SqlImpl, + ) + + if isinstance(resource, TableImpl): + self._ast: AstNode = resource + elif isinstance(resource, (pl.DataFrame, pl.LazyFrame)): + if name is None: + name = "?" + self._ast = PolarsImpl(name, resource) + elif isinstance(resource, (str, sqa.Table)): + if isinstance(backend, SqlAlchemy): + self._ast = SqlImpl(resource, backend, name) + + if self._ast is None: + raise AssertionError + + self._cache = Cache(self._ast.cols, list(self._ast.cols.values()), []) + + def __getitem__(self, key: str) -> Col: + if not isinstance(key, str): + raise TypeError( + f"argument to __getitem__ (bracket `[]` operator) on a Table must be a " + f"str, got {type(key)} instead." + ) + if (col := self._cache.cols.get(key)) is None: + raise ValueError( + f"column `{key}` does not exist in table `{self._ast.name}`" + ) + return col + + def __getattr__(self, name: str) -> Col: + if name in ("__copy__", "__deepcopy__", "__setstate__", "__getstate__"): + # for hasattr to work correctly on dunder methods + raise AttributeError + if (col := self._cache.cols.get(name)) is None: + raise ValueError( + f"column `{name}` does not exist in table `{self._ast.name}`" + ) + return col + + def __setstate__(self, d): # to avoid very annoying AttributeErrors + for slot, val in d[1].items(): + setattr(self, slot, val) + + def __iter__(self) -> Iterable[ColExpr]: + cols = copy.copy(self._cache.select) + yield from cols + + def __len__(self) -> int: + return len(self._cache.select) + + def __str__(self): + try: + from pydiverse.transform.backend.targets import Polars + from pydiverse.transform.pipe.verbs import export + + return ( + f"Table: {self.name}, backend: {type(self._impl).__name__}\n" + f"{self >> export(Polars())}" + ) + except Exception as e: + return ( + f"Table: {self.name}, backend: {type(self._impl).__name__}\n" + "failed to collect table due to an exception. " + f"{type(e).__name__}: {str(e)}" + ) + + def _repr_html_(self) -> str | None: + html = ( + f"Table {self.name} using" + f" {type(self._impl).__name__} backend:
" + ) + try: + from pydiverse.transform.backend.targets import Polars + from pydiverse.transform.pipe.verbs import export + + # TODO: For lazy backend only show preview (eg. take first 20 rows) + html += (self >> export(Polars()))._repr_html_() + except Exception as e: + html += ( + "
Failed to collect table due to an exception:\n"
+                f"{escape(e.__class__.__name__)}: {escape(str(e))}
" + ) + return html + + def _repr_pretty_(self, p, cycle): + p.text(str(self) if not cycle else "...") + + +@dataclasses.dataclass(slots=True) +class Cache: + cols: dict[str, Col] + select: list[Col] + partition_by: list[Col] diff --git a/src/pydiverse/transform/pipe/verbs.py b/src/pydiverse/transform/pipe/verbs.py new file mode 100644 index 00000000..8348645e --- /dev/null +++ b/src/pydiverse/transform/pipe/verbs.py @@ -0,0 +1,464 @@ +from __future__ import annotations + +import copy +import uuid +from collections.abc import Iterable +from typing import Any + +from pydiverse.transform.backend.table_impl import TableImpl +from pydiverse.transform.backend.targets import Target +from pydiverse.transform.errors import FunctionTypeError +from pydiverse.transform.ops.core import Ftype +from pydiverse.transform.pipe.pipeable import builtin_verb +from pydiverse.transform.pipe.table import Table +from pydiverse.transform.tree import dtypes, verbs +from pydiverse.transform.tree.ast import AstNode +from pydiverse.transform.tree.col_expr import ( + Col, + ColExpr, + ColFn, + ColName, + Order, + wrap_literal, +) +from pydiverse.transform.tree.verbs import ( + Arrange, + Filter, + GroupBy, + Join, + JoinHow, + JoinValidate, + Mutate, + Rename, + Select, + SliceHead, + Summarise, + Ungroup, + Verb, +) + +__all__ = [ + "alias", + "collect", + "build_query", + "show_query", + "select", + "drop", + "rename", + "mutate", + "join", + "left_join", + "inner_join", + "outer_join", + "filter", + "arrange", + "group_by", + "ungroup", + "summarise", + "slice_head", + "export", +] + + +@builtin_verb() +def alias(table: Table, new_name: str | None = None): + if new_name is None: + new_name = table._ast.name + new = copy.copy(table) + new._ast, nd_map, uuid_map = table._ast._clone() + new._ast.name = new_name + new._cache = copy.copy(table._cache) + + new._cache.cols = { + name: Col(name, nd_map[col._ast], uuid_map[col._uuid], col._dtype, col._ftype) + for name, col in table._cache.cols.items() + } + new._cache.partition_by = [ + Col(col.name, nd_map[col._ast], uuid_map[col._uuid], col._dtype, col._ftype) + for col in table._cache.partition_by + ] + new._cache.select = [ + Col(col.name, nd_map[col._ast], uuid_map[col._uuid], col._dtype, col._ftype) + for col in table._cache.select + ] + return new + + +@builtin_verb() +def collect(table: Table) -> Table: ... + + +@builtin_verb() +def export(table: Table, target: Target): + check_table_references(table._ast) + table = table >> alias() + SourceBackend: type[TableImpl] = get_backend(table._ast) + return SourceBackend.export(table._ast, target, table._cache.select) + + +@builtin_verb() +def build_query(table: Table) -> str: + check_table_references(table._ast) + table = table >> alias() + SourceBackend: type[TableImpl] = get_backend(table._ast) + return SourceBackend.build_query(table._ast, table._cache.select) + + +@builtin_verb() +def show_query(table: Table): + if query := table >> build_query(): + print(query) + else: + print(f"no query to show for {table._ast.name}") + + return table + + +@builtin_verb() +def select(table: Table, *args: Col | ColName): + new = copy.copy(table) + new._ast = Select(table._ast, preprocess_arg(args, table)) + new._cache = copy.copy(table._cache) + new._cache.select = new._ast.select + return new + + +@builtin_verb() +def drop(table: Table, *args: Col | ColName): + dropped_uuids = {col._uuid for col in preprocess_arg(args, table)} + return select( + table, + *(col for col in table._cache.cols.values() if col._uuid not in dropped_uuids), + ) + + +@builtin_verb() +def rename(table: Table, name_map: dict[str, str]): + if not isinstance(name_map, dict): + raise TypeError("`name_map` argument to `rename` must be a dict") + if len(name_map) == 0: + return table + + new = copy.copy(table) + new._ast = Rename(table._ast, name_map) + new._cache = copy.copy(table._cache) + new._cache.cols = copy.copy(table._cache.cols) + + for name, _ in name_map.items(): + if name not in new._cache.cols: + raise ValueError( + f"no column with name `{name}` in table `{table._ast.name}`" + ) + del new._cache.cols[name] + + for name, replacement in name_map.items(): + if replacement in new._cache.cols: + raise ValueError(f"duplicate column name `{replacement}`") + new._cache.cols[replacement] = table._cache.cols[name] + + return new + + +@builtin_verb() +def mutate(table: Table, **kwargs: ColExpr): + if len(kwargs) == 0: + return table + + new = copy.copy(table) + new._ast = Mutate( + table._ast, + list(kwargs.keys()), + preprocess_arg(kwargs.values(), table), + [uuid.uuid1() for _ in kwargs.keys()], + ) + + new._cache = copy.copy(table._cache) + new._cache.cols = copy.copy(table._cache.cols) + for name, val, uid in zip(new._ast.names, new._ast.values, new._ast.uuids): + new._cache.cols[name] = Col( + name, new._ast, uid, val.dtype(), val.ftype(agg_is_window=True) + ) + + overwritten = { + col_name for col_name in new._ast.names if col_name in new._cache.cols + } + new._cache.select = [ + col for col in table._cache.select if col.name not in overwritten + ] + [new._cache.cols[name] for name in new._ast.names] + + return new + + +@builtin_verb() +def filter(table: Table, *predicates: ColExpr): + if len(predicates) == 0: + return table + + new = copy.copy(table) + new._ast = Filter(table._ast, preprocess_arg(predicates, table)) + + for cond in new._ast.filters: + if not isinstance(cond.dtype(), dtypes.Bool): + raise TypeError( + "predicates given to `filter` must be of boolean type.\n" + f"hint: {cond} is of type {cond.dtype()} instead." + ) + + return new + + +@builtin_verb() +def arrange(table: Table, *order_by: ColExpr): + if len(order_by) == 0: + return table + + new = copy.copy(table) + new._ast = Arrange( + table._ast, + preprocess_arg((Order.from_col_expr(ord) for ord in order_by), table), + ) + + return new + + +@builtin_verb() +def group_by(table: Table, *cols: Col | ColName, add=False): + if len(cols) == 0: + return table + + new = copy.copy(table) + new._ast = GroupBy(table._ast, preprocess_arg(cols, table), add) + new._cache = copy.copy(table._cache) + if add: + new._cache.partition_by = table._cache.partition_by + new._ast.group_by + else: + new._cache.partition_by = new._ast.group_by + + return new + + +@builtin_verb() +def ungroup(table: Table): + new = copy.copy(table) + new._ast = Ungroup(table._ast) + new._cache = copy.copy(table._cache) + new._cache.partition_by = [] + return new + + +@builtin_verb() +def summarise(table: Table, **kwargs: ColExpr): + new = copy.copy(table) + new._ast = Summarise( + table._ast, + list(kwargs.keys()), + preprocess_arg(kwargs.values(), table, update_partition_by=False), + [uuid.uuid1() for _ in kwargs.keys()], + ) + + partition_by_uuids = {col._uuid for col in table._cache.partition_by} + + def check_summarise_col_expr(expr: ColExpr, agg_fn_above: bool): + if ( + isinstance(expr, Col) + and expr._uuid not in partition_by_uuids + and not agg_fn_above + ): + raise FunctionTypeError( + f"column `{expr}` is neither aggregated nor part of the grouping " + "columns." + ) + + elif isinstance(expr, ColFn): + if expr.ftype(agg_is_window=False) == Ftype.WINDOW: + raise FunctionTypeError( + f"forbidden window function `{expr.name}` in `summarise`" + ) + elif expr.ftype(agg_is_window=False) == Ftype.AGGREGATE: + agg_fn_above = True + + for child in expr.iter_children(): + check_summarise_col_expr(child, agg_fn_above) + + for root in new._ast.values: + check_summarise_col_expr(root, False) + + new._cache = copy.copy(table._cache) + new._cache.cols = table._cache.cols | { + name: Col(name, new._ast, uid, val.dtype(), val.ftype(agg_is_window=False)) + for name, val, uid in zip(new._ast.names, new._ast.values, new._ast.uuids) + } + + new._cache.select = table._cache.partition_by + [ + new._cache.cols[name] for name in new._ast.names + ] + new._cache.partition_by = [] + + return new + + +@builtin_verb() +def slice_head(table: Table, n: int, *, offset: int = 0): + if table._cache.partition_by: + raise ValueError("cannot apply `slice_head` to a grouped table") + + new = copy.copy(table) + new._ast = SliceHead(table._ast, n, offset) + return new + + +@builtin_verb() +def join( + left: Table, + right: Table, + on: ColExpr, + how: JoinHow, + *, + validate: JoinValidate = "m:m", + suffix: str | None = None, # appended to cols of the right table +): + if left._cache.partition_by: + raise ValueError(f"cannot join grouped table `{left._ast.name}`") + elif right._cache.partition_by: + raise ValueError(f"cannot join grouped table `{right._ast.name}`") + + # TODO: more sophisticated resolution + if suffix is None and right._ast.name: + suffix = f"_{right._ast.name}" + if suffix is None: + suffix = "_right" + + new = copy.copy(left) + new._ast = Join( + left._ast, right._ast, preprocess_arg(on, left), how, validate, suffix + ) + new._cache = copy.copy(left._cache) + new._cache.cols = left._cache.cols | { + name + suffix: col for name, col in right._cache.cols.items() + } + new._cache.select = left._cache.select + right._cache.select + + return new + + +@builtin_verb() +def inner_join( + left: Table, + right: Table, + on: ColExpr, + *, + validate: JoinValidate = "m:m", + suffix: str | None = None, +): + return left >> join(right, on, "inner", validate=validate, suffix=suffix) + + +@builtin_verb() +def left_join( + left: Table, + right: Table, + on: ColExpr, + *, + validate: JoinValidate = "m:m", + suffix: str | None = None, +): + return left >> join(right, on, "left", validate=validate, suffix=suffix) + + +@builtin_verb() +def outer_join( + left: Table, + right: Table, + on: ColExpr, + *, + validate: JoinValidate = "m:m", + suffix: str | None = None, +): + return left >> join(right, on, "outer", validate=validate, suffix=suffix) + + +def preprocess_arg(arg: Any, table: Table, *, update_partition_by: bool = True) -> Any: + if isinstance(arg, dict): + return { + key: preprocess_arg(val, table, update_partition_by=update_partition_by) + for key, val in arg.items() + } + if isinstance(arg, Iterable) and not isinstance(arg, str): + return [ + preprocess_arg(elem, table, update_partition_by=update_partition_by) + for elem in arg + ] + if isinstance(arg, Order): + return Order( + preprocess_arg( + arg.order_by, table, update_partition_by=update_partition_by + ), + arg.descending, + arg.nulls_last, + ) + else: + arg = wrap_literal(arg) + assert isinstance(arg, ColExpr) + + arg = arg.map_subtree( + lambda col: col if not isinstance(col, ColName) else table[col.name] + ) + + if not update_partition_by: + return arg + + from pydiverse.transform.backend.polars import PolarsImpl + + for desc in arg.iter_subtree(): + if ( + isinstance(desc, ColFn) + and "partition_by" not in desc.context_kwargs + and ( + PolarsImpl.registry.get_op(desc.name).ftype + in (Ftype.WINDOW, Ftype.AGGREGATE) + ) + ): + desc.context_kwargs["partition_by"] = table._cache.partition_by + + return arg + + +def get_backend(nd: AstNode) -> type[TableImpl]: + if isinstance(nd, Verb): + return get_backend(nd.child) + assert isinstance(nd, TableImpl) and nd is not TableImpl + return nd.__class__ + + +# checks whether there are duplicate tables and whether all cols used in expressions +# are from descendants +def check_table_references(nd: AstNode) -> set[AstNode]: + if isinstance(nd, verbs.Verb): + subtree = check_table_references(nd.child) + + if isinstance(nd, verbs.Join): + right_tables = check_table_references(nd.right) + if intersection := subtree & right_tables: + raise ValueError( + f"table `{list(intersection)[0]}` occurs twice in the table " + "tree.\nhint: To join two tables derived from a common table, " + "apply `>> alias()` to one of them before the join." + ) + + if len(right_tables) > len(subtree): + subtree, right_tables = right_tables, subtree + subtree |= right_tables + + for col in nd.iter_col_nodes(): + if isinstance(col, Col) and col._ast not in subtree: + raise ValueError( + f"table `{col._ast.name}` referenced via column `{col}` cannot be " + "used at this point. It The current table is not derived " + "from it." + ) + + subtree.add(nd) + return subtree + + else: + return {nd} diff --git a/src/pydiverse/transform/polars/polars_table.py b/src/pydiverse/transform/polars/polars_table.py deleted file mode 100644 index fc02b4c0..00000000 --- a/src/pydiverse/transform/polars/polars_table.py +++ /dev/null @@ -1,826 +0,0 @@ -from __future__ import annotations - -import functools -import itertools -import operator -import uuid -from typing import Any, Callable, Literal - -import polars as pl - -from pydiverse.transform import ops -from pydiverse.transform.core import dtypes -from pydiverse.transform.core.expressions.expressions import ( - BaseExpression, - CaseExpression, - Column, - FunctionCall, - LiteralColumn, -) -from pydiverse.transform.core.expressions.symbolic_expressions import SymbolicExpression -from pydiverse.transform.core.expressions.translator import ( - Translator, - TypedValue, -) -from pydiverse.transform.core.registry import TypedOperatorImpl -from pydiverse.transform.core.table_impl import AbstractTableImpl -from pydiverse.transform.core.util import OrderingDescriptor -from pydiverse.transform.core.util.util import translate_ordering -from pydiverse.transform.errors import ( - AlignmentError, - ExpressionError, - FunctionTypeError, -) -from pydiverse.transform.ops.core import OPType - - -class PolarsEager(AbstractTableImpl): - def __init__(self, name: str, df: pl.DataFrame): - self.df = df - self.join_translator = JoinTranslator() - - cols = { - col.name: Column(col.name, self, _pdt_dtype(col.dtype)) - for col in df.iter_columns() - } - self.underlying_col_name: dict[uuid.UUID, str] = { - col.uuid: f"{name}_{col.name}_{col.uuid.int}" for col in cols.values() - } - self.df = self.df.rename( - {col.name: self.underlying_col_name[col.uuid] for col in cols.values()} - ) - super().__init__(name, cols) - - def mutate(self, **kwargs): - uuid_to_kwarg: dict[uuid.UUID, (str, BaseExpression)] = { - self.named_cols.fwd[k]: (k, v) for (k, v) in kwargs.items() - } - self.underlying_col_name.update( - { - uuid: f"{self.name}_{col_name}_mut_{uuid.int}" - for uuid, (col_name, _) in uuid_to_kwarg.items() - } - ) - - polars_exprs = [ - self.cols[uuid].compiled().alias(self.underlying_col_name[uuid]) - for uuid in uuid_to_kwarg.keys() - ] - self.df = self.df.with_columns(*polars_exprs) - - def join( - self, - right: PolarsEager, - on: SymbolicExpression, - how: Literal["inner", "left", "outer"], - *, - validate: Literal["1:1", "1:m", "m:1", "m:m"] = "m:m", - ): - # get the columns on which the data frames are joined - left_on: list[str] = [] - right_on: list[str] = [] - for col1, col2 in self.join_translator.translate(on): - if col2.uuid in self.cols and col1.uuid in right.cols: - col1, col2 = col2, col1 - assert col1.uuid in self.cols and col2.uuid in right.cols - left_on.append(self.underlying_col_name[col1.uuid]) - right_on.append(right.underlying_col_name[col2.uuid]) - - self.underlying_col_name.update(right.underlying_col_name) - - self.df = self.df.join( - right.df, - how=how, - left_on=left_on, - right_on=right_on, - validate=validate, - coalesce=False, - ) - - def filter(self, *args: SymbolicExpression): - if not args: - return - pl_expr, dtype = self.compiler.translate(functools.reduce(operator.and_, args)) - assert isinstance(dtype, dtypes.Bool) - self.df = self.df.filter(pl_expr()) - - def alias(self, new_name: str | None = None): - new_name = new_name or self.name - return self.__class__(new_name, self.export()) - - def arrange(self, ordering: list[OrderingDescriptor]): - self.df = self.df.sort( - by=[self.compiler.translate(o.order).value() for o in ordering], - nulls_last=[not o.nulls_first for o in ordering], - descending=[not o.asc for o in ordering], - ) - - def summarise(self, **kwargs: SymbolicExpression): - uuid_to_kwarg: dict[uuid.UUID, (str, BaseExpression)] = { - self.named_cols.fwd[k]: (k, v) for (k, v) in kwargs.items() - } - self.underlying_col_name.update( - { - uuid: f"{self.name}_{col_name}_summarise_{uuid.int}" - for uuid, (col_name, _) in uuid_to_kwarg.items() - } - ) - - agg_exprs: list[pl.Expr] = [ - self.cols[uuid].compiled().alias(self.underlying_col_name[uuid]) - for uuid in uuid_to_kwarg.keys() - ] - group_exprs: list[pl.Expr] = [ - pl.col(self.underlying_col_name[col.uuid]) for col in self.grouped_by - ] - - if self.grouped_by: - # retain the cols the table was grouped by and add the aggregation cols - self.df = self.df.group_by(*group_exprs).agg(*agg_exprs) - else: - self.df = self.df.select(*agg_exprs) - - def export(self) -> pl.DataFrame: - return self.df.select( - **{ - name: self.underlying_col_name[uuid] - for (name, uuid) in self.selected_cols() - } - ) - - def slice_head(self, n: int, offset: int): - self.df = self.df.slice(offset, n) - - def is_aligned_with(self, col: Column | LiteralColumn) -> bool: - if isinstance(col, Column): - return ( - isinstance(col.table, type(self)) - and col.table.df.height == self.df.height - ) - if isinstance(col, LiteralColumn): - return issubclass(col.backend, type(self)) and ( - not isinstance(col.typed_value.value, pl.Series) - or len(col.typed_value.value) == self.df.height - ) # not a series => scalar - - class ExpressionCompiler( - AbstractTableImpl.ExpressionCompiler[ - "PolarsEager", TypedValue[Callable[[], pl.Expr]] - ] - ): - def _translate_col( - self, col: Column, **kwargs - ) -> TypedValue[Callable[[], pl.Expr]]: - def value(): - return pl.col(self.backend.underlying_col_name[col.uuid]) - - return TypedValue(value, col.dtype) - - def _translate_literal_col( - self, col: LiteralColumn, **kwargs - ) -> TypedValue[Callable[[], pl.Expr]]: - if not self.backend.is_aligned_with(col): - raise AlignmentError( - f"literal column {col} not aligned with table {self.backend.name}." - ) - - def value(**kw): - return col.typed_value.value - - return TypedValue(value, col.typed_value.dtype, col.typed_value.ftype) - - def _translate_function( - self, - implementation: TypedOperatorImpl, - op_args: list[TypedValue[Callable[[], pl.Expr]]], - context_kwargs: dict[str, Any], - *, - verb: str | None = None, - **kwargs, - ) -> TypedValue[Callable[[], pl.Expr]]: - pl_result_type = _pl_dtype(implementation.rtype) - - internal_kwargs = {} - - op = implementation.operator - ftype = ( - OPType.WINDOW - if op.ftype == OPType.AGGREGATE and verb != "summarise" - else op.ftype - ) - - grouping = context_kwargs.get("partition_by") - # the `partition_by=` grouping overrides the `group_by` grouping - if grouping is not None: # translate possible lambda cols - grouping = [self.backend.resolve_lambda_cols(col) for col in grouping] - else: # use the current grouping of the table - grouping = self.backend.grouped_by - - ordering = context_kwargs.get("arrange") - if ordering: - ordering = translate_ordering(self.backend, ordering) - by = [self._translate(o.order).value() for o in ordering] - descending = [not o.asc for o in ordering] - nulls_last = [not o.nulls_first for o in ordering] - - filter_cond = context_kwargs.get("filter") - if filter_cond: - filter_cond = self.translate( - self.backend.resolve_lambda_cols(filter_cond) - ) - - args: list[Callable[[], pl.Expr]] = [arg.value for arg in op_args] - dtypes: list[dtypes.DType] = [arg.dtype for arg in op_args] - if ftype == OPType.WINDOW and ordering and not grouping: - # order the args. if the table is grouped by group_by or - # partition_by=, the groups will be sorted via over(order_by=) - # anyways so it need not be done here. - def ordered_arg(arg): - return arg().sort_by( - by=by, descending=descending, nulls_last=nulls_last - ) - - args = [ - arg if dtype.const else functools.partial(ordered_arg, arg) - for arg, dtype in zip(args, dtypes) - ] - - if ftype in (OPType.WINDOW, OPType.AGGREGATE) and filter_cond: - # filtering needs to be done before applying the operator. We filter - # all non-constant arguments, although there should always be only - # one of these. - def filtered_value(value): - return value().filter(filter_cond.value()) - - assert len(list(filter(lambda arg: not arg.dtype.const, op_args))) == 1 - args = [ - arg if dtype.const else functools.partial(filtered_value, arg) - for arg, dtype in zip(args, dtypes) - ] - - if op.name in ("rank", "dense_rank"): - assert len(args) == 0 - args = [ - functools.partial( - lambda ordering: pl.struct( - *self.backend._merge_desc_nulls_last(ordering) - ), - ordering, - ) - ] - ordering = None - - def value(**kw): - return implementation( - *[arg(**kw) for arg in args], - _tbl=self.backend, - _result_type=pl_result_type, - **internal_kwargs, - ) - - if ftype == OPType.AGGREGATE: - if context_kwargs.get("filter"): - # TODO: allow AGGRRGATE + `filter` context_kwarg - raise NotImplementedError - - if context_kwargs.get("partition_by"): - # technically, it probably wouldn't be too hard to support this in - # polars. - assert verb == "summarise" - raise ValueError( - f"cannot use keyword argument `partition_by` for the " - f"aggregation function `{op.name}` inside `summarise`." - ) - - # TODO: in the grouping / filter expressions, we should probably call - # validate_table_args. look what it does and use it. - # TODO: what happens if I put None or similar in a filter / partition_by? - if ftype == OPType.WINDOW: - if verb == "summarise": - raise FunctionTypeError( - "window function are not allowed inside summarise" - ) - - # if `verb` != "muatate", we should give a warning that this only works - # for polars - - if grouping: - # when doing sort_by -> over in polars, for whatever reason the - # `nulls_last` argument is ignored. thus when both a grouping and an - # arrangment are specified, we manually add the descending and - # nulls_last markers to the ordering. - order_by = None - if ordering: - order_by = self.backend._merge_desc_nulls_last(ordering) - - def partitioned_value(value): - group_exprs: list[pl.Expr] = [ - pl.col(self.backend.underlying_col_name[col.uuid]) - for col in grouping - ] - return value().over(*group_exprs, order_by=order_by) - - value = functools.partial(partitioned_value, value) - - elif ordering: - if op.ftype == OPType.AGGREGATE: - # TODO: don't fail, but give a warning that `arrange` is useless - # here - ... - - # the function was executed on the ordered arguments. here we - # restore the original order of the table. - def sorted_value(value): - inv_permutation = pl.int_range( - 0, pl.len(), dtype=pl.Int64 - ).sort_by( - by=by, - descending=descending, - nulls_last=nulls_last, - ) - return value().sort_by(inv_permutation) - - # need to bind `value` inside `filtered_value` so that it refers to - # the original `value`. - value = functools.partial(sorted_value, value) - - return TypedValue( - value, - implementation.rtype, - PolarsEager._get_op_ftype( - op_args, - op, - OPType.WINDOW - if op.ftype == OPType.AGGREGATE and verb != "summarise" - else None, - ), - ) - - def _translate_case( - self, - expr: CaseExpression, - switching_on: TypedValue[Callable[[], pl.Expr]] | None, - cases: list[ - tuple[ - TypedValue[Callable[[], pl.Expr]], TypedValue[Callable[[], pl.Expr]] - ] - ], - default: TypedValue[Callable[[], pl.Expr]], - **kwargs, - ) -> TypedValue[Callable[[], pl.Expr]]: - def value(): - if switching_on is not None: - switching_on_v = switching_on.value() - conds = [ - match_expr.value() == switching_on_v for match_expr, _ in cases - ] - else: - conds = [case[0].value() for case in cases] - - pl_expr = pl.when(conds[0]).then(cases[0][1].value()) - for cond, (_, value) in zip(conds[1:], cases[1:]): - pl_expr = pl_expr.when(cond).then(value.value()) - return pl_expr.otherwise(default.value()) - - result_dtype, result_ftype = self._translate_case_common( - expr, switching_on, cases, default, **kwargs - ) - - return TypedValue(value, result_dtype, result_ftype) - - def _translate_literal_value(self, expr): - def value(): - return pl.lit(expr) - - return value - - class AlignedExpressionEvaluator( - AbstractTableImpl.AlignedExpressionEvaluator[TypedValue[pl.Series]] - ): - def _translate_col(self, col: Column, **kwargs) -> TypedValue[pl.Series]: - return TypedValue( - col.table.df.get_column(col.table.underlying_col_name[col.uuid]), - col.table.cols[col.uuid].dtype, - ) - - def _translate_literal_col( - self, expr: LiteralColumn, **kwargs - ) -> TypedValue[pl.Series]: - return expr.typed_value - - def _translate_function( - self, - implementation: TypedOperatorImpl, - op_args: list[TypedValue[pl.Series]], - context_kwargs: dict[str, Any], - **kwargs, - ) -> TypedValue[pl.Series]: - args = [arg.value for arg in op_args] - op = implementation.operator - - arg_lens = {arg.len() for arg in args if isinstance(arg, pl.Series)} - if len(arg_lens) >= 2: - raise AlignmentError( - f"arguments for function {implementation.operator.name} are not " - f"aligned. they have lengths {list(arg_lens)} but all lengths must " - f"be equal." - ) - - value = implementation(*args) - - return TypedValue( - value, - implementation.rtype, - PolarsEager._get_op_ftype( - op_args, op, OPType.WINDOW if op.ftype == OPType.AGGREGATE else None - ), - ) - - # merges descending and null_last markers into the ordering expression - def _merge_desc_nulls_last( - self, ordering: list[OrderingDescriptor] - ) -> list[pl.Expr]: - with_signs = [] - for o in ordering: - numeric = ( - self.compiler.translate(o.order).value().rank("dense").cast(pl.Int64) - ) - with_signs.append(numeric if o.asc else -numeric) - return [ - x.fill_null( - -(pl.len().cast(pl.Int64) + 1) - if o.nulls_first - else pl.len().cast(pl.Int64) + 1 - ) - for x, o in zip(with_signs, ordering) - ] - - -class JoinTranslator(Translator[tuple]): - """ - This translator takes a conjunction (AND) of equality checks and returns - a tuple of tuple where the inner tuple contains the left and right column - of the equality checks. - """ - - def _translate(self, expr, **kwargs): - if isinstance(expr, Column): - return expr - if isinstance(expr, FunctionCall): - if expr.name == "__eq__": - c1 = expr.args[0] - c2 = expr.args[1] - assert isinstance(c1, Column) and isinstance(c2, Column) - return ((c1, c2),) - if expr.name == "__and__": - return tuple(itertools.chain(*expr.args)) - raise ExpressionError( - f"invalid ON clause element: {expr}. only a conjunction of equalities" - " is supported" - ) - - -def _pdt_dtype(t: pl.DataType) -> dtypes.DType: - if t.is_float(): - return dtypes.Float() - elif t.is_integer(): - return dtypes.Int() - elif isinstance(t, pl.Boolean): - return dtypes.Bool() - elif isinstance(t, pl.String): - return dtypes.String() - elif isinstance(t, pl.Datetime): - return dtypes.DateTime() - elif isinstance(t, pl.Date): - return dtypes.Date() - elif isinstance(t, pl.Duration): - return dtypes.Duration() - - raise TypeError(f"polars type {t} is not supported") - - -def _pl_dtype(t: dtypes.DType) -> pl.DataType: - if isinstance(t, dtypes.Float): - return pl.Float64() - elif isinstance(t, dtypes.Int): - return pl.Int64() - elif isinstance(t, dtypes.Bool): - return pl.Boolean() - elif isinstance(t, dtypes.String): - return pl.String() - elif isinstance(t, dtypes.DateTime): - return pl.Datetime() - elif isinstance(t, dtypes.Date): - return pl.Date() - elif isinstance(t, dtypes.Duration): - return pl.Duration() - - raise TypeError(f"pydiverse.transform type {t} not supported for polars") - - -with PolarsEager.op(ops.Mean()) as op: - - @op.auto - def _mean(x): - return x.mean() - - -with PolarsEager.op(ops.Min()) as op: - - @op.auto - def _min(x): - return x.min() - - -with PolarsEager.op(ops.Max()) as op: - - @op.auto - def _max(x): - return x.max() - - -with PolarsEager.op(ops.Sum()) as op: - - @op.auto - def _sum(x): - return x.sum() - - -with PolarsEager.op(ops.All()) as op: - - @op.auto - def _all(x): - return x.all() - - -with PolarsEager.op(ops.Any()) as op: - - @op.auto - def _any(x): - return x.any() - - -with PolarsEager.op(ops.IsNull()) as op: - - @op.auto - def _is_null(x): - return x.is_null() - - -with PolarsEager.op(ops.IsNotNull()) as op: - - @op.auto - def _is_not_null(x): - return x.is_not_null() - - -with PolarsEager.op(ops.FillNull()) as op: - - @op.auto - def _fill_null(x, y): - return x.fill_null(y) - - -with PolarsEager.op(ops.DtYear()) as op: - - @op.auto - def _dt_year(x): - return x.dt.year() - - -with PolarsEager.op(ops.DtMonth()) as op: - - @op.auto - def _dt_month(x): - return x.dt.month() - - -with PolarsEager.op(ops.DtDay()) as op: - - @op.auto - def _dt_day(x): - return x.dt.day() - - -with PolarsEager.op(ops.DtHour()) as op: - - @op.auto - def _dt_hour(x): - return x.dt.hour() - - -with PolarsEager.op(ops.DtMinute()) as op: - - @op.auto - def _dt_minute(x): - return x.dt.minute() - - -with PolarsEager.op(ops.DtSecond()) as op: - - @op.auto - def _dt_second(x): - return x.dt.second() - - -with PolarsEager.op(ops.DtMillisecond()) as op: - - @op.auto - def _dt_millisecond(x): - return x.dt.millisecond() - - -with PolarsEager.op(ops.DtDayOfWeek()) as op: - - @op.auto - def _dt_day_of_week(x): - return x.dt.weekday() - - -with PolarsEager.op(ops.DtDayOfYear()) as op: - - @op.auto - def _dt_day_of_year(x): - return x.dt.ordinal_day() - - -with PolarsEager.op(ops.DtDays()) as op: - - @op.auto - def _days(x): - return x.dt.total_days() - - -with PolarsEager.op(ops.DtHours()) as op: - - @op.auto - def _hours(x): - return x.dt.total_hours() - - -with PolarsEager.op(ops.DtMinutes()) as op: - - @op.auto - def _minutes(x): - return x.dt.total_minutes() - - -with PolarsEager.op(ops.DtSeconds()) as op: - - @op.auto - def _seconds(x): - return x.dt.total_seconds() - - -with PolarsEager.op(ops.DtMilliseconds()) as op: - - @op.auto - def _milliseconds(x): - return x.dt.total_milliseconds() - - -with PolarsEager.op(ops.Sub()) as op: - - @op.extension(ops.DtSub) - def _dt_sub(lhs, rhs): - return lhs - rhs - - -with PolarsEager.op(ops.RSub()) as op: - - @op.extension(ops.DtRSub) - def _dt_rsub(rhs, lhs): - return lhs - rhs - - -with PolarsEager.op(ops.Add()) as op: - - @op.extension(ops.DtDurAdd) - def _dt_dur_add(lhs, rhs): - return lhs + rhs - - -with PolarsEager.op(ops.RAdd()) as op: - - @op.extension(ops.DtDurRAdd) - def _dt_dur_radd(rhs, lhs): - return lhs + rhs - - -with PolarsEager.op(ops.RowNumber()) as op: - - @op.auto - def _row_number(): - return pl.int_range(start=1, end=pl.len() + 1, dtype=pl.Int64) - - -with PolarsEager.op(ops.Rank()) as op: - - @op.auto - def _rank(x): - return x.rank("min").cast(pl.Int64) - - -with PolarsEager.op(ops.DenseRank()) as op: - - @op.auto - def _dense_rank(x): - return x.rank("dense").cast(pl.Int64) - - -with PolarsEager.op(ops.Shift()) as op: - - @op.auto - def _shift(x, n, fill_value=None): - return x.shift(n, fill_value=fill_value) - - -with PolarsEager.op(ops.IsIn()) as op: - - @op.auto - def _isin(x, *values): - return x.is_in([pl.select(v).item() for v in values]) - - -with PolarsEager.op(ops.StrContains()) as op: - - @op.auto - def _contains(x, y): - return x.str.contains(y) - - -with PolarsEager.op(ops.StrStartsWith()) as op: - - @op.auto - def _starts_with(x, y): - return x.str.starts_with(y) - - -with PolarsEager.op(ops.StrEndsWith()) as op: - - @op.auto - def _ends_with(x, y): - return x.str.ends_with(y) - - -with PolarsEager.op(ops.StrToLower()) as op: - - @op.auto - def _lower(x): - return x.str.to_lowercase() - - -with PolarsEager.op(ops.StrToUpper()) as op: - - @op.auto - def _upper(x): - return x.str.to_uppercase() - - -with PolarsEager.op(ops.StrReplaceAll()) as op: - - @op.auto - def _replace_all(x, to_replace, replacement): - return x.str.replace_all(to_replace, replacement) - - -with PolarsEager.op(ops.StrLen()) as op: - - @op.auto - def _string_length(x): - return x.str.len_chars().cast(pl.Int64) - - -with PolarsEager.op(ops.StrStrip()) as op: - - @op.auto - def _str_strip(x): - return x.str.strip_chars() - - -with PolarsEager.op(ops.StrSlice()) as op: - - @op.auto - def _str_slice(x, offset, length): - return x.str.slice(offset, length) - - -with PolarsEager.op(ops.Count()) as op: - - @op.auto - def _count(x=None): - return pl.len() if x is None else x.count() - - -with PolarsEager.op(ops.Greatest()) as op: - - @op.auto - def _greatest(*x): - return pl.max_horizontal(*x) - - -with PolarsEager.op(ops.Least()) as op: - - @op.auto - def _least(*x): - return pl.min_horizontal(*x) diff --git a/src/pydiverse/transform/sql/duckdb.py b/src/pydiverse/transform/sql/duckdb.py deleted file mode 100644 index 5b0f0fb3..00000000 --- a/src/pydiverse/transform/sql/duckdb.py +++ /dev/null @@ -1,7 +0,0 @@ -from __future__ import annotations - -from pydiverse.transform.sql.sql_table import SQLTableImpl - - -class DuckDBTableImpl(SQLTableImpl): - _dialect_name = "duckdb" diff --git a/src/pydiverse/transform/sql/mssql.py b/src/pydiverse/transform/sql/mssql.py deleted file mode 100644 index c5cc2e92..00000000 --- a/src/pydiverse/transform/sql/mssql.py +++ /dev/null @@ -1,324 +0,0 @@ -from __future__ import annotations - -import sqlalchemy as sa - -from pydiverse.transform import ops -from pydiverse.transform._typing import CallableT -from pydiverse.transform.core import dtypes -from pydiverse.transform.core.expressions import TypedValue -from pydiverse.transform.core.expressions.expressions import Column -from pydiverse.transform.core.registry import TypedOperatorImpl -from pydiverse.transform.core.util import OrderingDescriptor -from pydiverse.transform.ops import Operator, OPType -from pydiverse.transform.sql.sql_table import SQLTableImpl -from pydiverse.transform.util.warnings import warn_non_standard - - -class MSSqlTableImpl(SQLTableImpl): - _dialect_name = "mssql" - - def _build_select_select(self, select): - s = [] - for name, uuid_ in self.selected_cols(): - sql_col = self.cols[uuid_].compiled(self.sql_columns) - if not isinstance(sql_col, sa.sql.ColumnElement): - sql_col = sa.literal(sql_col) - if dtypes.Bool().same_kind(self.cols[uuid_].dtype): - # Make sure that any boolean values get stored as bit - sql_col = sa.cast(sql_col, sa.Boolean()) - s.append(sql_col.label(name)) - return select.with_only_columns(*s) - - def _order_col( - self, col: sa.SQLColumnExpression, ordering: OrderingDescriptor - ) -> list[sa.SQLColumnExpression]: - # MSSQL doesn't support nulls first / nulls last - order_by_expressions = [] - - # asc implies nulls first - if not ordering.nulls_first and ordering.asc: - order_by_expressions.append(sa.func.iif(col.is_(None), 1, 0)) - - # desc implies nulls last - if ordering.nulls_first and not ordering.asc: - order_by_expressions.append(sa.func.iif(col.is_(None), 0, 1)) - - order_by_expressions.append(col.asc() if ordering.asc else col.desc()) - return order_by_expressions - - class ExpressionCompiler(SQLTableImpl.ExpressionCompiler): - def translate(self, expr, **kwargs): - mssql_bool_as_bit = True - if verb := kwargs.get("verb"): - mssql_bool_as_bit = verb not in ("filter", "join") - - return super().translate( - expr, **kwargs, mssql_bool_as_bit=mssql_bool_as_bit - ) - - def _translate(self, expr, **kwargs): - if context := kwargs.get("context"): - if context == "case_val": - kwargs["mssql_bool_as_bit"] = True - elif context == "case_cond": - kwargs["mssql_bool_as_bit"] = False - - return super()._translate(expr, **kwargs) - - def _translate_col(self, col: Column, **kwargs): - # If mssql_bool_as_bit is true, then we can just return the - # precompiled col. Otherwise, we must recompile it to ensure - # we return booleans as bools and not as bits. - if kwargs.get("mssql_bool_as_bit") is True: - return super()._translate_col(col, **kwargs) - - # Can either be a base SQL column, or a reference to an expression - if col.uuid in self.backend.sql_columns: - is_bool = dtypes.Bool().same_kind(self.backend.cols[col.uuid].dtype) - - def sql_col(cols, **kw): - sql_col = cols[col.uuid] - if is_bool: - return mssql_convert_bit_to_bool(sql_col) - return sql_col - - return TypedValue(sql_col, col.dtype, OPType.EWISE) - - meta_data = self.backend.cols[col.uuid] - return self._translate(meta_data.expr, **kwargs) - - def _translate_function_value( - self, implementation, op_args, context_kwargs, *, verb=None, **kwargs - ): - value = super()._translate_function_value( - implementation, - op_args, - context_kwargs, - verb=verb, - **kwargs, - ) - - bool_as_bit = kwargs.get("mssql_bool_as_bit") - returns_bool_as_bit = mssql_op_returns_bool_as_bit(implementation) - return mssql_convert_bool_bit_value(value, bool_as_bit, returns_bool_as_bit) - - def _translate_function_arguments(self, expr, operator, **kwargs): - kwargs["mssql_bool_as_bit"] = mssql_op_wants_bool_as_bit(operator) - return super()._translate_function_arguments(expr, operator, **kwargs) - - -# Boolean / Bit Conversion -# -# MSSQL doesn't have a boolean type. This means that expressions that -# return a boolean (e.g. ==, !=, >) can't be used in other expressions -# without casting to the BIT type. -# Conversely, after casting to BIT, we sometimes may need to convert -# back to booleans. - - -def mssql_op_wants_bool_as_bit(operator: Operator) -> bool: - # These operations want boolean types (not BIT) as input - exceptions = [ - ops.logical.BooleanBinary, - ops.logical.Invert, - ] - - for exception in exceptions: - if isinstance(operator, exception): - return False - - return True - - -def mssql_op_returns_bool_as_bit(implementation: TypedOperatorImpl) -> bool | None: - if not dtypes.Bool().same_kind(implementation.rtype): - return None - - # These operations return boolean types (not BIT) - if isinstance(implementation.operator, ops.logical.Logical): - return False - - return True - - -def mssql_convert_bit_to_bool(x: sa.SQLColumnExpression): - return x == sa.literal_column("1") - - -def mssql_convert_bool_to_bit(x: sa.SQLColumnExpression): - return sa.case( - (x, sa.literal_column("1")), - (sa.not_(x), sa.literal_column("0")), - ) - - -def mssql_convert_bool_bit_value( - value_func: CallableT, - wants_bool_as_bit: bool | None, - is_bool_as_bit: bool | None, -) -> CallableT: - if wants_bool_as_bit is True and is_bool_as_bit is False: - - def value(*args, **kwargs): - x = value_func(*args, **kwargs) - return mssql_convert_bool_to_bit(x) - - return value - - if wants_bool_as_bit is False and is_bool_as_bit is True: - - def value(*args, **kwargs): - x = value_func(*args, **kwargs) - return mssql_convert_bit_to_bool(x) - - return value - - return value_func - - -# Operators - - -with MSSqlTableImpl.op(ops.Equal()) as op: - - @op("str, str -> bool") - def _eq(x, y): - warn_non_standard( - "MSSQL ignores trailing whitespace when comparing strings", - ) - return x == y - - -with MSSqlTableImpl.op(ops.NotEqual()) as op: - - @op("str, str -> bool") - def _ne(x, y): - warn_non_standard( - "MSSQL ignores trailing whitespace when comparing strings", - ) - return x != y - - -with MSSqlTableImpl.op(ops.Less()) as op: - - @op("str, str -> bool") - def _lt(x, y): - warn_non_standard( - "MSSQL ignores trailing whitespace when comparing strings", - ) - return x < y - - -with MSSqlTableImpl.op(ops.LessEqual()) as op: - - @op("str, str -> bool") - def _le(x, y): - warn_non_standard( - "MSSQL ignores trailing whitespace when comparing strings", - ) - return x <= y - - -with MSSqlTableImpl.op(ops.Greater()) as op: - - @op("str, str -> bool") - def _gt(x, y): - warn_non_standard( - "MSSQL ignores trailing whitespace when comparing strings", - ) - return x > y - - -with MSSqlTableImpl.op(ops.GreaterEqual()) as op: - - @op("str, str -> bool") - def _ge(x, y): - warn_non_standard( - "MSSQL ignores trailing whitespace when comparing strings", - ) - return x >= y - - -with MSSqlTableImpl.op(ops.Pow()) as op: - - @op.auto - def _pow(lhs, rhs): - # In MSSQL, the output type of pow is the same as the input type. - # This means, that if lhs is a decimal, then we may very easily loose - # a lot of precision if the exponent is <= 1 - # https://learn.microsoft.com/en-us/sql/t-sql/functions/power-transact-sql?view=sql-server-ver16 - return sa.func.POWER(sa.cast(lhs, sa.Double()), rhs, type_=sa.Double()) - - -with MSSqlTableImpl.op(ops.RPow()) as op: - - @op.auto - def _rpow(rhs, lhs): - return _pow(lhs, rhs) - - -with MSSqlTableImpl.op(ops.StrLen()) as op: - - @op.auto - def _str_length(x): - warn_non_standard( - "MSSQL ignores trailing whitespace when computing string length", - ) - return sa.func.LENGTH(x, type_=sa.Integer()) - - -with MSSqlTableImpl.op(ops.StrReplaceAll()) as op: - - @op.auto - def _replace(x, y, z): - x = x.collate("Latin1_General_CS_AS") - return sa.func.REPLACE(x, y, z, type_=x.type) - - -with MSSqlTableImpl.op(ops.StrStartsWith()) as op: - - @op.auto - def _startswith(x, y): - x = x.collate("Latin1_General_CS_AS") - return x.startswith(y, autoescape=True) - - -with MSSqlTableImpl.op(ops.StrEndsWith()) as op: - - @op.auto - def _endswith(x, y): - x = x.collate("Latin1_General_CS_AS") - return x.endswith(y, autoescape=True) - - -with MSSqlTableImpl.op(ops.StrContains()) as op: - - @op.auto - def _contains(x, y): - x = x.collate("Latin1_General_CS_AS") - return x.contains(y, autoescape=True) - - -with MSSqlTableImpl.op(ops.StrSlice()) as op: - - @op.auto - def _str_slice(x, offset, length): - return sa.func.SUBSTRING(x, offset + 1, length) - - -with MSSqlTableImpl.op(ops.DtDayOfWeek()) as op: - - @op.auto - def _day_of_week(x): - # Offset DOW such that Mon=1, Sun=7 - _1 = sa.literal_column("1") - _2 = sa.literal_column("2") - _7 = sa.literal_column("7") - return (sa.extract("dow", x) + sa.text("@@DATEFIRST") - _2) % _7 + _1 - - -with MSSqlTableImpl.op(ops.Mean()) as op: - - @op.auto - def _mean(x): - return sa.func.AVG(sa.cast(x, sa.Double()), type_=sa.Double()) diff --git a/src/pydiverse/transform/sql/postgres.py b/src/pydiverse/transform/sql/postgres.py deleted file mode 100644 index a5c3bbb0..00000000 --- a/src/pydiverse/transform/sql/postgres.py +++ /dev/null @@ -1,120 +0,0 @@ -from __future__ import annotations - -import sqlalchemy as sa - -from pydiverse.transform import ops -from pydiverse.transform.sql.sql_table import SQLTableImpl - - -class PostgresTableImpl(SQLTableImpl): - _dialect_name = "postgresql" - - -with PostgresTableImpl.op(ops.Less()) as op: - - @op("str, str -> bool") - def _lt(x, y): - return x < y.collate("POSIX") - - -with PostgresTableImpl.op(ops.LessEqual()) as op: - - @op("str, str -> bool") - def _le(x, y): - return x <= y.collate("POSIX") - - -with PostgresTableImpl.op(ops.Greater()) as op: - - @op("str, str -> bool") - def _gt(x, y): - return x > y.collate("POSIX") - - -with PostgresTableImpl.op(ops.GreaterEqual()) as op: - - @op("str, str -> bool") - def _ge(x, y): - return x >= y.collate("POSIX") - - -with PostgresTableImpl.op(ops.Round()) as op: - - @op.auto - def _round(x, decimals=0): - if decimals == 0: - if isinstance(x.type, sa.Integer): - return x - return sa.func.ROUND(x, type_=x.type) - - if isinstance(x.type, sa.Float): - # Postgres doesn't support rounding of doubles to specific precision - # -> Must first cast to numeric - return sa.func.ROUND(sa.cast(x, sa.Numeric), decimals, type_=sa.Numeric) - - return sa.func.ROUND(x, decimals, type_=x.type) - - -with PostgresTableImpl.op(ops.DtSecond()) as op: - - @op.auto - def _second(x): - return sa.func.FLOOR(sa.extract("second", x), type_=sa.Integer()) - - -with PostgresTableImpl.op(ops.DtMillisecond()) as op: - - @op.auto - def _millisecond(x): - _1000 = sa.literal_column("1000") - return sa.func.FLOOR(sa.extract("milliseconds", x) % _1000, type_=sa.Integer()) - - -with PostgresTableImpl.op(ops.Greatest()) as op: - - @op("str... -> str") - def _greatest(*x): - # TODO: Determine return type - return sa.func.GREATEST(*(e.collate("POSIX") for e in x)) - - -with PostgresTableImpl.op(ops.Least()) as op: - - @op("str... -> str") - def _least(*x): - # TODO: Determine return type - return sa.func.LEAST(*(e.collate("POSIX") for e in x)) - - -with PostgresTableImpl.op(ops.Any()) as op: - - @op.auto - def _any(x, *, _window_partition_by=None, _window_order_by=None): - return sa.func.coalesce(sa.func.BOOL_OR(x, type_=sa.Boolean()), sa.false()) - - @op.auto(variant="window") - def _any(x, *, _window_partition_by=None, _window_order_by=None): - return sa.func.coalesce( - sa.func.BOOL_OR(x, type_=sa.Boolean()).over( - partition_by=_window_partition_by, - order_by=_window_order_by, - ), - sa.false(), - ) - - -with PostgresTableImpl.op(ops.All()) as op: - - @op.auto - def _all(x): - return sa.func.coalesce(sa.func.BOOL_AND(x, type_=sa.Boolean()), sa.false()) - - @op.auto(variant="window") - def _all(x, *, _window_partition_by=None, _window_order_by=None): - return sa.func.coalesce( - sa.func.BOOL_AND(x, type_=sa.Boolean()).over( - partition_by=_window_partition_by, - order_by=_window_order_by, - ), - sa.false(), - ) diff --git a/src/pydiverse/transform/sql/sql_table.py b/src/pydiverse/transform/sql/sql_table.py deleted file mode 100644 index 52a5b2c7..00000000 --- a/src/pydiverse/transform/sql/sql_table.py +++ /dev/null @@ -1,1257 +0,0 @@ -from __future__ import annotations - -import functools -import inspect -import itertools -import operator as py_operator -import uuid -import warnings -from collections.abc import Iterable -from dataclasses import dataclass -from functools import reduce -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal - -import polars as pl -import sqlalchemy as sa -from sqlalchemy import sql - -from pydiverse.transform import ops -from pydiverse.transform._typing import ImplT -from pydiverse.transform.core import dtypes -from pydiverse.transform.core.expressions import ( - Column, - LiteralColumn, - SymbolicExpression, - iterate_over_expr, -) -from pydiverse.transform.core.expressions.translator import TypedValue -from pydiverse.transform.core.table_impl import AbstractTableImpl, ColumnMetaData -from pydiverse.transform.core.util import OrderingDescriptor, translate_ordering -from pydiverse.transform.errors import AlignmentError, FunctionTypeError -from pydiverse.transform.ops import OPType - -if TYPE_CHECKING: - from pydiverse.transform.core.registry import TypedOperatorImpl - - -class SQLTableImpl(AbstractTableImpl): - """SQL backend - - Attributes: - tbl: The underlying SQLAlchemy table object. - engine: The SQLAlchemy engine. - sql_columns: A dict mapping from uuids to SQLAlchemy column objects - (only those contained in `tbl`). - - alignment_hash: A hash value that allows checking if two tables are - 'aligned'. In the case of SQL this means that two tables NUST NOT - share the same alignment hash unless they were derived from the - same table(s) and are guaranteed to have the same number of columns - in the same order. In other words: Two tables MUST only have the - same alignment hash if a literal column derived from one of them - can be used by the other table and produces the same result. - """ - - __registered_dialects: dict[str, type[SQLTableImpl]] = {} - _dialect_name: str - - def __new__(cls, *args, **kwargs): - if cls != SQLTableImpl or (not args and not kwargs): - return super().__new__(cls) - - signature = inspect.signature(cls.__init__) - engine = signature.bind(None, *args, **kwargs).arguments["engine"] - - # If calling SQLTableImpl(engine), then we want to dynamically instantiate - # the correct dialect specific subclass based on the engine dialect. - if isinstance(engine, str): - dialect = sa.engine.make_url(engine).get_dialect().name - else: - dialect = engine.dialect.name - - dialect_specific_cls = SQLTableImpl.__registered_dialects.get(dialect, cls) - return super(SQLTableImpl, dialect_specific_cls).__new__(dialect_specific_cls) - - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - - # Whenever a new subclass if SQLTableImpl is defined, it must contain the - # `_dialect_name` attribute. This allows us to dynamically instantiate it - # when calling SQLTableImpl(engine) based on the dialect name found - # in the engine url (see __new__). - dialect_name = getattr(cls, "_dialect_name", None) - if dialect_name is None: - raise ValueError( - "All subclasses of SQLTableImpl must have a `_dialect_name` attribute." - f" But {cls.__name__}._dialect_name is None." - ) - - if dialect_name in SQLTableImpl.__registered_dialects: - warnings.warn( - f"Already registered a SQLTableImpl for dialect {dialect_name}" - ) - SQLTableImpl.__registered_dialects[dialect_name] = cls - - def __init__( - self, - engine: sa.Engine | str, - table, - _dtype_hints: dict[str, dtypes.DType] = None, - ): - self.engine = sa.create_engine(engine) if isinstance(engine, str) else engine - tbl = self._create_table(table, self.engine) - - columns = { - col.name: Column( - name=col.name, - table=self, - dtype=self._get_dtype(col, hints=_dtype_hints), - ) - for col in tbl.columns - } - - self.replace_tbl(tbl, columns) - super().__init__(name=self.tbl.name, columns=columns) - - def is_aligned_with(self, col: Column | LiteralColumn) -> bool: - if isinstance(col, Column): - if not isinstance(col.table, type(self)): - return False - return col.table.alignment_hash == self.alignment_hash - - if isinstance(col, LiteralColumn): - return all( - self.is_aligned_with(atom) - for atom in iterate_over_expr(col.expr, expand_literal_col=True) - if isinstance(atom, Column) - ) - - raise ValueError - - @classmethod - def _html_repr_expr(cls, expr): - if isinstance(expr, sa.sql.ColumnElement): - return str(expr.compile(compile_kwargs={"literal_binds": True})) - return super()._html_repr_expr(expr) - - @staticmethod - def _create_table(tbl, engine=None): - """Return a sa.Table - - :param tbl: a sa.Table or string of form 'table_name' - or 'schema_name.table_name'. - """ - if isinstance(tbl, sa.sql.FromClause): - return tbl - - if not isinstance(tbl, str): - raise ValueError(f"tbl must be a sqlalchemy Table or string, but was {tbl}") - - schema, table_name = tbl.split(".") if "." in tbl else [None, tbl] - return sa.Table( - table_name, - sa.MetaData(), - schema=schema, - autoload_with=engine, - ) - - @staticmethod - def _get_dtype( - col: sa.Column, hints: dict[str, dtypes.DType] = None - ) -> dtypes.DType: - """Determine the dtype of a column. - - :param col: The sqlalchemy column object. - :param hints: In some situations sqlalchemy can't determine the dtype of - a column. Instead of throwing an exception we can use these type - hints as a fallback. - :return: Appropriate dtype string. - """ - - type_ = col.type - if isinstance(type_, sa.Integer): - return dtypes.Int() - if isinstance(type_, sa.Numeric): - return dtypes.Float() - if isinstance(type_, sa.String): - return dtypes.String() - if isinstance(type_, sa.Boolean): - return dtypes.Bool() - if isinstance(type_, sa.DateTime): - return dtypes.DateTime() - if isinstance(type_, sa.Date): - return dtypes.Date() - if isinstance(type_, sa.Interval): - return dtypes.Duration() - if isinstance(type_, sa.Time): - raise NotImplementedError("Unsupported type: Time") - - if hints is not None: - if dtype := hints.get(col.name): - return dtype - - raise NotImplementedError(f"Unsupported type: {type_}") - - def replace_tbl(self, new_tbl, columns: dict[str:Column]): - if isinstance(new_tbl, sql.Select): - # noinspection PyNoneFunctionAssignment - new_tbl = new_tbl.subquery() - - self.tbl = new_tbl - self.alignment_hash = generate_alignment_hash() - - self.sql_columns = { - col.uuid: self.tbl.columns[col.name] for col in columns.values() - } # from uuid to sqlalchemy column - - if hasattr(self, "cols"): - # TODO: Clean up... This feels a bit hacky - for col in columns.values(): - self.cols[col.uuid] = ColumnMetaData.from_expr(col.uuid, col, self) - if hasattr(self, "intrinsic_grouped_by"): - self.intrinsic_grouped_by.clear() - - self.joins: list[JoinDescriptor] = [] - self.wheres: list[SymbolicExpression] = [] - self.having: list[SymbolicExpression] = [] - self.order_bys: list[OrderingDescriptor] = [] - self.limit_offset: tuple[int, int] | None = None - - def build_select(self) -> sql.Select: - # Validate current state - if len(self.selects) == 0: - raise ValueError("Can't execute a SQL query without any SELECT statements.") - - # Start building query - select = self.tbl.select() - - # `select_from` is required if no table is explicitly referenced - # inside the SELECT. e.g. `SELECT COUNT(*) AS count` - select = select.select_from(self.tbl) - - # FROM - select = self._build_select_from(select) - - # WHERE - select = self._build_select_where(select) - - # GROUP BY - select = self._build_select_group_by(select) - - # HAVING - select = self._build_select_having(select) - - # LIMIT / OFFSET - select = self._build_select_limit_offset(select) - - # SELECT - select = self._build_select_select(select) - - # ORDER BY - select = self._build_select_order_by(select) - - return select - - def _build_select_from(self, select): - for join in self.joins: - compiled, _ = self.compiler.translate(join.on, verb="join") - on = compiled(self.sql_columns) - - select = select.join( - join.right.tbl, - onclause=on, - isouter=join.how != "inner", - full=join.how == "outer", - ) - - return select - - def _build_select_where(self, select): - if not self.wheres: - return select - - # Combine wheres using ands - combined_where = functools.reduce( - py_operator.and_, map(SymbolicExpression, self.wheres) - )._ - compiled, where_dtype = self.compiler.translate(combined_where, verb="filter") - assert isinstance(where_dtype, dtypes.Bool) - where = compiled(self.sql_columns) - return select.where(where) - - def _build_select_group_by(self, select): - if not self.intrinsic_grouped_by: - return select - - compiled_gb, group_by_dtypes = zip( - *( - self.compiler.translate(group_by, verb="group_by") - for group_by in self.intrinsic_grouped_by - ) - ) - group_bys = (compiled(self.sql_columns) for compiled in compiled_gb) - return select.group_by(*group_bys) - - def _build_select_having(self, select): - if not self.having: - return select - - # Combine havings using ands - combined_having = functools.reduce( - py_operator.and_, map(SymbolicExpression, self.having) - )._ - compiled, having_dtype = self.compiler.translate(combined_having, verb="filter") - assert isinstance(having_dtype, dtypes.Bool) - having = compiled(self.sql_columns) - return select.having(having) - - def _build_select_limit_offset(self, select): - if self.limit_offset is None: - return select - - limit, offset = self.limit_offset - return select.limit(limit).offset(offset) - - def _build_select_select(self, select): - # Convert self.selects to SQLAlchemy Expressions - s = [] - for name, uuid_ in self.selected_cols(): - sql_col = self.cols[uuid_].compiled(self.sql_columns) - if not isinstance(sql_col, sa.sql.ColumnElement): - sql_col = sa.literal(sql_col) - s.append(sql_col.label(name)) - return select.with_only_columns(*s) - - def _build_select_order_by(self, select): - if not self.order_bys: - return select - - o = [] - for o_by in self.order_bys: - compiled, _ = self.compiler.translate(o_by.order, verb="arrange") - col = compiled(self.sql_columns) - o.extend(self._order_col(col, o_by)) - - return select.order_by(*o) - - #### Verb Operations #### - - def preverb_hook(self, verb: str, *args, **kwargs) -> None: - def has_any_ftype_cols(ftypes: OPType | tuple[OPType, ...], cols: Iterable): - if isinstance(ftypes, OPType): - ftypes = (ftypes,) - return any( - self.cols[c.uuid].ftype in ftypes - for v in cols - for c in iterate_over_expr(self.resolve_lambda_cols(v)) - if isinstance(c, Column) - ) - - requires_subquery = False - clear_order = False - - if self.limit_offset is not None: - # The LIMIT / TOP clause is executed at the very end of the query. - # This means we must create a subquery for any verb that modifies - # the rows. - if verb in ( - "join", - "filter", - "arrange", - "group_by", - "summarise", - ): - requires_subquery = True - - if verb == "mutate": - # Window functions can't be nested, thus a subquery is required - requires_subquery |= has_any_ftype_cols(OPType.WINDOW, kwargs.values()) - elif verb == "filter": - # Window functions aren't allowed in where clause - requires_subquery |= has_any_ftype_cols(OPType.WINDOW, args) - elif verb == "summarise": - # The result of the aggregate is always ordered according to the - # grouping columns. We must clear the order_bys so that the order - # is consistent with eager execution. We can do this because aggregate - # functions are independent of the order. - clear_order = True - - # If the grouping level is different from the grouping level of the - # tbl object, or if on of the input columns is a window or aggregate - # function, we must make a subquery. - requires_subquery |= ( - bool(self.intrinsic_grouped_by) - and self.grouped_by != self.intrinsic_grouped_by - ) - requires_subquery |= has_any_ftype_cols( - (OPType.AGGREGATE, OPType.WINDOW), kwargs.values() - ) - - # TODO: It would be nice if this could be done without having to select all - # columns. As a potential challenge for the hackathon I propose a mean - # of even creating the subqueries lazyly. This means that we could - # perform some kind of query optimization before submitting the actual - # query. Eg: Instead of selecting all possible columns, only select - # those that actually get used. - if requires_subquery: - columns = { - name: self.cols[uuid].as_column(name, self) - for name, uuid in self.named_cols.fwd.items() - } - - original_selects = self.selects.copy() - self.selects |= columns.keys() - subquery = self.build_select() - - self.replace_tbl(subquery, columns) - self.selects = original_selects - - if clear_order: - self.order_bys.clear() - - def alias(self, name=None): - if name is None: - suffix = format(uuid.uuid1().int % 0x7FFFFFFF, "X") - name = f"{self.name}_{suffix}" - - # TODO: If the table has not been modified, a simple `.alias()` - # would produce nicer queries. - subquery = self.build_select().subquery(name=name) - # In some situations sqlalchemy fails to determine the datatype of a column. - # To circumvent this, we can pass on the information we know. - dtype_hints = { - name: self.cols[self.named_cols.fwd[name]].dtype for name in self.selects - } - - return self.__class__(self.engine, subquery, _dtype_hints=dtype_hints) - - def collect(self): - select = self.build_select() - with self.engine.connect() as conn: - try: - # TODO: check for which pandas versions this is needed: - # Temporary fix for pandas bug (https://github.com/pandas-dev/pandas/issues/35484) - # Taken from siuba - from pandas.io import sql as _pd_sql - - class _FixedSqlDatabase(_pd_sql.SQLDatabase): - def execute(self, *args, **kwargs): - return self.connectable.execute(*args, **kwargs) - - sql_db = _FixedSqlDatabase(conn) - result = sql_db.read_sql(select).convert_dtypes() - except AttributeError: - import pandas as pd - - result = pd.read_sql_query(select, con=conn) - - # Add metadata - result.attrs["name"] = self.name - return result - - def export(self): - with self.engine.connect() as conn: - if isinstance(self, DuckDBTableImpl): - result = pl.read_database(self.build_query(), connection=conn) - else: - result = pl.read_database(self.build_select(), connection=conn) - return result - - def build_query(self) -> str: - query = self.build_select() - return str( - query.compile( - dialect=self.engine.dialect, compile_kwargs={"literal_binds": True} - ) - ) - - def join( - self, - right: SQLTableImpl, - on: SymbolicExpression, - how: Literal["inner", "left", "outer"], - *, - validate: Literal["1:1", "1:m", "m:1", "m:m"] = "m:m", - ): - self.alignment_hash = generate_alignment_hash() - - # If right has joins already, merging them becomes extremely difficult - # This is because the ON clauses could contain NULL checks in which case - # the joins aren't associative anymore. - if right.joins: - raise ValueError( - "Can't automatically combine joins if the right side already contains a" - " JOIN clause." - ) - - if right.limit_offset is not None: - raise ValueError( - "The right table can't be sliced when performing a join." - " Wrap the right side in a subquery to fix this." - ) - - # TODO: Handle GROUP BY and SELECTS on left / right side - - # Combine the WHERE clauses - if how == "inner": - # Inner Join: The WHERES can be combined - self.wheres.extend(right.wheres) - elif how == "left": - # WHERES from right must go into the ON clause - on = reduce(py_operator.and_, (on, *right.wheres)) - elif how == "outer": - # For outer joins, the WHERE clause can't easily be merged. - # The best solution for now is to move them into a subquery. - if self.wheres: - raise ValueError( - "Filters can't precede outer joins. Wrap the left side in a" - " subquery to fix this." - ) - if right.wheres: - raise ValueError( - "Filters can't precede outer joins. Wrap the right side in a" - " subquery to fix this." - ) - - if validate != "m:m": - warnings.warn("SQL table backend ignores join validation argument.") - - descriptor = JoinDescriptor(right, on, how) - self.joins.append(descriptor) - - self.sql_columns.update(right.sql_columns) - - def filter(self, *args): - self.alignment_hash = generate_alignment_hash() - - if self.intrinsic_grouped_by: - for arg in args: - # If a condition involves only grouping columns, it can be - # moved into the wheres instead of the havings. - only_grouping_cols = all( - col in self.intrinsic_grouped_by - for col in iterate_over_expr(arg, expand_literal_col=True) - if isinstance(col, Column) - ) - - if only_grouping_cols: - self.wheres.append(arg) - else: - self.having.append(arg) - else: - self.wheres.extend(args) - - def arrange(self, ordering): - self.alignment_hash = generate_alignment_hash() - - # Merge order bys and remove duplicate columns - order_bys = [] - order_by_columns = set() - for o_by in ordering + self.order_bys: - if o_by.order in order_by_columns: - continue - order_bys.append(o_by) - order_by_columns.add(o_by.order) - - self.order_bys = order_bys - - def summarise(self, **kwargs): - self.alignment_hash = generate_alignment_hash() - - def slice_head(self, n: int, offset: int): - if self.limit_offset is None: - self.limit_offset = (n, offset) - else: - old_n, old_o = self.limit_offset - self.limit_offset = (min(abs(old_n - offset), n), old_o + offset) - - #### EXPRESSIONS #### - - def _order_col( - self, col: sa.SQLColumnExpression, ordering: OrderingDescriptor - ) -> list[sa.SQLColumnExpression]: - col = col.asc() if ordering.asc else col.desc() - col = col.nullsfirst() if ordering.nulls_first else col.nullslast() - return [col] - - class ExpressionCompiler( - AbstractTableImpl.ExpressionCompiler[ - "SQLTableImpl", - TypedValue[Callable[[dict[uuid.UUID, sa.Column]], sql.ColumnElement]], - ] - ): - def _translate_col(self, col, **kwargs): - # Can either be a base SQL column, or a reference to an expression - if col.uuid in self.backend.sql_columns: - - def sql_col(cols, **kw): - return cols[col.uuid] - - return TypedValue(sql_col, col.dtype, OPType.EWISE) - - meta_data = self.backend.cols[col.uuid] - return TypedValue(meta_data.compiled, meta_data.dtype, meta_data.ftype) - - def _translate_literal_col(self, expr, **kwargs): - if not self.backend.is_aligned_with(expr): - raise AlignmentError( - "Literal column isn't aligned with this table. " - f"Literal Column: {expr}" - ) - - def sql_col(cols, **kw): - return expr.typed_value.value - - return TypedValue(sql_col, expr.typed_value.dtype, expr.typed_value.ftype) - - def _translate_function( - self, implementation, op_args, context_kwargs, *, verb=None, **kwargs - ): - value = self._translate_function_value( - implementation, - op_args, - context_kwargs, - verb=verb, - **kwargs, - ) - operator = implementation.operator - - if operator.ftype == OPType.AGGREGATE and verb == "mutate": - # Aggregate function in mutate verb -> window function - over_value = self.over_clause(value, implementation, context_kwargs) - ftype = self.backend._get_op_ftype( - op_args, operator, OPType.WINDOW, strict=True - ) - return TypedValue(over_value, implementation.rtype, ftype) - - elif operator.ftype == OPType.WINDOW: - if verb != "mutate": - raise FunctionTypeError( - "Window function are only allowed inside a mutate." - ) - - over_value = self.over_clause(value, implementation, context_kwargs) - ftype = self.backend._get_op_ftype(op_args, operator, strict=True) - return TypedValue(over_value, implementation.rtype, ftype) - - else: - ftype = self.backend._get_op_ftype(op_args, operator, strict=True) - return TypedValue(value, implementation.rtype, ftype) - - def _translate_function_value( - self, - implementation: TypedOperatorImpl, - op_args: list, - context_kwargs: dict, - *, - verb=None, - **kwargs, - ): - impl_dtypes = implementation.impl.signature.args - if implementation.impl.signature.is_vararg: - impl_dtypes = itertools.chain( - impl_dtypes[:-1], - itertools.repeat(impl_dtypes[-1]), - ) - - def value(cols, *, variant=None, internal_kwargs=None, **kw): - args = [] - for arg, dtype in zip(op_args, impl_dtypes): - if dtype.const: - args.append(arg.value(cols, as_sql_literal=False)) - else: - args.append(arg.value(cols)) - - kwargs = { - "_tbl": self.backend, - "_verb": verb, - **(internal_kwargs or {}), - } - - if variant is not None: - if variant_impl := implementation.get_variant(variant): - return variant_impl(*args, **kwargs) - - return implementation(*args, **kwargs) - - return value - - def _translate_case(self, expr, switching_on, cases, default, **kwargs): - def value(*args, **kw): - default_ = default.value(*args, **kwargs) - - if switching_on is not None: - switching_on_ = switching_on.value(*args, **kwargs) - return sa.case( - { - cond.value(*args, **kw): val.value(*args, **kw) - for cond, val in cases - }, - value=switching_on_, - else_=default_, - ) - - return sa.case( - *( - (cond.value(*args, **kw), val.value(*args, **kw)) - for cond, val in cases - ), - else_=default_, - ) - - result_dtype, result_ftype = self._translate_case_common( - expr, switching_on, cases, default, **kwargs - ) - return TypedValue(value, result_dtype, result_ftype) - - def _translate_literal_value(self, expr): - def literal_func(*args, as_sql_literal=True, **kwargs): - if as_sql_literal: - return sa.literal(expr) - return expr - - return literal_func - - def over_clause( - self, - value: Callable, - implementation: TypedOperatorImpl, - context_kwargs: dict, - ): - operator = implementation.operator - if operator.ftype not in (OPType.AGGREGATE, OPType.WINDOW): - raise FunctionTypeError - - wants_order_by = operator.ftype == OPType.WINDOW - - # PARTITION BY - grouping = context_kwargs.get("partition_by") - if grouping is not None: - grouping = [self.backend.resolve_lambda_cols(col) for col in grouping] - else: - grouping = self.backend.grouped_by - - compiled_pb = tuple(self.translate(col).value for col in grouping) - - # ORDER BY - def order_by_clause_generator(ordering: OrderingDescriptor): - compiled, _ = self.translate(ordering.order) - - def clause(*args, **kwargs): - col = compiled(*args, **kwargs) - return self.backend._order_col(col, ordering) - - return clause - - if wants_order_by: - arrange = context_kwargs.get("arrange") - if not arrange: - raise TypeError("Missing 'arrange' argument.") - - ordering = translate_ordering(self.backend, arrange) - compiled_ob = [order_by_clause_generator(o_by) for o_by in ordering] - - # New value callable - def over_value(*args, **kwargs): - pb = sql.expression.ClauseList( - *(compiled(*args, **kwargs) for compiled in compiled_pb) - ) - ob = ( - sql.expression.ClauseList( - *( - clause - for compiled in compiled_ob - for clause in compiled(*args, **kwargs) - ) - ) - if wants_order_by - else None - ) - - # Some operators need to further modify the OVER expression - # To do this, we allow registering a variant called "window" - if implementation.has_variant("window"): - return value( - *args, - variant="window", - internal_kwargs={ - "_window_partition_by": pb, - "_window_order_by": ob, - }, - **kwargs, - ) - - # If now window variant has been defined, just apply generic OVER clause - return value(*args, **kwargs).over( - partition_by=pb, - order_by=ob, - ) - - return over_value - - class AlignedExpressionEvaluator( - AbstractTableImpl.AlignedExpressionEvaluator[TypedValue[sql.ColumnElement]] - ): - def translate(self, expr, check_alignment=True, **kwargs): - if check_alignment: - alignment_hashes = { - col.table.alignment_hash - for col in iterate_over_expr(expr, expand_literal_col=True) - if isinstance(col, Column) - } - if len(alignment_hashes) >= 2: - raise AlignmentError( - "Expression contains columns from different tables that aren't" - " aligned." - ) - - return super().translate(expr, check_alignment=check_alignment, **kwargs) - - def _translate_col(self, col, **kwargs): - backend = col.table - if col.uuid in backend.sql_columns: - sql_col = backend.sql_columns[col.uuid] - return TypedValue(sql_col, col.dtype) - - meta_data = backend.cols[col.uuid] - return TypedValue( - meta_data.compiled(backend.sql_columns), - meta_data.dtype, - meta_data.ftype, - ) - - def _translate_literal_col(self, expr, **kwargs): - assert issubclass(expr.backend, SQLTableImpl) - return expr.typed_value - - def _translate_function( - self, implementation, op_args, context_kwargs, **kwargs - ): - # Aggregate function -> window function - value = implementation(*(arg.value for arg in op_args)) - operator = implementation.operator - override_ftype = ( - OPType.WINDOW if operator.ftype == OPType.AGGREGATE else None - ) - ftype = SQLTableImpl._get_op_ftype( - op_args, operator, override_ftype, strict=True - ) - - if operator.ftype == OPType.AGGREGATE: - value = value.over() - if operator.ftype == OPType.WINDOW: - raise NotImplementedError("How to handle window functions?") - - return TypedValue(value, implementation.rtype, ftype) - - -@dataclass -class JoinDescriptor(Generic[ImplT]): - __slots__ = ("right", "on", "how") - - right: ImplT - on: Any - how: str - - -def generate_alignment_hash(): - # It should be possible to have an alternative hash value that - # is a bit more lenient -> If the same set of operations get applied - # to a table in two different orders that produce the same table - # object, their hash could also be equal. - return uuid.uuid1() - - -#### BACKEND SPECIFIC OPERATORS ################################################ - - -with SQLTableImpl.op(ops.FloorDiv(), check_super=False) as op: - if sa.__version__ < "2": - - @op.auto - def _floordiv(lhs, rhs): - return sa.cast(lhs / rhs, sa.Integer()) - - else: - - @op.auto - def _floordiv(lhs, rhs): - return lhs // rhs - - -with SQLTableImpl.op(ops.RFloorDiv(), check_super=False) as op: - - @op.auto - def _rfloordiv(rhs, lhs): - return _floordiv(lhs, rhs) - - -with SQLTableImpl.op(ops.Pow()) as op: - - @op.auto - def _pow(lhs, rhs): - if isinstance(lhs.type, sa.Float) or isinstance(rhs.type, sa.Float): - type_ = sa.Double() - elif isinstance(lhs.type, sa.Numeric) or isinstance(rhs, sa.Numeric): - type_ = sa.Numeric() - else: - type_ = sa.Double() - - return sa.func.POW(lhs, rhs, type_=type_) - - -with SQLTableImpl.op(ops.RPow()) as op: - - @op.auto - def _rpow(rhs, lhs): - return _pow(lhs, rhs) - - -with SQLTableImpl.op(ops.Xor()) as op: - - @op.auto - def _xor(lhs, rhs): - return lhs != rhs - - -with SQLTableImpl.op(ops.RXor()) as op: - - @op.auto - def _rxor(rhs, lhs): - return lhs != rhs - - -with SQLTableImpl.op(ops.Pos()) as op: - - @op.auto - def _pos(x): - return x - - -with SQLTableImpl.op(ops.Abs()) as op: - - @op.auto - def _abs(x): - return sa.func.ABS(x, type_=x.type) - - -with SQLTableImpl.op(ops.Round()) as op: - - @op.auto - def _round(x, decimals=0): - return sa.func.ROUND(x, decimals, type_=x.type) - - -with SQLTableImpl.op(ops.IsIn()) as op: - - @op.auto - def _isin(x, *values, _verb=None): - if _verb == "filter": - # In WHERE and HAVING clause, we can use the IN operator - return x.in_(values) - # In SELECT we must replace it with the corresponding boolean expression - return reduce(py_operator.or_, map(lambda v: x == v, values)) - - -with SQLTableImpl.op(ops.IsNull()) as op: - - @op.auto - def _is_null(x): - return x.is_(sa.null()) - - -with SQLTableImpl.op(ops.IsNotNull()) as op: - - @op.auto - def _is_not_null(x): - return x.is_not(sa.null()) - - -#### String Functions #### - - -with SQLTableImpl.op(ops.StrStrip()) as op: - - @op.auto - def _str_strip(x): - return sa.func.TRIM(x, type_=x.type) - - -with SQLTableImpl.op(ops.StrLen()) as op: - - @op.auto - def _str_length(x): - return sa.func.LENGTH(x, type_=sa.Integer()) - - -with SQLTableImpl.op(ops.StrToUpper()) as op: - - @op.auto - def _upper(x): - return sa.func.UPPER(x, type_=x.type) - - -with SQLTableImpl.op(ops.StrToLower()) as op: - - @op.auto - def _upper(x): - return sa.func.LOWER(x, type_=x.type) - - -with SQLTableImpl.op(ops.StrReplaceAll()) as op: - - @op.auto - def _replace(x, y, z): - return sa.func.REPLACE(x, y, z, type_=x.type) - - -with SQLTableImpl.op(ops.StrStartsWith()) as op: - - @op.auto - def _startswith(x, y): - return x.startswith(y, autoescape=True) - - -with SQLTableImpl.op(ops.StrEndsWith()) as op: - - @op.auto - def _endswith(x, y): - return x.endswith(y, autoescape=True) - - -with SQLTableImpl.op(ops.StrContains()) as op: - - @op.auto - def _contains(x, y): - return x.contains(y, autoescape=True) - - -with SQLTableImpl.op(ops.StrSlice()) as op: - - @op.auto - def _str_slice(x, offset, length): - # SQL has 1-indexed strings but we do it 0-indexed - return sa.func.SUBSTR(x, offset + 1, length) - - -#### Datetime Functions #### - - -with SQLTableImpl.op(ops.DtYear()) as op: - - @op.auto - def _year(x): - return sa.extract("year", x) - - -with SQLTableImpl.op(ops.DtMonth()) as op: - - @op.auto - def _month(x): - return sa.extract("month", x) - - -with SQLTableImpl.op(ops.DtDay()) as op: - - @op.auto - def _day(x): - return sa.extract("day", x) - - -with SQLTableImpl.op(ops.DtHour()) as op: - - @op.auto - def _hour(x): - return sa.extract("hour", x) - - -with SQLTableImpl.op(ops.DtMinute()) as op: - - @op.auto - def _minute(x): - return sa.extract("minute", x) - - -with SQLTableImpl.op(ops.DtSecond()) as op: - - @op.auto - def _second(x): - return sa.extract("second", x) - - -with SQLTableImpl.op(ops.DtMillisecond()) as op: - - @op.auto - def _millisecond(x): - return sa.extract("milliseconds", x) % 1000 - - -with SQLTableImpl.op(ops.DtDayOfWeek()) as op: - - @op.auto - def _day_of_week(x): - return sa.extract("dow", x) - - -with SQLTableImpl.op(ops.DtDayOfYear()) as op: - - @op.auto - def _day_of_year(x): - return sa.extract("doy", x) - - -#### Generic Functions #### - - -with SQLTableImpl.op(ops.Greatest()) as op: - - @op.auto - def _greatest(*x): - # TODO: Determine return type - return sa.func.GREATEST(*x) - - -with SQLTableImpl.op(ops.Least()) as op: - - @op.auto - def _least(*x): - # TODO: Determine return type - return sa.func.LEAST(*x) - - -#### Summarising Functions #### - - -with SQLTableImpl.op(ops.Mean()) as op: - - @op.auto - def _mean(x): - type_ = sa.Numeric() - if isinstance(x.type, sa.Float): - type_ = sa.Double() - - return sa.func.AVG(x, type_=type_) - - -with SQLTableImpl.op(ops.Min()) as op: - - @op.auto - def _min(x): - return sa.func.min(x) - - -with SQLTableImpl.op(ops.Max()) as op: - - @op.auto - def _max(x): - return sa.func.max(x) - - -with SQLTableImpl.op(ops.Sum()) as op: - - @op.auto - def _sum(x): - return sa.func.sum(x) - - -with SQLTableImpl.op(ops.Any()) as op: - - @op.auto - def _any(x, *, _window_partition_by=None, _window_order_by=None): - return sa.func.coalesce(sa.func.max(x), sa.false()) - - @op.auto(variant="window") - def _any(x, *, _window_partition_by=None, _window_order_by=None): - return sa.func.coalesce( - sa.func.max(x).over( - partition_by=_window_partition_by, - order_by=_window_order_by, - ), - sa.false(), - ) - - -with SQLTableImpl.op(ops.All()) as op: - - @op.auto - def _all(x): - return sa.func.coalesce(sa.func.min(x), sa.false()) - - @op.auto(variant="window") - def _all(x, *, _window_partition_by=None, _window_order_by=None): - return sa.func.coalesce( - sa.func.min(x).over( - partition_by=_window_partition_by, - order_by=_window_order_by, - ), - sa.false(), - ) - - -with SQLTableImpl.op(ops.Count()) as op: - - @op.auto - def _count(x=None): - if x is None: - # Get the number of rows - return sa.func.count() - else: - # Count non null values - return sa.func.count(x) - - -#### Window Functions #### - - -with SQLTableImpl.op(ops.Shift()) as op: - - @op.auto - def _shift(): - raise RuntimeError("This is a stub") - - @op.auto(variant="window") - def _shift( - x, - by, - empty_value=None, - *, - _window_partition_by=None, - _window_order_by=None, - ): - if by == 0: - return x - if by > 0: - return sa.func.LAG(x, by, empty_value, type_=x.type).over( - partition_by=_window_partition_by, order_by=_window_order_by - ) - if by < 0: - return sa.func.LEAD(x, -by, empty_value, type_=x.type).over( - partition_by=_window_partition_by, order_by=_window_order_by - ) - - -with SQLTableImpl.op(ops.RowNumber()) as op: - - @op.auto - def _row_number(): - return sa.func.ROW_NUMBER(type_=sa.Integer()) - - -with SQLTableImpl.op(ops.Rank()) as op: - - @op.auto - def _rank(): - return sa.func.rank() - - -with SQLTableImpl.op(ops.DenseRank()) as op: - - @op.auto - def _dense_rank(): - return sa.func.dense_rank() - - -from .mssql import MSSqlTableImpl # noqa -from .duckdb import DuckDBTableImpl # noqa -from .postgres import PostgresTableImpl # noqa -from .sqlite import SQLiteTableImpl # noqa diff --git a/src/pydiverse/transform/tree/__init__.py b/src/pydiverse/transform/tree/__init__.py new file mode 100644 index 00000000..d75d93dc --- /dev/null +++ b/src/pydiverse/transform/tree/__init__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +from . import col_expr +from .ast import AstNode + +__all__ = ["AstNode", "Col", "col_expr"] diff --git a/src/pydiverse/transform/tree/ast.py b/src/pydiverse/transform/tree/ast.py new file mode 100644 index 00000000..ec3db41b --- /dev/null +++ b/src/pydiverse/transform/tree/ast.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from collections.abc import Iterable +from uuid import UUID + + +class AstNode: + __slots__ = ["name"] + + name: str + + def clone(self) -> AstNode: + return self._clone()[0] + + def _clone(self) -> tuple[AstNode, dict[AstNode, AstNode], dict[UUID, UUID]]: ... + + def iter_subtree(self) -> Iterable[AstNode]: ... diff --git a/src/pydiverse/transform/tree/col_expr.py b/src/pydiverse/transform/tree/col_expr.py new file mode 100644 index 00000000..5ad49d05 --- /dev/null +++ b/src/pydiverse/transform/tree/col_expr.py @@ -0,0 +1,510 @@ +from __future__ import annotations + +import copy +import dataclasses +import functools +import html +import itertools +import operator +from collections.abc import Callable, Generator, Iterable +from typing import Any +from uuid import UUID + +from pydiverse.transform.errors import FunctionTypeError +from pydiverse.transform.ops.core import Ftype +from pydiverse.transform.tree import dtypes +from pydiverse.transform.tree.ast import AstNode +from pydiverse.transform.tree.dtypes import Dtype, python_type_to_pdt +from pydiverse.transform.tree.registry import OperatorRegistry + + +class ColExpr: + __slots__ = ["_dtype", "_ftype"] + + __contains__ = None + __iter__ = None + + def __init__(self, _dtype: Dtype | None = None, _ftype: Ftype | None = None): + self._dtype = _dtype + self._ftype = _ftype + + def __getattr__(self, name: str) -> FnAttr: + if name.startswith("_") and name.endswith("_"): + # that hasattr works correctly + raise AttributeError(f"`ColExpr` has no attribute `{name}`") + return FnAttr(name, self) + + def __bool__(self): + raise TypeError( + "cannot call __bool__() on a ColExpr. hint: A ColExpr cannot be " + "converted to a boolean or used with the and, or, not keywords" + ) + + def __setstate__(self, d): # to avoid very annoying AttributeErrors + for slot, val in d[1].items(): + setattr(self, slot, val) + + def _repr_html_(self) -> str: + return f"
{html.escape(repr(self))}
" + + def _repr_pretty_(self, p, cycle): + p.text(str(self) if not cycle else "...") + + def dtype(self) -> Dtype: + return self._dtype + + def ftype(self, *, agg_is_window: bool) -> Ftype: + return self._ftype + + def map( + self, mapping: dict[tuple | ColExpr, ColExpr], *, default: ColExpr = None + ) -> CaseExpr: + return CaseExpr( + ( + ( + self.isin( + *wrap_literal(key if isinstance(key, Iterable) else (key,)) + ), + wrap_literal(val), + ) + for key, val in mapping.items() + ), + wrap_literal(default), + ) + + def iter_children(self) -> Iterable[ColExpr]: + return iter(()) + + # yields all ColExpr`s appearing in the subtree of `self`. Python builtin types + # and `Order` expressions are not yielded. + def iter_subtree(self) -> Iterable[ColExpr]: + for node in self.iter_children(): + yield from node.iter_subtree() + yield self + + def map_subtree(self, g: Callable[[ColExpr], ColExpr]) -> ColExpr: + return g(self) + + +class Col(ColExpr): + __slots__ = ["name", "_ast", "_uuid"] + + def __init__( + self, name: str, _ast: AstNode, _uuid: UUID, _dtype: Dtype, _ftype: Ftype + ): + self.name = name + self._ast = _ast + self._uuid = _uuid + super().__init__(_dtype, _ftype) + + def __repr__(self) -> str: + return f"<{self._ast.name}.{self.name}" f"({self.dtype()})>" + + def __str__(self) -> str: + try: + from pydiverse.transform.backend.targets import Polars + from pydiverse.transform.pipe.verbs import export, select + + df = self.table >> select(self) >> export(Polars()) + return str(df) + except Exception as e: + return ( + repr(self) + + f"\ncould evaluate {repr(self)} due to" + + f"{e.__class__.__name__}: {str(e)}" + ) + + def __hash__(self) -> int: + return hash(self.uuid) + + +class ColName(ColExpr): + __slots__ = ["name"] + + def __init__( + self, name: str, dtype: Dtype | None = None, ftype: Ftype | None = None + ): + self.name = name + super().__init__(dtype, ftype) + + def __repr__(self) -> str: + return f"" + + +class LiteralCol(ColExpr): + __slots__ = ["val"] + + def __init__(self, val: Any): + self.val = val + dtype = python_type_to_pdt(type(val)) + dtype.const = True + super().__init__(dtype, Ftype.EWISE) + + def __repr__(self): + return f"<{self.val} ({self.dtype()})>" + + +class ColFn(ColExpr): + __slots__ = ["name", "args", "context_kwargs"] + + def __init__(self, name: str, *args: ColExpr, **kwargs: list[ColExpr | Order]): + self.name = name + self.args = list(args) + self.context_kwargs = { + key: [val] if not isinstance(val, Iterable) else list(val) + for key, val in kwargs.items() + } + + if arrange := self.context_kwargs.get("arrange"): + self.context_kwargs["arrange"] = [ + Order.from_col_expr(expr) if isinstance(expr, ColExpr) else expr + for expr in arrange + ] + + if filters := self.context_kwargs.get("filter"): + if len(self.args) == 0: + assert self.name == "count" + self.args = [LiteralCol(0)] + + # TODO: check that this is an aggregation + + assert len(self.args) == 1 + self.args[0] = CaseExpr( + [ + ( + functools.reduce(operator.and_, (cond for cond in filters)), + self.args[0], + ) + ] + ) + del self.context_kwargs["filter"] + + super().__init__() + # try to eagerly resolve the types to get a nicer stack trace on type errors + self.dtype() + + def __repr__(self) -> str: + args = [repr(e) for e in self.args] + [ + f"{key}={repr(val)}" for key, val in self.context_kwargs.items() + ] + return f'{self.name}({", ".join(args)})' + + def iter_children(self) -> Iterable[ColExpr]: + yield from itertools.chain(self.args, *self.context_kwargs.values()) + + def map_subtree(self, g: Callable[[ColExpr], ColExpr]) -> ColExpr: + new_fn = copy.copy(self) + new_fn.args = [arg.map_subtree(g) for arg in self.args] + + new_fn.context_kwargs = { + key: [val.map_subtree(g) for val in arr] + for key, arr in self.context_kwargs.items() + } + return g(new_fn) + + def dtype(self) -> Dtype: + if self._dtype is not None: + return self._dtype + + arg_dtypes = [arg.dtype() for arg in self.args] + if None in arg_dtypes: + return None + + from pydiverse.transform.backend import PolarsImpl + + self._dtype = PolarsImpl.registry.get_impl(self.name, arg_dtypes).return_type + return self._dtype + + def ftype(self, *, agg_is_window: bool): + """ + Determine the ftype based on the arguments. + + e(e) -> e a(e) -> a w(e) -> w + e(a) -> a a(a) -> Err w(a) -> w + e(w) -> w a(w) -> Err w(w) -> Err + + If the operator ftype is incompatible with the arguments, this function raises + an Exception. + """ + + # TODO: This causes wrong results if ftype is called once with + # agg_is_window=True and then with agg_is_window=False. + if self._ftype is not None: + return self._ftype + + ftypes = [arg.ftype(agg_is_window=agg_is_window) for arg in self.args] + if None in ftypes: + return None + + from pydiverse.transform.backend.polars import PolarsImpl + + op = PolarsImpl.registry.get_op(self.name) + + actual_ftype = ( + Ftype.WINDOW if op.ftype == Ftype.AGGREGATE and agg_is_window else op.ftype + ) + + if actual_ftype == Ftype.EWISE: + # this assert is ok since window functions in `summarise` are already kicked + # out by the `Summarise` constructor. + assert not (Ftype.WINDOW in ftypes and Ftype.AGGREGATE in ftypes) + + if Ftype.WINDOW in ftypes: + self._ftype = Ftype.WINDOW + elif Ftype.AGGREGATE in ftypes: + self._ftype = Ftype.AGGREGATE + else: + self._ftype = actual_ftype + + else: + self._ftype = actual_ftype + + # kick out nested window / aggregation functions + for node in self.iter_subtree(): + if ( + node is not self + and isinstance(node, ColFn) + and ( + (desc_ftype := PolarsImpl.registry.get_op(node.name).ftype) + in ( + Ftype.AGGREGATE, + Ftype.WINDOW, + ) + ) + ): + assert isinstance(self, ColFn) + ftype_string = { + Ftype.AGGREGATE: "aggregation", + Ftype.WINDOW: "window", + } + raise FunctionTypeError( + f"{ftype_string[desc_ftype]} function `{node.name}` nested " + f"inside {ftype_string[self._ftype]} function `{self.name}`.\n" + "hint: There may be at most one window / aggregation function " + "in a column expression on any path from the root to a leaf." + ) + + return self._ftype + + +@dataclasses.dataclass(slots=True) +class FnAttr: + name: str + arg: ColExpr + + def __getattr__(self, name) -> FnAttr: + return FnAttr(f"{self.name}.{name}", self.arg) + + def __call__(self, *args, **kwargs) -> ColExpr: + return ColFn( + self.name, + wrap_literal(self.arg), + *wrap_literal(args), + **wrap_literal(kwargs), + ) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self.name}({self.arg})>" + + +@dataclasses.dataclass(slots=True) +class WhenClause: + cases: list[tuple[ColExpr, ColExpr]] + cond: ColExpr + + def then(self, value: ColExpr) -> CaseExpr: + return CaseExpr((*self.cases, (self.cond, wrap_literal(value)))) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self.cond}>" + + +class CaseExpr(ColExpr): + __slots__ = ["cases", "default_val"] + + def __init__( + self, + cases: Iterable[tuple[ColExpr, ColExpr]], + default_val: ColExpr | None = None, + ): + self.cases = list(cases) + + # We distinguish `None` and `LiteralCol(None)` as a `default_val`. The first one + # signals that the user has not yet set a default value, the second one + # indicates that the user set `None` as a default value. + self.default_val = default_val + super().__init__() + self.dtype() + + def __repr__(self) -> str: + return ( + " {val}, " for cond, val in self.cases), "" + ) + + f"default={self.default_val}>" + ) + + def iter_children(self) -> Iterable[ColExpr]: + yield from itertools.chain.from_iterable(self.cases) + if self.default_val is not None: + yield self.default_val + + def map_subtree(self, g: Callable[[ColExpr], ColExpr]) -> ColExpr: + new_case_expr = copy.copy(self) + new_case_expr.cases = [ + (cond.map_subtree(g), val.map_subtree(g)) for cond, val in self.cases + ] + new_case_expr.default_val = ( + self.default_val.map_subtree(g) if self.default_val is not None else None + ) + return g(new_case_expr) + + def dtype(self): + if self._dtype is not None: + return self._dtype + + try: + val_types = [val.dtype() for _, val in self.cases] + if self.default_val is not None: + val_types.append(self.default_val.dtype()) + + if None in val_types: + return None + + self._dtype = dtypes.promote_dtypes( + [dtype.without_modifiers() for dtype in val_types] + ) + except Exception as e: + raise TypeError(f"invalid case expression: {e}") from e + + for cond, _ in self.cases: + if cond.dtype() is not None and not isinstance(cond.dtype(), dtypes.Bool): + raise TypeError( + f"argument `{cond}` for `when` must be of boolean type, but has " + f"type `{cond.dtype()}`" + ) + + return self._dtype + + def ftype(self, *, agg_is_window: bool): + if self._ftype is not None: + return self._ftype + + val_ftypes = set() + if self.default_val is not None and not self.default_val.dtype().const: + val_ftypes.add(self.default_val.ftype(agg_is_window=agg_is_window)) + + for _, val in self.cases: + if val.dtype() is not None and not val.dtype().const: + val_ftypes.add(val.ftype(agg_is_window=agg_is_window)) + + if None in val_ftypes: + return None + + if len(val_ftypes) == 0: + self._ftype = Ftype.EWISE + elif len(val_ftypes) == 1: + (self._ftype,) = val_ftypes + elif Ftype.WINDOW in val_ftypes: + self._ftype = Ftype.WINDOW + else: + # AGGREGATE and EWISE are incompatible + raise FunctionTypeError( + "incompatible function types found in case statement: " ", ".join( + val_ftypes + ) + ) + + return self._ftype + + def when(self, condition: ColExpr) -> WhenClause: + if self.default_val is not None: + raise TypeError("cannot call `when` on a closed case expression after") + + if condition.dtype() is not None and not isinstance( + condition.dtype(), dtypes.Bool + ): + raise TypeError( + "argument for `when` must be of boolean type, but has type " + f"`{condition.dtype()}`" + ) + + return WhenClause(self.cases, wrap_literal(condition)) + + def otherwise(self, value: ColExpr) -> CaseExpr: + if self.default_val is not None: + raise TypeError("default value is already set on this case expression") + return CaseExpr(self.cases, wrap_literal(value)) + + +@dataclasses.dataclass(slots=True) +class Order: + order_by: ColExpr + descending: bool = False + nulls_last: bool | None = None + + # The given `expr` may contain nulls_last markers or descending markers. The + # order_by of the Order does not contain these special functions and can thus be + # translated normally. + @staticmethod + def from_col_expr(expr: ColExpr) -> Order: + descending = None + nulls_last = None + while isinstance(expr, ColFn): + if descending is None: + if expr.name == "descending": + descending = True + elif expr.name == "ascending": + descending = False + + if nulls_last is None: + if expr.name == "nulls_last": + nulls_last = True + elif expr.name == "nulls_first": + nulls_last = False + + if expr.name in ("descending", "ascending", "nulls_last", "nulls_first"): + assert len(expr.args) == 1 + assert len(expr.context_kwargs) == 0 + expr = expr.args[0] + else: + break + + if descending is None: + descending = False + + return Order(expr, descending, nulls_last) + + def iter_subtree(self) -> Iterable[ColExpr]: + yield from self.order_by.iter_subtree() + + def map_subtree(self, g: Callable[[ColExpr], ColExpr]) -> Order: + return Order(self.order_by.map_subtree(g), self.descending, self.nulls_last) + + +# Add all supported dunder methods to `ColExpr`. This has to be done, because Python +# doesn't call __getattr__ for dunder methods. +def create_operator(op): + def impl(*args, **kwargs): + return ColFn(op, *wrap_literal(args), **wrap_literal(kwargs)) + + return impl + + +for dunder in OperatorRegistry.SUPPORTED_DUNDER: + setattr(ColExpr, dunder, create_operator(dunder)) +del create_operator + + +def wrap_literal(expr: Any) -> Any: + if isinstance(expr, ColExpr | Order): + return expr + elif isinstance(expr, dict): + return {key: wrap_literal(val) for key, val in expr.items()} + elif isinstance(expr, (list, tuple)): + return expr.__class__(wrap_literal(elem) for elem in expr) + elif isinstance(expr, Generator): + return (wrap_literal(elem) for elem in expr) + else: + return LiteralCol(expr) diff --git a/src/pydiverse/transform/core/dtypes.py b/src/pydiverse/transform/tree/dtypes.py similarity index 74% rename from src/pydiverse/transform/core/dtypes.py rename to src/pydiverse/transform/tree/dtypes.py index 11dccd87..be06ec35 100644 --- a/src/pydiverse/transform/core/dtypes.py +++ b/src/pydiverse/transform/tree/dtypes.py @@ -1,12 +1,14 @@ from __future__ import annotations +import datetime from abc import ABC, abstractmethod +from types import NoneType from pydiverse.transform._typing import T -from pydiverse.transform.errors import ExpressionTypeError +from pydiverse.transform.errors import DataTypeError -class DType(ABC): +class Dtype(ABC): def __init__(self, *, const: bool = False, vararg: bool = False): self.const = const self.vararg = vararg @@ -44,7 +46,7 @@ def without_modifiers(self: T) -> T: """Returns a copy of `self` with all modifiers removed""" return type(self)() - def same_kind(self, other: DType) -> bool: + def same_kind(self, other: Dtype) -> bool: """Check if `other` is of the same type as self. More specifically, `other` must be a stricter subtype of `self`. @@ -62,14 +64,14 @@ def same_kind(self, other: DType) -> bool: return True - def can_promote_to(self, other: DType) -> bool: + def can_promote_to(self, other: Dtype) -> bool: return other.same_kind(self) -class Int(DType): +class Int(Dtype): name = "int" - def can_promote_to(self, other: DType) -> bool: + def can_promote_to(self, other: Dtype) -> bool: if super().can_promote_to(other): return True @@ -83,31 +85,31 @@ def can_promote_to(self, other: DType) -> bool: return False -class Float(DType): +class Float(Dtype): name = "float" -class String(DType): +class String(Dtype): name = "str" -class Bool(DType): +class Bool(Dtype): name = "bool" -class DateTime(DType): +class DateTime(Dtype): name = "datetime" -class Date(DType): +class Date(Dtype): name = "date" -class Duration(DType): +class Duration(Dtype): name = "duration" -class Template(DType): +class Template(Dtype): name = None def __init__(self, name, **kwargs): @@ -117,13 +119,13 @@ def __init__(self, name, **kwargs): def without_modifiers(self: T) -> T: return type(self)(self.name) - def same_kind(self, other: DType) -> bool: + def same_kind(self, other: Dtype) -> bool: if not super().same_kind(other): return False return self.name == other.name - def modifiers_compatible(self, other: DType) -> bool: + def modifiers_compatible(self, other: Dtype) -> bool: """ Check if another dtype object is compatible with the modifiers of the template. """ @@ -132,13 +134,34 @@ def modifiers_compatible(self, other: DType) -> bool: return True -class NoneDType(DType): +class NoneDtype(Dtype): """DType used to represent the `None` value.""" name = "none" -def dtype_from_string(t: str) -> DType: +def python_type_to_pdt(t: type) -> Dtype: + if t is int: + return Int() + elif t is float: + return Float() + elif t is bool: + return Bool() + elif t is str: + return String() + elif t is datetime.datetime: + return DateTime() + elif t is datetime.date: + return Date() + elif t is datetime.timedelta: + return Duration() + elif t is NoneType: + return NoneDtype() + + raise TypeError(f"pydiverse.transform does not support python builtin type {t}") + + +def dtype_from_string(t: str) -> Dtype: parts = [part for part in t.split(" ") if part] is_const = False @@ -182,20 +205,20 @@ def dtype_from_string(t: str) -> DType: if base_type == "duration": return Duration(const=is_const, vararg=is_vararg) if base_type == "none": - return NoneDType(const=is_const, vararg=is_vararg) + return NoneDtype(const=is_const, vararg=is_vararg) raise ValueError(f"Unknown type '{base_type}'") -def promote_dtypes(dtypes: list[DType]) -> DType: +def promote_dtypes(dtypes: list[Dtype]) -> Dtype: if len(dtypes) == 0: - raise ValueError("Expected non empty list of dtypes") + raise ValueError("expected non empty list of dtypes") promoted = dtypes[0] for dtype in dtypes[1:]: - if isinstance(dtype, NoneDType): + if isinstance(dtype, NoneDtype): continue - if isinstance(promoted, NoneDType): + if isinstance(promoted, NoneDtype): promoted = dtype continue @@ -205,6 +228,6 @@ def promote_dtypes(dtypes: list[DType]) -> DType: promoted = dtype continue - raise ExpressionTypeError(f"Incompatible types {dtype} and {promoted}.") + raise DataTypeError(f"incompatible types {dtype} and {promoted}") return promoted diff --git a/src/pydiverse/transform/core/registry.py b/src/pydiverse/transform/tree/registry.py similarity index 88% rename from src/pydiverse/transform/core/registry.py rename to src/pydiverse/transform/tree/registry.py index cce6dca0..023d9ed5 100644 --- a/src/pydiverse/transform/core/registry.py +++ b/src/pydiverse/transform/tree/registry.py @@ -9,8 +9,8 @@ from functools import partial from typing import TYPE_CHECKING, Callable -from pydiverse.transform.core import dtypes -from pydiverse.transform.errors import ExpressionTypeError +from pydiverse.transform.errors import DataTypeError +from pydiverse.transform.tree import dtypes if TYPE_CHECKING: from pydiverse.transform.ops import Operator, OperatorExtension @@ -112,14 +112,14 @@ class TypedOperatorImpl: operator: Operator impl: OperatorImpl - rtype: dtypes.DType + return_type: dtypes.Dtype @classmethod - def from_operator_impl(cls, impl: OperatorImpl, rtype: dtypes.DType): + def from_operator_impl(cls, impl: OperatorImpl, return_type: dtypes.Dtype): return cls( operator=impl.operator, impl=impl, - rtype=rtype, + return_type=return_type, ) def __call__(self, *args, **kwargs): @@ -194,9 +194,9 @@ class OperatorRegistry: def __init__(self, name, super_registry=None): self.name = name - self.super_registry = super_registry + self.super_registry: OperatorRegistry | None = super_registry self.registered_ops: set[Operator] = set() - self.implementations: dict[str, OperatorImplementationStore] = dict() + self.implementations: dict[str, OperatorImplStore] = dict() self.check_super: dict[str, bool] = dict() def register_op(self, operator: Operator, check_super=True): @@ -222,23 +222,23 @@ def register_op(self, operator: Operator, check_super=True): " in this registry." ) - self.implementations[name] = OperatorImplementationStore(operator) + self.implementations[name] = OperatorImplStore(operator) self.check_super[name] = check_super self.registered_ops.add(operator) self.ALL_REGISTERED_OPS.add(name) - def get_operator(self, name: str) -> Operator | None: + def get_op(self, name: str) -> Operator | None: if impl_store := self.implementations.get(name, None): return impl_store.operator # If operation hasn't been defined in this registry, go to the parent # registry and check if it has been defined there. if self.super_registry is None or not self.check_super.get(name, True): - raise ValueError(f"No implementation for operator '{name}' found") - return self.super_registry.get_operator(name) + raise ValueError(f"no implementation for operator `{name}` found") + return self.super_registry.get_op(name) - def add_implementation( + def add_impl( self, operator: Operator, impl: Callable, @@ -247,30 +247,30 @@ def add_implementation( ): if operator not in self.registered_ops: raise ValueError( - f"Operator {operator} ({operator.name}) hasn't been registered in this" - f" operator registry '{self.name}'" + f"operator `{operator}` ({operator.name}) hasn't been registered in the" + f" operator registry `{self.name}` yet" ) signature = OperatorSignature.parse(signature) operator.validate_signature(signature) - implementation_store = self.implementations[operator.name] + impl_store = self.implementations[operator.name] op_impl = OperatorImpl(operator, impl, signature) if variant: - implementation_store.add_variant(variant, op_impl) + impl_store.add_variant(variant, op_impl) else: - implementation_store.add_implementation(op_impl) + impl_store.add_impl(op_impl) - def get_implementation(self, name, args_signature) -> TypedOperatorImpl: + def get_impl(self, name, args_signature) -> TypedOperatorImpl: if name not in self.ALL_REGISTERED_OPS: - raise ValueError(f"No operator named '{name}'.") + raise ValueError(f"operator named `{name}` does not exist") for dtype in args_signature: - if not isinstance(dtype, dtypes.DType): + if not isinstance(dtype, dtypes.Dtype): raise TypeError( - "Expected elements of `args_signature` to be of type DType," - f" but found element of type {type(dtype).__name__} instead." + "expected elements of `args_signature` to be of type Dtype, " + f"found element of type {type(dtype).__name__} instead" ) if store := self.implementations.get(name): @@ -280,11 +280,11 @@ def get_implementation(self, name, args_signature) -> TypedOperatorImpl: # If operation hasn't been defined in this registry, go to the parent # registry and check if it has been defined there. if self.super_registry is None or not self.check_super.get(name, True): - raise ValueError( - f"No implementation for operator '{name}' found that matches signature" - f" '{args_signature}'." + raise TypeError( + f"invalid usage of operator `{name}` with arguments of type " + f"{args_signature}" ) - return self.super_registry.get_implementation(name, args_signature) + return self.super_registry.get_impl(name, args_signature) class OperatorSignature: @@ -314,7 +314,7 @@ class OperatorSignature: """ - def __init__(self, args: list[dtypes.DType], rtype: dtypes.DType): + def __init__(self, args: list[dtypes.Dtype], rtype: dtypes.Dtype): """ :param args: Tuple of argument types. :param rtype: The return type. @@ -389,7 +389,7 @@ def is_vararg(self) -> bool: return self.args[-1].vararg -class OperatorImplementationStore: +class OperatorImplStore: """ Stores all implementations for a specific operation in a trie according to their signature. This enables us to easily find the best matching @@ -399,9 +399,9 @@ class OperatorImplementationStore: @dataclasses.dataclass class TrieNode: __slots__ = ("value", "operator", "children") - value: dtypes.DType + value: dtypes.Dtype operator: OperatorImpl | None - children: list[OperatorImplementationStore.TrieNode] + children: list[OperatorImplStore.TrieNode] def __repr__(self): self_text = f"({self.value} - {self.operator})" @@ -415,7 +415,7 @@ def __init__(self, operator: Operator): self.operator = operator self.root = self.TrieNode("ROOT", None, []) # type: ignore - def add_implementation(self, operator: OperatorImpl): + def add_impl(self, operator: OperatorImpl): node = self.get_node(operator.signature, create_missing=True) if node.operator is not None: raise ValueError( @@ -453,7 +453,7 @@ def get_node(self, signature: OperatorSignature, create_missing: bool = True): return node def find_best_match( - self, signature: tuple[dtypes.DType] + self, signature: tuple[dtypes.Dtype] ) -> TypedOperatorImpl | None: matches = list(self._find_matches(signature)) @@ -483,8 +483,8 @@ def find_best_match( return TypedOperatorImpl.from_operator_impl(best_match[0].operator, rtype) def _find_matches( - self, signature: tuple[dtypes.DType] - ) -> Iterable[TrieNode, dict[str, dtypes.DType, tuple[int, ...]]]: + self, signature: tuple[dtypes.Dtype] + ) -> Iterable[TrieNode, dict[str, dtypes.Dtype, tuple[int, ...]]]: """Yield all operators that match the input signature""" # Case 0 arguments: @@ -494,16 +494,16 @@ def _find_matches( # Case 1+ args: def does_match( - dtype: dtypes.DType, - node: OperatorImplementationStore.TrieNode, + dtype: dtypes.Dtype, + node: OperatorImplStore.TrieNode, ) -> bool: if isinstance(node.value, dtypes.Template): return node.value.modifiers_compatible(dtype) return dtype.can_promote_to(node.value) - stack: list[ - tuple[OperatorImplementationStore.TrieNode, int, dict, tuple[int, ...]] - ] = [(child, 0, dict(), tuple()) for child in self.root.children] + stack: list[tuple[OperatorImplStore.TrieNode, int, dict, tuple[int, ...]]] = [ + (child, 0, dict(), tuple()) for child in self.root.children + ] while stack: node, s_i, templates, type_promotion_indices = stack.pop() @@ -535,7 +535,7 @@ def does_match( for name, types_ in templates.items() } yield node, templates, type_promotion_indices - except ExpressionTypeError: + except DataTypeError: print(f"Can't promote: {templates}") pass @@ -573,7 +573,7 @@ def __call__(self, signature: str, *, variant: str = None): raise TypeError("Signature must be of type str.") def decorator(func): - self.registry.add_implementation( + self.registry.add_impl( self.operator, func, signature, @@ -591,7 +591,7 @@ def auto(self, func: Callable = None, *, variant: str = None): raise ValueError(f"Operator {self.operator} has not default signatures.") for signature in self.operator.signatures: - self.registry.add_implementation( + self.registry.add_impl( self.operator, func, signature, @@ -609,7 +609,7 @@ def extension(self, extension: type[OperatorExtension], variant: str = None): def decorator(func): for sig in extension.signatures: - self.registry.add_implementation(self.operator, func, sig, variant) + self.registry.add_impl(self.operator, func, sig, variant) return func return decorator diff --git a/src/pydiverse/transform/tree/verbs.py b/src/pydiverse/transform/tree/verbs.py new file mode 100644 index 00000000..24e312fd --- /dev/null +++ b/src/pydiverse/transform/tree/verbs.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import copy +import dataclasses +import uuid +from collections.abc import Callable, Iterable +from typing import Literal +from uuid import UUID + +from pydiverse.transform.tree.ast import AstNode +from pydiverse.transform.tree.col_expr import Col, ColExpr, Order + +JoinHow = Literal["inner", "left", "outer"] + +JoinValidate = Literal["1:1", "1:m", "m:1", "m:m"] + + +@dataclasses.dataclass(eq=False, slots=True) +class Verb(AstNode): + child: AstNode + + def __post_init__(self): + self.name = self.child.name + + def _clone(self) -> tuple[Verb, dict[AstNode, AstNode], dict[UUID, UUID]]: + child, nd_map, uuid_map = self.child._clone() + cloned = copy.copy(self) + cloned.child = child + + cloned.map_col_nodes( + lambda col: Col( + col.name, nd_map[col._ast], uuid_map[col._uuid], col._dtype, col._ftype + ) + if isinstance(col, Col) + else copy.copy(col) + ) + nd_map[self] = cloned + + return cloned, nd_map, uuid_map + + def iter_subtree(self) -> Iterable[AstNode]: + yield from self.child.iter_subtree() + yield self + + def iter_col_roots(self) -> Iterable[ColExpr]: + return iter(()) + + def iter_col_nodes(self) -> Iterable[ColExpr]: + for col in self.iter_col_roots(): + yield from col.iter_subtree() + + def map_col_roots(self, g: Callable[[ColExpr], ColExpr]): ... + + def map_col_nodes(self, g: Callable[[ColExpr], ColExpr]): + self.map_col_roots(lambda root: root.map_subtree(g)) + + +@dataclasses.dataclass(eq=False, slots=True) +class Select(Verb): + select: list[Col] + + def iter_col_roots(self) -> Iterable[ColExpr]: + yield from self.select + + def map_col_roots(self, g: Callable[[ColExpr], ColExpr]): + self.select = [g(col) for col in self.select] + + +@dataclasses.dataclass(eq=False, slots=True) +class Drop(Verb): + drop: list[Col] + + def iter_col_roots(self) -> Iterable[ColExpr]: + yield from self.drop + + def map_col_roots(self, g: Callable[[ColExpr], ColExpr]): + self.drop = [g(col) for col in self.drop] + + +@dataclasses.dataclass(eq=False, slots=True) +class Rename(Verb): + name_map: dict[str, str] + + +@dataclasses.dataclass(eq=False, slots=True) +class Mutate(Verb): + names: list[str] + values: list[ColExpr] + uuids: list[UUID] + + def iter_col_roots(self) -> Iterable[ColExpr]: + yield from self.values + + def map_col_roots(self, g: Callable[[ColExpr], ColExpr]): + self.values = [g(val) for val in self.values] + + def _clone(self) -> tuple[Verb, dict[AstNode, AstNode], dict[UUID, UUID]]: + cloned, nd_map, uuid_map = Verb._clone(self) + assert isinstance(cloned, Mutate) + cloned.uuids = [uuid.uuid1() for _ in self.names] + uuid_map.update( + {old_uid: new_uid for old_uid, new_uid in zip(self.uuids, cloned.uuids)} + ) + return cloned, nd_map, uuid_map + + +@dataclasses.dataclass(eq=False, slots=True) +class Filter(Verb): + filters: list[ColExpr] + + def iter_col_roots(self) -> Iterable[ColExpr]: + yield from self.filters + + def map_col_roots(self, g: Callable[[ColExpr], ColExpr]): + self.filters = [g(predicate) for predicate in self.filters] + + +@dataclasses.dataclass(eq=False, slots=True) +class Summarise(Verb): + names: list[str] + values: list[ColExpr] + uuids: list[UUID] + + def iter_col_roots(self) -> Iterable[ColExpr]: + yield from self.values + + def map_col_roots(self, g: Callable[[ColExpr], ColExpr]): + self.values = [g(val) for val in self.values] + + def _clone(self) -> tuple[Verb, dict[AstNode, AstNode], dict[UUID, UUID]]: + cloned, nd_map, uuid_map = Verb._clone(self) + assert isinstance(cloned, Summarise) + cloned.uuids = [uuid.uuid1() for _ in self.names] + uuid_map.update( + {old_uid: new_uid for old_uid, new_uid in zip(self.uuids, cloned.uuids)} + ) + return cloned, nd_map, uuid_map + + +@dataclasses.dataclass(eq=False, slots=True) +class Arrange(Verb): + order_by: list[Order] + + def iter_col_roots(self) -> Iterable[ColExpr]: + yield from (ord.order_by for ord in self.order_by) + + def map_col_roots(self, g: Callable[[ColExpr], ColExpr]): + self.order_by = [ + Order(g(ord.order_by), ord.descending, ord.nulls_last) + for ord in self.order_by + ] + + +@dataclasses.dataclass(eq=False, slots=True) +class SliceHead(Verb): + n: int + offset: int + + +@dataclasses.dataclass(eq=False, slots=True) +class GroupBy(Verb): + group_by: list[Col] + add: bool + + def iter_col_roots(self) -> Iterable[ColExpr]: + yield from self.group_by + + def map_col_roots(self, g: Callable[[ColExpr], ColExpr]): + self.group_by = [g(col) for col in self.group_by] + + +@dataclasses.dataclass(eq=False, slots=True) +class Ungroup(Verb): ... + + +@dataclasses.dataclass(eq=False, slots=True) +class Join(Verb): + right: AstNode + on: ColExpr + how: JoinHow + validate: JoinValidate + suffix: str + + def _clone(self) -> tuple[Join, dict[AstNode, AstNode], dict[UUID, UUID]]: + child, nd_map, uuid_map = self.child._clone() + right_child, right_nd_map, right_uuid_map = self.right._clone() + nd_map.update(right_nd_map) + uuid_map.update(right_uuid_map) + + cloned = copy.copy(self) + cloned.child = child + cloned.right = right_child + cloned.on = self.on.map_subtree( + lambda col: Col( + col.name, nd_map[col._ast], uuid_map[col._uuid], col._dtype, col._ftype + ) + if isinstance(col, Col) + else copy.copy(col) + ) + + nd_map[self] = cloned + return cloned, nd_map, uuid_map + + def iter_subtree(self) -> Iterable[AstNode]: + yield from self.child.iter_subtree() + yield from self.right.iter_subtree() + yield self + + def iter_col_roots(self) -> Iterable[ColExpr]: + yield self.on + + def map_col_roots(self, g: Callable[[ColExpr], ColExpr]): + self.on = g(self.on) diff --git a/tests/test_backend_equivalence/conftest.py b/tests/test_backend_equivalence/conftest.py index 503a5338..da619439 100644 --- a/tests/test_backend_equivalence/conftest.py +++ b/tests/test_backend_equivalence/conftest.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime +from datetime import date, datetime import polars as pl import pytest @@ -13,6 +13,7 @@ { "col1": [1, 2, 3, 4], "col2": ["a", "baa", "c", "d"], + "cnull": [None, 2, None, None], } ), "df2": pl.DataFrame( @@ -94,6 +95,28 @@ datetime(2004, 12, 31, 23, 59, 59, 456_789), datetime(1970, 1, 1), ], + "cdate": [ + date(2017, 3, 2), + date(1998, 1, 12), + date(1999, 12, 31), + date(2024, 9, 23), + date(2018, 8, 13), + None, + date(2010, 5, 1), + date(2016, 2, 27), + date(2000, 1, 1), + ], + # "cdur": [ + # None, + # timedelta(1, 4, 2, 5), + # timedelta(0, 11, 14, 10000), + # timedelta(12, 2, 3), + # timedelta(4, 3, 1, 2, 3, 4), + # timedelta(0, 0, 0, 0, 1), + # timedelta(0, 1, 0, 1, 0, 1), + # None, + # timedelta(), + # ], } ), } diff --git a/tests/test_backend_equivalence/test_arrange.py b/tests/test_backend_equivalence/test_arrange.py index 140ddccb..a3c507f2 100644 --- a/tests/test_backend_equivalence/test_arrange.py +++ b/tests/test_backend_equivalence/test_arrange.py @@ -1,7 +1,7 @@ from __future__ import annotations from pydiverse.transform import C -from pydiverse.transform.core.verbs import ( +from pydiverse.transform.pipe.verbs import ( arrange, mutate, ) @@ -9,12 +9,14 @@ def test_noop(df1): - assert_result_equal(df1, lambda t: t >> arrange()) + assert_result_equal( + df1, lambda t: t >> arrange(), may_throw=True, exception=TypeError + ) def test_arrange(df2): - assert_result_equal(df2, lambda t: t >> arrange(t.col1)) - assert_result_equal(df2, lambda t: t >> arrange(-t.col1)) + assert_result_equal(df2, lambda t: t >> arrange(t.col1.ascending())) + assert_result_equal(df2, lambda t: t >> arrange((-t.col1).descending())) assert_result_equal(df2, lambda t: t >> arrange(t.col3)) assert_result_equal(df2, lambda t: t >> arrange(-t.col3)) @@ -55,7 +57,7 @@ def test_nulls_first(df4): lambda t: t >> arrange( t.col1.nulls_first(), - -t.col2.nulls_first(), + t.col2.descending().nulls_first(), t.col5.nulls_first(), ), check_row_order=True, @@ -68,7 +70,7 @@ def test_nulls_last(df4): lambda t: t >> arrange( t.col1.nulls_last(), - -t.col2.nulls_last(), + t.col2.nulls_last().descending(), t.col5.nulls_last(), ), check_row_order=True, @@ -81,8 +83,8 @@ def test_nulls_first_last_mixed(df4): lambda t: t >> arrange( t.col1.nulls_first(), - -t.col2.nulls_last(), - -t.col5, + t.col2.nulls_last().descending(), + t.col5.descending().nulls_last(), ), check_row_order=True, ) @@ -91,6 +93,8 @@ def test_nulls_first_last_mixed(df4): def test_arrange_after_mutate(df4): assert_result_equal( df4, - lambda t: t >> mutate(x=t.col1 <= t.col2) >> arrange(C.x, C.col4), + lambda t: t + >> mutate(x=t.col1 <= t.col2) + >> arrange(C.x.nulls_last(), C.col4.nulls_first()), check_row_order=True, ) diff --git a/tests/test_backend_equivalence/test_dtypes.py b/tests/test_backend_equivalence/test_dtypes.py new file mode 100644 index 00000000..96f67a51 --- /dev/null +++ b/tests/test_backend_equivalence/test_dtypes.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from pydiverse.transform.pipe.verbs import alias, filter, inner_join, mutate +from tests.util.assertion import assert_result_equal + + +def test_dtypes(df1): + assert_result_equal( + df1, + lambda t: t + >> filter(t.col1 % 2 == 1) + >> inner_join(s := t >> mutate(u=t.col1 % 2) >> alias(), t.col1 == s.u), + ) diff --git a/tests/test_backend_equivalence/test_filter.py b/tests/test_backend_equivalence/test_filter.py index c09aeffe..004b3510 100644 --- a/tests/test_backend_equivalence/test_filter.py +++ b/tests/test_backend_equivalence/test_filter.py @@ -1,7 +1,7 @@ from __future__ import annotations from pydiverse.transform import C -from pydiverse.transform.core.verbs import ( +from pydiverse.transform.pipe.verbs import ( filter, mutate, ) @@ -9,7 +9,9 @@ def test_noop(df2): - assert_result_equal(df2, lambda t: t >> filter()) + assert_result_equal( + df2, lambda t: t >> filter(), may_throw=True, exception=TypeError + ) assert_result_equal(df2, lambda t: t >> filter(t.col1 == t.col1)) @@ -44,9 +46,17 @@ def test_filter_isin(df4): lambda t: t >> filter( C.col1.isin(0, 2), + C.col2.isin(0, t.col1 * t.col2), ), ) + assert_result_equal( + df4, + lambda t: t >> filter((-(t.col4 // 2 - 1)).isin(1, 4, t.col1 + t.col2)), + ) + + assert_result_equal(df4, lambda t: t >> filter(t.col1.isin(None))) + assert_result_equal( df4, lambda t: t diff --git a/tests/test_backend_equivalence/test_group_by.py b/tests/test_backend_equivalence/test_group_by.py index 45f1cff0..29205f4b 100644 --- a/tests/test_backend_equivalence/test_group_by.py +++ b/tests/test_backend_equivalence/test_group_by.py @@ -3,8 +3,8 @@ import pytest from pydiverse.transform import C -from pydiverse.transform.core import functions -from pydiverse.transform.core.verbs import ( +from pydiverse.transform.pipe import functions +from pydiverse.transform.pipe.verbs import ( arrange, filter, group_by, @@ -14,7 +14,7 @@ select, ungroup, ) -from tests.util import assert_result_equal, full_sort +from tests.util import assert_result_equal def test_ungroup(df3): @@ -44,7 +44,7 @@ def test_mutate(df3, df4): lambda t, u: t >> group_by(t.col1, t.col2) >> mutate(col1=t.col1 * t.col2) - >> arrange(-t.col3.nulls_last()) + >> arrange(t.col3.descending().nulls_last()) >> ungroup() >> left_join(u, t.col2 == u.col2) >> mutate( @@ -83,8 +83,8 @@ def test_ungrouped_join(df1, df3, how): lambda t, u: t >> group_by(t.col1) >> ungroup() - >> join(u, t.col1 == u.col1, how=how) - >> full_sort(), + >> join(u, t.col1 == u.col1, how=how), + check_row_order=False, ) diff --git a/tests/test_backend_equivalence/test_join.py b/tests/test_backend_equivalence/test_join.py index 87896c3d..0b1b2cd1 100644 --- a/tests/test_backend_equivalence/test_join.py +++ b/tests/test_backend_equivalence/test_join.py @@ -4,8 +4,8 @@ import pytest -from pydiverse.transform.core.expressions.lambda_getter import C -from pydiverse.transform.core.verbs import ( +from pydiverse.transform.pipe.c import C +from pydiverse.transform.pipe.verbs import ( alias, join, left_join, @@ -13,7 +13,7 @@ outer_join, select, ) -from tests.util import assert_result_equal, full_sort +from tests.util import assert_result_equal @pytest.mark.parametrize( @@ -65,14 +65,15 @@ def test_join(df1, df2, how): def test_join_and_select(df1, df2, how): assert_result_equal( (df1, df2), - lambda t, u: t >> select() >> join(u, t.col1 == u.col1, how=how) >> full_sort(), + lambda t, u: t >> select() >> join(u, t.col1 == u.col1, how=how), + check_row_order=False, ) assert_result_equal( (df1, df2), lambda t, u: t - >> join(u >> select(), (t.col1 == u.col1) & (t.col1 == u.col2), how=how) - >> full_sort(), + >> join(u >> select(), (t.col1 == u.col1) & (t.col1 == u.col2), how=how), + check_row_order=False, ) @@ -100,22 +101,18 @@ def test_self_join(df3, how): def self_join_1(t): u = t >> alias("self_join") - return t >> join(u, t.col1 == u.col1, how=how) >> full_sort() + return t >> join(u, t.col1 == u.col1, how=how) - assert_result_equal(df3, self_join_1) + assert_result_equal(df3, self_join_1, check_row_order=False) def self_join_2(t): u = t >> alias("self_join") - return ( - t - >> join(u, (t.col1 == u.col1) & (t.col2 == u.col2), how=how) - >> full_sort() - ) + return t >> join(u, (t.col1 == u.col1) & (t.col2 == u.col2), how=how) - assert_result_equal(df3, self_join_2) + assert_result_equal(df3, self_join_2, check_row_order=False) def self_join_3(t): u = t >> alias("self_join") - return t >> join(u, (t.col2 == u.col3), how=how) >> full_sort() + return t >> join(u, (t.col2 == u.col3), how=how) - assert_result_equal(df3, self_join_3) + assert_result_equal(df3, self_join_3, check_row_order=False) diff --git a/tests/test_backend_equivalence/test_mutate.py b/tests/test_backend_equivalence/test_mutate.py index 07cd9c48..11f88976 100644 --- a/tests/test_backend_equivalence/test_mutate.py +++ b/tests/test_backend_equivalence/test_mutate.py @@ -1,11 +1,10 @@ from __future__ import annotations from pydiverse.transform import C -from pydiverse.transform.core.verbs import ( +from pydiverse.transform.pipe.verbs import ( mutate, select, ) -from pydiverse.transform.errors import ExpressionTypeError from tests.util import assert_result_equal @@ -39,9 +38,7 @@ def test_literals(df1): def test_none(df4): - assert_result_equal( - df4, lambda t: t >> mutate(x=None), exception=ExpressionTypeError - ) + assert_result_equal(df4, lambda t: t >> mutate(x=None)) assert_result_equal( df4, lambda t: t diff --git a/tests/test_backend_equivalence/test_ops/test_case_expression.py b/tests/test_backend_equivalence/test_ops/test_case_expression.py index fb0d9321..180648c7 100644 --- a/tests/test_backend_equivalence/test_ops/test_case_expression.py +++ b/tests/test_backend_equivalence/test_ops/test_case_expression.py @@ -1,13 +1,13 @@ from __future__ import annotations +import pydiverse.transform as pdt from pydiverse.transform import C -from pydiverse.transform import functions as f -from pydiverse.transform.core.verbs import ( +from pydiverse.transform.errors import FunctionTypeError +from pydiverse.transform.pipe.verbs import ( group_by, mutate, summarise, ) -from pydiverse.transform.errors import ExpressionTypeError, FunctionTypeError from tests.util import assert_result_equal @@ -16,16 +16,7 @@ def test_mutate_case_ewise(df4): df4, lambda t: t >> mutate( - x=C.col1.case( - (0, 1), - (1, 2), - (2, 2), - ), - y=C.col1.case( - (0, 0), - (1, None), - default=10.5, - ), + x=C.col1.map({0: 1, (1, 2): 2}), y=C.col1.map({0: 0, 1: None}, default=10.4) ), ) @@ -33,11 +24,11 @@ def test_mutate_case_ewise(df4): df4, lambda t: t >> mutate( - x=f.case( - (C.col1 == C.col2, 1), - (C.col2 == C.col3, 2), - default=(C.col1 + C.col2), - ) + x=pdt.when(C.col1 == C.col2) + .then(1) + .when(C.col2 == C.col3) + .then(2) + .otherwise(C.col1 + C.col2), ), ) @@ -47,12 +38,14 @@ def test_mutate_case_window(df4): df4, lambda t: t >> mutate( - x=f.case( - (C.col1.max() == 1, 1), - (C.col1.max() == 2, 2), - (C.col1.max() == 3, 3), - (C.col1.max() == 4, 4), - ) + x=pdt.when(C.col1.max() == 1) + .then(1) + .when(C.col1.max() == 2) + .then(2) + .when(C.col1.max() == 3) + .then(3) + .when(C.col1.max() == 4) + .then(4) ), ) @@ -60,10 +53,18 @@ def test_mutate_case_window(df4): df4, lambda t: t >> mutate( - u=C.col1.shift(1, 1729, arrange=[-t.col3, t.col4]), - x=C.col1.shift(1, 0, arrange=[C.col4]).case( - (1, C.col2.shift(1, -1, arrange=[C.col2, C.col4])), - (2, C.col3.shift(2, -2, arrange=[C.col3, C.col4])), + u=C.col1.shift( + 1, 1729, arrange=[t.col3.descending().nulls_last(), t.col4.nulls_last()] + ), + x=C.col1.shift(1, 0, arrange=[C.col4.nulls_first()]).map( + { + 1: C.col2.shift( + 1, -1, arrange=[C.col2.nulls_last(), C.col4.nulls_first()] + ), + 2: C.col3.shift( + 2, -2, arrange=[C.col3.nulls_last(), C.col4.nulls_last()] + ), + } ), ), ) @@ -73,12 +74,14 @@ def test_mutate_case_window(df4): df4, lambda t: t >> mutate( - x=C.col1.shift(1, 0, arrange=[C.col4]) - .case( - (1, 2), - (2, 3), + x=C.col1.shift(1, 0, arrange=[C.col4.nulls_last()]) + .map( + { + 1: 2, + 2: 3, + } ) - .shift(1, -1, arrange=[-C.col4]) + .shift(1, -1, arrange=[C.col4.descending().nulls_first()]) ), may_throw=True, ) @@ -92,46 +95,35 @@ def test_summarise_case(df4): C.col1, ) >> summarise( - x=C.col2.max().case( - (0, C.col1.min()), # Int - (1, C.col2.mean() + 0.5), # Float - (2, 2), # ftype=EWISE - ), - y=f.case( - (C.col2.max() > 2, 1), - (C.col2.max() < 2, C.col2.min()), - default=C.col3.mean(), + x=C.col2.max().map( + { + 0: C.col1.min(), + 1: C.col2.mean() + 0.5, + 2: 2, + } ), + y=pdt.when(C.col2.max() > 2) + .then(1) + .when(C.col2.max() < 2) + .then(C.col2.min()) + .otherwise(C.col3.mean()), ), ) def test_invalid_value_dtype(df4): - # Incompatible types String and Float - assert_result_equal( - df4, - lambda t: t - >> mutate( - x=C.col1.case( - (0, "a"), - (1, 1.1), - ) - ), - exception=ExpressionTypeError, - ) - - -def test_invalid_result_dtype(df4): - # Invalid result type: none assert_result_equal( df4, lambda t: t >> mutate( - x=f.case( - default=None, + x=C.col1.map( + { + 0: "a", + 1: 1.1, + } ) ), - exception=ExpressionTypeError, + exception=TypeError, ) @@ -140,8 +132,10 @@ def test_invalid_ftype(df1): df1, lambda t: t >> summarise( - x=f.rank(arrange=[C.col1]).case( - (1, C.col1.max()), + x=pdt.rank(arrange=[C.col1]).map( + { + 1: C.col1.max(), + }, default=None, ) ), @@ -152,10 +146,7 @@ def test_invalid_ftype(df1): df1, lambda t: t >> summarise( - x=f.case( - (f.rank(arrange=[C.col1]) == 1, 1), - default=None, - ) + x=pdt.when(pdt.rank(arrange=[C.col1]) == 1).then(1).otherwise(None) ), exception=FunctionTypeError, ) diff --git a/tests/test_backend_equivalence/test_ops/test_functions.py b/tests/test_backend_equivalence/test_ops/test_functions.py index acfcb1a2..f82d31aa 100644 --- a/tests/test_backend_equivalence/test_ops/test_functions.py +++ b/tests/test_backend_equivalence/test_ops/test_functions.py @@ -1,22 +1,32 @@ from __future__ import annotations +import pydiverse.transform as pdt from pydiverse.transform import C -from pydiverse.transform import functions as f -from pydiverse.transform.core.verbs import mutate +from pydiverse.transform.pipe.verbs import mutate +from pydiverse.transform.tree.col_expr import LiteralCol from tests.fixtures.backend import skip_backends from tests.util import assert_result_equal def test_count(df4): assert_result_equal( - df4, lambda t: t >> mutate(**{col._.name + "_count": f.count(col) for col in t}) + df4, + lambda t: t + >> mutate(**{col.name + "_count": pdt.count(col) for col in t}) + >> mutate(o=LiteralCol(0).count(filter=t.col3 == 2)) + >> mutate(u=pdt.count(), v=pdt.count(filter=t.col4 > 0)), ) def test_row_number(df4): assert_result_equal( df4, - lambda t: t >> mutate(row_number=f.row_number(arrange=[-C.col1, C.col5])), + lambda t: t + >> mutate( + row_number=pdt.row_number( + arrange=[C.col1.descending().nulls_first(), C.col5.nulls_last()] + ) + ), ) @@ -28,14 +38,14 @@ def test_min(df4): df4, lambda t: t >> mutate( - int1=f.min(C.col1 + 2, C.col2, 9), - int2=f.min(C.col1 * C.col2, 0), - int3=f.min(C.col1 * C.col2, C.col2 * C.col3, 2 - C.col3), - int4=f.min(C.col1), - float1=f.min(C.col1, 1.5), - float2=f.min(1, C.col1 + 1.5, C.col2, 2.2), - str1=f.min(C.col5, "c"), - str2=f.min(C.col5, "C"), + int1=pdt.min(C.col1 + 2, C.col2, 9), + int2=pdt.min(C.col1 * C.col2, 0), + int3=pdt.min(C.col1 * C.col2, C.col2 * C.col3, 2 - C.col3), + int4=pdt.min(C.col1), + float1=pdt.min(C.col1, 1.5), + float2=pdt.min(1, C.col1 + 1.5, C.col2, 2.2), + str1=pdt.min(C.col5, "c"), + str2=pdt.min(C.col5, "C"), ), ) @@ -48,13 +58,13 @@ def test_max(df4): df4, lambda t: t >> mutate( - int1=f.max(C.col1 + 2, C.col2, 9), - int2=f.max(C.col1 * C.col2, 0), - int3=f.max(C.col1 * C.col2, C.col2 * C.col3, 2 - C.col3), - int4=f.max(C.col1), - float1=f.max(C.col1, 1.5), - float2=f.max(1, C.col1 + 1.5, C.col2, 2.2), - str1=f.max(C.col5, "c"), - str2=f.max(C.col5, "C"), + int1=pdt.max(C.col1 + 2, C.col2, 9), + int2=pdt.max(C.col1 * C.col2, 0), + int3=pdt.max(C.col1 * C.col2, C.col2 * C.col3, 2 - C.col3), + int4=pdt.max(C.col1), + float1=pdt.max(C.col1, 1.5), + float2=pdt.max(1, C.col1 + 1.5, C.col2, 2.2), + str1=pdt.max(C.col5, "c"), + str2=pdt.max(C.col5, "C"), ), ) diff --git a/tests/test_backend_equivalence/test_ops/test_ops_datetime.py b/tests/test_backend_equivalence/test_ops/test_ops_datetime.py index 7c2bbebb..8bd74f4a 100644 --- a/tests/test_backend_equivalence/test_ops/test_ops_datetime.py +++ b/tests/test_backend_equivalence/test_ops/test_ops_datetime.py @@ -3,7 +3,7 @@ from datetime import datetime from pydiverse.transform import C -from pydiverse.transform.core.verbs import ( +from pydiverse.transform.pipe.verbs import ( filter, mutate, ) @@ -65,6 +65,7 @@ def test_year(df_datetime): >> mutate( x=C.col1.dt.year(), y=C.col2.dt.year(), + z=t.cdate.dt.year(), ), ) @@ -76,6 +77,7 @@ def test_month(df_datetime): >> mutate( x=C.col1.dt.month(), y=C.col2.dt.month(), + z=t.cdate.dt.month(), ), ) @@ -83,11 +85,7 @@ def test_month(df_datetime): def test_day(df_datetime): assert_result_equal( df_datetime, - lambda t: t - >> mutate( - x=C.col1.dt.day(), - y=C.col2.dt.day(), - ), + lambda t: t >> mutate(x=C.col1.dt.day(), y=C.col2.dt.day(), z=t.cdate.dt.day()), ) @@ -101,6 +99,12 @@ def test_hour(df_datetime): ), ) + assert_result_equal( + df_datetime, + lambda t: t >> mutate(z=t.cdate.dt.hour()), + exception=ValueError, + ) + def test_minute(df_datetime): assert_result_equal( @@ -155,3 +159,11 @@ def test_day_of_year(df_datetime): y=C.col2.dt.day_of_year(), ), ) + + +def test_duration_add(df_datetime): + assert_result_equal(df_datetime, lambda t: t >> mutate(z=t.cdur + t.cdur)) + + +def test_dt_subtract(df_datetime): + assert_result_equal(df_datetime, lambda t: t >> mutate(z=t.col1 - t.col2)) diff --git a/tests/test_backend_equivalence/test_ops/test_ops_string.py b/tests/test_backend_equivalence/test_ops/test_ops_string.py index 77eacd9d..90b1f897 100644 --- a/tests/test_backend_equivalence/test_ops/test_ops_string.py +++ b/tests/test_backend_equivalence/test_ops/test_ops_string.py @@ -1,7 +1,7 @@ from __future__ import annotations from pydiverse.transform import C -from pydiverse.transform.core.verbs import ( +from pydiverse.transform.pipe.verbs import ( filter, mutate, ) diff --git a/tests/test_backend_equivalence/test_rename.py b/tests/test_backend_equivalence/test_rename.py index da831d9f..3febfc92 100644 --- a/tests/test_backend_equivalence/test_rename.py +++ b/tests/test_backend_equivalence/test_rename.py @@ -1,13 +1,15 @@ from __future__ import annotations -from pydiverse.transform.core.verbs import ( +from pydiverse.transform.pipe.verbs import ( rename, ) from tests.util import assert_result_equal def test_noop(df3): - assert_result_equal(df3, lambda t: t >> rename({})) + assert_result_equal( + df3, lambda t: t >> rename({}), may_throw=True, exception=TypeError + ) assert_result_equal(df3, lambda t: t >> rename({"col1": "col1"})) diff --git a/tests/test_backend_equivalence/test_select.py b/tests/test_backend_equivalence/test_select.py index 3fc0fc7a..f246a174 100644 --- a/tests/test_backend_equivalence/test_select.py +++ b/tests/test_backend_equivalence/test_select.py @@ -1,8 +1,6 @@ from __future__ import annotations -from pydiverse.transform import C -from pydiverse.transform.core.verbs import ( - mutate, +from pydiverse.transform.pipe.verbs import ( select, ) from tests.util import assert_result_equal @@ -15,20 +13,3 @@ def test_simple_select(df1): def test_reorder(df1): assert_result_equal(df1, lambda t: t >> select(t.col2, t.col1)) - - -def test_ellipsis(df3): - assert_result_equal(df3, lambda t: t >> select(...)) - assert_result_equal(df3, lambda t: t >> select(t.col1) >> select(...)) - assert_result_equal( - df3, lambda t: t >> mutate(x=t.col1 * 2) >> select() >> select(...) - ) - - -def test_negative_select(df3): - assert_result_equal(df3, lambda t: t >> select(-t.col1)) - assert_result_equal(df3, lambda t: t >> select(-C.col1, -t.col2)) - assert_result_equal( - df3, - lambda t: t >> select() >> mutate(x=t.col1 * 2) >> select(-C.col3), - ) diff --git a/tests/test_backend_equivalence/test_slice_head.py b/tests/test_backend_equivalence/test_slice_head.py index 84f2da1a..195206d7 100644 --- a/tests/test_backend_equivalence/test_slice_head.py +++ b/tests/test_backend_equivalence/test_slice_head.py @@ -1,8 +1,8 @@ from __future__ import annotations from pydiverse.transform import C -from pydiverse.transform.core import functions as f -from pydiverse.transform.core.verbs import ( +from pydiverse.transform.pipe import functions as f +from pydiverse.transform.pipe.verbs import ( arrange, filter, group_by, @@ -12,7 +12,7 @@ slice_head, summarise, ) -from tests.util import assert_result_equal, full_sort +from tests.util import assert_result_equal def test_simple(df3): @@ -61,13 +61,6 @@ def test_chained(df3): ) -def test_with_select(df3): - assert_result_equal( - df3, - lambda t: t >> select() >> arrange(*t) >> slice_head(4, offset=2) >> select(*t), - ) - - def test_with_mutate(df3): assert_result_equal( df3, @@ -83,18 +76,17 @@ def test_with_join(df1, df2): assert_result_equal( (df1, df2), lambda t, u: t - >> full_sort() >> arrange(*t) >> slice_head(3) - >> left_join(u, t.col1 == u.col1) - >> full_sort(), + >> left_join(u, t.col1 == u.col1), + check_row_order=False, ) assert_result_equal( (df1, df2), lambda t, u: t - >> left_join(u >> arrange(*t) >> slice_head(2, offset=1), t.col1 == u.col1) - >> full_sort(), + >> left_join(u >> arrange(*t) >> slice_head(2, offset=1), t.col1 == u.col1), + check_row_order=False, exception=ValueError, may_throw=True, ) diff --git a/tests/test_backend_equivalence/test_summarise.py b/tests/test_backend_equivalence/test_summarise.py index a92667e0..778f732e 100644 --- a/tests/test_backend_equivalence/test_summarise.py +++ b/tests/test_backend_equivalence/test_summarise.py @@ -1,7 +1,8 @@ from __future__ import annotations from pydiverse.transform import C -from pydiverse.transform.core.verbs import ( +from pydiverse.transform.errors import FunctionTypeError +from pydiverse.transform.pipe.verbs import ( arrange, filter, group_by, @@ -9,7 +10,6 @@ select, summarise, ) -from pydiverse.transform.errors import ExpressionTypeError, FunctionTypeError from tests.util import assert_result_equal @@ -94,24 +94,24 @@ def test_filter(df3): ) -# def test_filter_argument(df3): -# assert_result_equal( -# df3, -# lambda t: t -# >> group_by(t.col2) -# >> summarise(u=t.col4.sum(filter=(t.col1 != 0))), -# ) +def test_filter_argument(df3): + assert_result_equal( + df3, + lambda t: t + >> group_by(t.col2) + >> summarise(u=t.col4.sum(filter=(t.col1 != 0))), + ) -# assert_result_equal( -# df3, -# lambda t: t -# >> group_by(t.col4, t.col1) -# >> summarise( -# u=(t.col3 * t.col4 - t.col2).sum( -# filter=(t.col5.isin("a", "e", "i", "o", "u")) -# ) -# ), -# ) + assert_result_equal( + df3, + lambda t: t + >> group_by(t.col4, t.col1) + >> summarise( + u=(t.col3 * t.col4 - t.col2).sum( + filter=(t.col5.isin("a", "e", "i", "o", "u")) + ) + ), + ) def test_arrange(df3): @@ -153,9 +153,7 @@ def test_not_summarising(df4): def test_none(df4): - assert_result_equal( - df4, lambda t: t >> summarise(x=None), exception=ExpressionTypeError - ) + assert_result_equal(df4, lambda t: t >> summarise(x=None)) # TODO: Implement more test cases for summarise verb @@ -169,7 +167,7 @@ def test_op_min(df4): df4, lambda t: t >> group_by(t.col1) - >> summarise(**{c._.name + "_min": c.min() for c in t}), + >> summarise(**{c.name + "_min": c.min() for c in t}), ) @@ -178,7 +176,7 @@ def test_op_max(df4): df4, lambda t: t >> group_by(t.col1) - >> summarise(**{c._.name + "_max": c.max() for c in t}), + >> summarise(**{c.name + "_max": c.max() for c in t}), ) @@ -202,3 +200,16 @@ def test_op_all(df4): df4, lambda t: t >> group_by(t.col1) >> mutate(all=(C.col2 != C.col3).all()), ) + + +def test_group_cols_in_agg(df3): + assert_result_equal( + df3, + lambda t: t >> group_by(t.col1, t.col2) >> summarise(u=t.col1 + t.col2), + ) + + assert_result_equal( + df3, + lambda t: t >> group_by(t.col1, t.col2) >> summarise(u=t.col1 + t.col3), + exception=FunctionTypeError, + ) diff --git a/tests/test_backend_equivalence/test_syntax.py b/tests/test_backend_equivalence/test_syntax.py index 2ea9f873..c26aee02 100644 --- a/tests/test_backend_equivalence/test_syntax.py +++ b/tests/test_backend_equivalence/test_syntax.py @@ -1,7 +1,7 @@ from __future__ import annotations from pydiverse.transform import C -from pydiverse.transform.core.verbs import ( +from pydiverse.transform.pipe.verbs import ( mutate, select, ) @@ -11,16 +11,4 @@ def test_lambda_cols(df3): assert_result_equal(df3, lambda t: t >> select(C.col1, C.col2)) assert_result_equal(df3, lambda t: t >> mutate(col1=C.col1, col2=C.col1)) - assert_result_equal(df3, lambda t: t >> select(C.col10), exception=ValueError) - - -def test_columns_pipeable(df3): - assert_result_equal(df3, lambda t: t.col1 >> mutate(x=t.col1)) - - # Test invalid operations - assert_result_equal(df3, lambda t: t.col1 >> mutate(x=t.col2), exception=ValueError) - - assert_result_equal(df3, lambda t: t.col1 >> mutate(x=C.col2), exception=ValueError) - - assert_result_equal(df3, lambda t: (t.col1 + 1) >> select(), exception=TypeError) diff --git a/tests/test_backend_equivalence/test_window_function.py b/tests/test_backend_equivalence/test_window_function.py index 7ffece39..3cc27b0c 100644 --- a/tests/test_backend_equivalence/test_window_function.py +++ b/tests/test_backend_equivalence/test_window_function.py @@ -1,18 +1,19 @@ from __future__ import annotations from pydiverse.transform import C -from pydiverse.transform.core import functions as f -from pydiverse.transform.core.verbs import ( +from pydiverse.transform.errors import FunctionTypeError +from pydiverse.transform.pipe import functions as f +from pydiverse.transform.pipe.verbs import ( arrange, filter, group_by, + join, mutate, select, summarise, ungroup, ) -from pydiverse.transform.errors import FunctionTypeError -from tests.util import assert_result_equal, full_sort +from tests.util import assert_result_equal def test_simple_ungrouped(df3): @@ -38,20 +39,31 @@ def test_simple_grouped(df3): ) -def test_partition_by_argument(df3): +def test_partition_by_argument(df3, df4): assert_result_equal( df3, lambda t: t >> mutate( - u=t.col1.min(partition_by=[t.col3]), - v=t.col4.sum(partition_by=[t.col2]), - w=f.rank(arrange=[-t.col5, t.col4], partition_by=[t.col2]), + u=t.col1.min(partition_by=t.col3), + v=t.col4.sum(partition_by=t.col2), + w=f.rank( + arrange=[t.col5.descending().nulls_last(), t.col4.nulls_first()], + partition_by=[t.col2], + ), x=f.row_number( arrange=[t.col4.nulls_last()], partition_by=[t.col1, t.col2] ), ), ) + assert_result_equal( + (df3, df4), + lambda t, u: t + >> join(u, t.col1 == u.col3, how="left") + >> group_by(t.col2) + >> mutate(y=(u.col3 + t.col1).max(partition_by=(col for col in t))), + ) + assert_result_equal( df3, lambda t: t @@ -203,7 +215,6 @@ def test_arrange_argument(df3): lambda t: t >> group_by(t.col1) >> mutate(x=C.col4.shift(1, arrange=[-C.col3])) - >> full_sort() >> select(C.x), ) @@ -212,25 +223,18 @@ def test_arrange_argument(df3): lambda t: t >> group_by(t.col2) >> mutate(x=f.row_number(arrange=[-C.col4])) - >> full_sort() >> select(C.x), ) # Ungrouped assert_result_equal( df3, - lambda t: t - >> mutate(x=C.col4.shift(1, arrange=[-C.col3])) - >> full_sort() - >> select(C.x), + lambda t: t >> mutate(x=C.col4.shift(1, arrange=[-C.col3])) >> select(C.x), ) assert_result_equal( df3, - lambda t: t - >> mutate(x=f.row_number(arrange=[-C.col4])) - >> full_sort() - >> select(C.x), + lambda t: t >> mutate(x=f.row_number(arrange=[-C.col4])) >> select(C.x), ) @@ -260,6 +264,18 @@ def test_complex(df3): >> arrange(C.span), ) + assert_result_equal( + df3, + lambda t: t + >> group_by(t.col1, t.col2) + >> summarise(mean3=t.col3.mean(), u=t.col4.max()) + >> group_by(C.u) + >> mutate(minM3=C.mean3.min(), maxM3=C.mean3.max()) + >> mutate(span=C.maxM3 - C.minM3) + >> filter(C.span < 3) + >> arrange(C.span), + ) + def test_nested_bool(df4): assert_result_equal( @@ -268,8 +284,8 @@ def test_nested_bool(df4): >> group_by(t.col1) >> mutate(x=t.col1 <= t.col2, y=(t.col3 * 4) >= C.col4) >> mutate( - xshift=C.x.shift(1, arrange=[t.col4]), - yshift=C.y.shift(-1, arrange=[t.col4]), + xshift=C.x.shift(1, arrange=[t.col4.nulls_last()]), + yshift=C.y.shift(-1, arrange=[t.col4.nulls_first()]), ) >> mutate(xAndY=C.x & C.y, xAndYshifted=C.xshift & C.yshift), ) @@ -284,10 +300,10 @@ def test_op_shift(df4): lambda t: t >> group_by(t.col1) >> mutate( - shift1=t.col2.shift(1, arrange=[t.col4]), - shift2=t.col4.shift(-2, 0, arrange=[t.col4]), - shift3=t.col4.shift(0, arrange=[t.col4]), - u=C.col1.shift(1, 0, arrange=[t.col4]), + shift1=t.col2.shift(1, arrange=[t.col4.nulls_first()]), + shift2=t.col4.shift(-2, 0, arrange=[t.col4.nulls_last()]), + shift3=t.col4.shift(0, arrange=[t.col4.nulls_first()]), + u=C.col1.shift(1, 0, arrange=[t.col4.nulls_last()]), ), ) @@ -295,8 +311,8 @@ def test_op_shift(df4): df4, lambda t: t >> mutate( - u=t.col1.shift(1, 0, arrange=[t.col2, t.col4]), - v=t.col1.shift(2, 1, arrange=[-t.col4.nulls_first()]), + u=t.col1.shift(1, 0, arrange=[t.col2.nulls_last(), t.col4.nulls_first()]), + v=t.col1.shift(2, 1, arrange=[t.col4.descending().nulls_first()]), ), ) @@ -307,8 +323,10 @@ def test_op_row_number(df4): lambda t: t >> group_by(t.col1) >> mutate( - row_number1=f.row_number(arrange=[-C.col4.nulls_last()]), - row_number2=f.row_number(arrange=[C.col2, C.col3, t.col4]), + row_number1=f.row_number(arrange=[C.col4.descending().nulls_last()]), + row_number2=f.row_number( + arrange=[C.col2.nulls_last(), C.col3.nulls_first(), t.col4.nulls_last()] + ), ), ) @@ -316,8 +334,10 @@ def test_op_row_number(df4): df4, lambda t: t >> mutate( - u=f.row_number(arrange=[-C.col4.nulls_last()]), - v=f.row_number(arrange=[-t.col3, t.col4]), + u=f.row_number(arrange=[C.col4.descending().nulls_last()]), + v=f.row_number( + arrange=[t.col3.descending().nulls_first(), t.col4.nulls_first()] + ), ), ) @@ -328,12 +348,12 @@ def test_op_rank(df4): lambda t: t >> group_by(t.col1) >> mutate( - rank1=f.rank(arrange=[t.col1]), - rank2=f.rank(arrange=[t.col2]), + rank1=f.rank(arrange=[t.col1.nulls_last()]), + rank2=f.rank(arrange=[t.col2.nulls_first()]), rank3=f.rank(arrange=[t.col2.nulls_last()]), rank4=f.rank(arrange=[t.col5.nulls_first()]), - rank5=f.rank(arrange=[-t.col5.nulls_first()]), - rank_expr=f.rank(arrange=[t.col3 - t.col2]), + rank5=f.rank(arrange=[t.col5.descending().nulls_first()]), + rank_expr=f.rank(arrange=[(t.col3 - t.col2).nulls_last()]), ), ) @@ -345,13 +365,15 @@ def test_op_dense_rank(df3): >> group_by(t.col1) >> mutate( rank1=f.dense_rank(arrange=[t.col5.nulls_first()]), - rank2=f.dense_rank(arrange=[t.col2]), + rank2=f.dense_rank(arrange=[t.col2.nulls_last()]), rank3=f.dense_rank(arrange=[t.col2.nulls_last()]), ) - >> ungroup(), - # TODO: activate these once SQL partition_by= is implemented - # >> mutate( - # rank4=f.dense_rank(arrange=[t.col4.nulls_first()], partition_by=[t.col2]), - # rank5=f.dense_rank(arrange=[-t.col5.nulls_first()], partition_by=[t.col2]), - # ), + >> ungroup() + >> mutate( + rank4=f.dense_rank(arrange=[t.col4.nulls_first()], partition_by=[t.col2]), + rank5=f.dense_rank( + arrange=[t.col5.descending().nulls_first()], + partition_by=[t.col2], + ), + ), ) diff --git a/tests/test_core.py b/tests/test_core.py index d9ad88f3..d2d6fafc 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -3,18 +3,11 @@ import pytest from pydiverse.transform import C -from pydiverse.transform.core import AbstractTableImpl, Table, dtypes -from pydiverse.transform.core.dispatchers import ( - col_to_table, +from pydiverse.transform.pipe.pipeable import ( inverse_partial, - unwrap_tables, verb, - wrap_tables, ) -from pydiverse.transform.core.expressions import Column, SymbolicExpression -from pydiverse.transform.core.expressions.translator import TypedValue -from pydiverse.transform.core.util import bidict, ordered_set, sign_peeler -from pydiverse.transform.core.verbs import ( +from pydiverse.transform.pipe.verbs import ( arrange, collect, filter, @@ -25,16 +18,6 @@ ) -@pytest.fixture -def tbl1(): - return Table(MockTableImpl("mock1", ["col1", "col2"])) - - -@pytest.fixture -def tbl2(): - return Table(MockTableImpl("mock2", ["col1", "col2", "col3"])) - - class TestTable: def test_getattr(self, tbl1): assert tbl1.col1._.name == "col1" @@ -44,16 +27,11 @@ def test_getattr(self, tbl1): _ = tbl1.colXXX def test_getitem(self, tbl1): - assert tbl1.col1._ == tbl1["col1"]._ - assert tbl1.col2._ == tbl1["col2"]._ - - assert tbl1.col2._ == tbl1[tbl1.col2]._ - assert tbl1.col2._ == tbl1[C.col2]._ + assert tbl1.col1 == tbl1["col1"] + assert tbl1.col2 == tbl1["col2"] - def test_setitem(self, tbl1): - tbl1["col1"] = 1 - tbl1[tbl1.col1] = 1 - tbl1[C.col1] = 1 + assert tbl1.col2 == tbl1[tbl1.col2] + assert tbl1.col2 == tbl1[C.col2] def test_iter(self, tbl1, tbl2): assert len(list(tbl1)) == len(list(tbl1._impl.selected_cols())) @@ -116,51 +94,6 @@ def subtract(v1, v2): assert 5 >> subtract(3) == 2 assert 5 >> add_10 >> subtract(5) == 10 - def test_col_to_table(self, tbl1): - assert col_to_table(15) == 15 - assert col_to_table(tbl1) == tbl1 - - c1_tbl = col_to_table(tbl1.col1._) - assert isinstance(c1_tbl, AbstractTableImpl) - assert c1_tbl.available_cols == {tbl1.col1._.uuid} - assert list(c1_tbl.named_cols.fwd) == ["col1"] - - def test_unwrap_tables(self): - impl_1 = MockTableImpl("impl_1", dict()) - impl_2 = MockTableImpl("impl_2", dict()) - tbl_1 = Table(impl_1) - tbl_2 = Table(impl_2) - - assert unwrap_tables(15) == 15 - assert unwrap_tables(impl_1) == impl_1 - assert unwrap_tables(tbl_1) == impl_1 - - assert unwrap_tables([tbl_1]) == [impl_1] - assert unwrap_tables([[tbl_1], tbl_2]) == [[impl_1], impl_2] - - assert unwrap_tables((tbl_1, tbl_2)) == (impl_1, impl_2) - assert unwrap_tables((tbl_1, (tbl_2, 15))) == (impl_1, (impl_2, 15)) - - assert unwrap_tables({tbl_1: tbl_1, 15: (15, tbl_2)}) == { - tbl_1: impl_1, - 15: (15, impl_2), - } - - def test_wrap_tables(self): - impl_1 = MockTableImpl("impl_1", dict()) - impl_2 = MockTableImpl("impl_2", dict()) - tbl_1 = Table(impl_1) - tbl_2 = Table(impl_2) - - assert wrap_tables(15) == 15 - assert wrap_tables(tbl_1) == tbl_1 - assert wrap_tables(impl_1) == tbl_1 - - assert wrap_tables([impl_1]) == [tbl_1] - assert wrap_tables([impl_1, [impl_2]]) == [tbl_1, [tbl_2]] - - assert wrap_tables((impl_1,)) == (tbl_1,) - class TestBuiltinVerbs: def test_collect(self, tbl1): @@ -352,114 +285,3 @@ def test_arrange(self, tbl1, tbl2): tbl1 >> arrange(tbl2.col1) with pytest.raises(ValueError): tbl1 >> arrange(tbl1.col1, -tbl2.col1) - - def test_col_pipeable(self, tbl1, tbl2): - result = tbl1.col1 >> mutate(x=tbl1.col1 * 2) - - assert result._impl.selects == ordered_set(["col1", "x"]) - assert list(result._impl.named_cols.fwd) == ["col1", "x"] - - with pytest.raises(TypeError): - (tbl1.col1 + 2) >> mutate(x=1) - - -class TestDataStructures: - def test_bidict(self): - d = bidict({"a": 1, "b": 2, "c": 3}) - - assert len(d) == 3 - assert tuple(d.fwd.keys()) == ("a", "b", "c") - assert tuple(d.fwd.values()) == (1, 2, 3) - - assert tuple(d.fwd.keys()) == tuple(d.bwd.values()) - assert tuple(d.bwd.keys()) == tuple(d.fwd.values()) - - d.fwd["d"] = 4 - d.bwd[4] = "x" - - assert tuple(d.fwd.keys()) == ("a", "b", "c", "x") - assert tuple(d.fwd.values()) == (1, 2, 3, 4) - assert tuple(d.fwd.keys()) == tuple(d.bwd.values()) - assert tuple(d.bwd.keys()) == tuple(d.fwd.values()) - - assert "x" in d.fwd - assert "d" not in d.fwd - - d.clear() - - assert len(d) == 0 - assert len(d.fwd.items()) == len(d.fwd) == 0 - assert len(d.bwd.items()) == len(d.bwd) == 0 - - with pytest.raises(ValueError): - bidict({"a": 1, "b": 1}) - - def test_ordered_set(self): - s = ordered_set([0, 1, 2]) - assert list(s) == [0, 1, 2] - - s.add(1) # Already in set -> Noop - assert list(s) == [0, 1, 2] - s.add(3) # Not in set -> add to the end - assert list(s) == [0, 1, 2, 3] - - s.remove(1) - assert list(s) == [0, 2, 3] - s.add(1) - assert list(s) == [0, 2, 3, 1] - - assert 1 in s - assert 4 not in s - assert len(s) == 4 - - s.clear() - assert len(s) == 0 - assert list(s) == [] - - # Set Operations - - s1 = ordered_set([0, 1, 2, 3]) - s2 = ordered_set([5, 4, 3, 2]) - - assert not s1.isdisjoint(s2) - assert list(s1 | s2) == [0, 1, 2, 3, 5, 4] - assert list(s1 ^ s2) == [0, 1, 5, 4] - assert list(s1 & s2) == [3, 2] - assert list(s1 - s2) == [0, 1] - - # Pop order - - s = ordered_set([0, 1, 2, 3]) - assert s.pop() == 0 - assert s.pop() == 1 - assert s.pop_back() == 3 - assert s.pop_back() == 2 - - -class TestUtil: - def test_sign_peeler(self): - x = object() - sx = SymbolicExpression(x) - assert sign_peeler(sx._) == (x, True) - assert sign_peeler((+sx)._) == (x, True) - assert sign_peeler((-sx)._) == (x, False) - assert sign_peeler((--sx)._) == (x, True) # noqa: B002 - assert sign_peeler((--+sx)._) == (x, True) # noqa: B002 - assert sign_peeler((-++--sx)._) == (x, False) # noqa: B002 - - -class MockTableImpl(AbstractTableImpl): - def __init__(self, name, col_names): - super().__init__( - name, {name: Column(name, self, dtypes.Int()) for name in col_names} - ) - - def resolve_lambda_cols(self, expr): - return expr - - def collect(self): - return list(self.selects) - - class ExpressionCompiler(AbstractTableImpl.ExpressionCompiler): - def _translate(self, expr, **kwargs): - return TypedValue(None, dtypes.Int()) diff --git a/tests/test_expressions.py b/tests/test_expressions.py deleted file mode 100644 index 92baa69a..00000000 --- a/tests/test_expressions.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -import pytest - -from pydiverse.transform import C -from pydiverse.transform.core.expressions import FunctionCall, SymbolicExpression - - -def compare_sexpr(expr1, expr2): - # Must compare using repr, because using == would result in another sexpr - expr1 = expr1 if not isinstance(expr1, SymbolicExpression) else expr1._ - expr2 = expr2 if not isinstance(expr2, SymbolicExpression) else expr2._ - assert expr1 == expr2 - - -class TestExpressions: - def test_symbolic_expression(self): - s1 = SymbolicExpression(1) - s2 = SymbolicExpression(2) - - compare_sexpr(s1 + s1, FunctionCall("__add__", 1, 1)) - compare_sexpr(s1 + s2, FunctionCall("__add__", 1, 2)) - compare_sexpr(s1 + 10, FunctionCall("__add__", 1, 10)) - compare_sexpr(10 + s1, FunctionCall("__radd__", 1, 10)) - - compare_sexpr(s1.argument(), FunctionCall("argument", 1)) - compare_sexpr(s1.str.argument(), FunctionCall("str.argument", 1)) - compare_sexpr(s1.argument(s2, 3), FunctionCall("argument", 1, 2, 3)) - - def test_lambda_col(self): - compare_sexpr(C.something, C["something"]) - compare_sexpr(C.something.chained(), C["something"].chained()) - - def test_banned_methods(self): - s1 = SymbolicExpression(1) - - with pytest.raises(TypeError): - bool(s1) - with pytest.raises(TypeError): - _ = s1 in s1 - with pytest.raises(TypeError): - _ = iter(s1) diff --git a/tests/test_operator_registry.py b/tests/test_operator_registry.py index b5150abd..7014e7f7 100644 --- a/tests/test_operator_registry.py +++ b/tests/test_operator_registry.py @@ -2,16 +2,16 @@ import pytest -from pydiverse.transform.core import dtypes -from pydiverse.transform.core.registry import ( +from pydiverse.transform.ops import Operator +from pydiverse.transform.tree import dtypes +from pydiverse.transform.tree.registry import ( OperatorRegistry, OperatorSignature, ) -from pydiverse.transform.ops import Operator def assert_signature( - s: OperatorSignature, args: list[dtypes.DType], rtype: dtypes.DType + s: OperatorSignature, args: list[dtypes.Dtype], rtype: dtypes.Dtype ): assert len(s.args) == len(args) @@ -103,34 +103,34 @@ def test_simple(self): reg.register_op(op1) reg.register_op(op2) - reg.add_implementation(op1, lambda: 1, "int, int -> int") - reg.add_implementation(op1, lambda: 2, "str, str -> str") + reg.add_impl(op1, lambda: 1, "int, int -> int") + reg.add_impl(op1, lambda: 2, "str, str -> str") - reg.add_implementation(op2, lambda: 10, "int, int -> int") - reg.add_implementation(op2, lambda: 20, "str, str -> str") + reg.add_impl(op2, lambda: 10, "int, int -> int") + reg.add_impl(op2, lambda: 20, "str, str -> str") - assert reg.get_implementation("op1", parse_dtypes("int", "int"))() == 1 + assert reg.get_impl("op1", parse_dtypes("int", "int"))() == 1 assert isinstance( - reg.get_implementation("op1", parse_dtypes("int", "int")).rtype, + reg.get_impl("op1", parse_dtypes("int", "int")).return_type, dtypes.Int, ) - assert reg.get_implementation("op2", parse_dtypes("int", "int"))() == 10 + assert reg.get_impl("op2", parse_dtypes("int", "int"))() == 10 - assert reg.get_implementation("op1", parse_dtypes("str", "str"))() == 2 - assert reg.get_implementation("op2", parse_dtypes("str", "str"))() == 20 + assert reg.get_impl("op1", parse_dtypes("str", "str"))() == 2 + assert reg.get_impl("op2", parse_dtypes("str", "str"))() == 20 with pytest.raises(ValueError): - reg.get_implementation("op1", parse_dtypes("int", "str")) + reg.get_impl("op1", parse_dtypes("int", "str")) with pytest.raises(ValueError): - reg.get_implementation( + reg.get_impl( "not_implemented", parse_dtypes( "int", ), ) - reg.add_implementation(op1, lambda: 100, "-> int") - assert reg.get_implementation("op1", tuple())() == 100 + reg.add_impl(op1, lambda: 100, "-> int") + assert reg.get_impl("op1", tuple())() == 100 def test_template(self): reg = OperatorRegistry("TestRegistry") @@ -143,58 +143,58 @@ def test_template(self): reg.register_op(op2) reg.register_op(op3) - reg.add_implementation(op1, lambda: 1, "T, T -> bool") - reg.add_implementation(op1, lambda: 2, "T, U -> U") + reg.add_impl(op1, lambda: 1, "T, T -> bool") + reg.add_impl(op1, lambda: 2, "T, U -> U") with pytest.raises(ValueError, match="already defined"): - reg.add_implementation(op1, lambda: 3, "T, U -> U") + reg.add_impl(op1, lambda: 3, "T, U -> U") - assert reg.get_implementation("op1", parse_dtypes("int", "int"))() == 1 - assert reg.get_implementation("op1", parse_dtypes("int", "str"))() == 2 + assert reg.get_impl("op1", parse_dtypes("int", "int"))() == 1 + assert reg.get_impl("op1", parse_dtypes("int", "str"))() == 2 # int can be promoted to float; results in "float, float -> bool" signature - assert reg.get_implementation("op1", parse_dtypes("int", "float"))() == 1 - assert reg.get_implementation("op1", parse_dtypes("float", "int"))() == 1 + assert reg.get_impl("op1", parse_dtypes("int", "float"))() == 1 + assert reg.get_impl("op1", parse_dtypes("float", "int"))() == 1 # More template matching... Also check matching precedence - reg.add_implementation(op2, lambda: 1, "int, int, int -> int") - reg.add_implementation(op2, lambda: 2, "int, str, T -> int") - reg.add_implementation(op2, lambda: 3, "int, T, str -> int") - reg.add_implementation(op2, lambda: 4, "int, T, T -> int") - reg.add_implementation(op2, lambda: 5, "T, T, T -> int") - reg.add_implementation(op2, lambda: 6, "A, T, T -> int") - - assert reg.get_implementation("op2", parse_dtypes("int", "int", "int"))() == 1 - assert reg.get_implementation("op2", parse_dtypes("int", "str", "str"))() == 2 - assert reg.get_implementation("op2", parse_dtypes("int", "int", "str"))() == 3 - assert reg.get_implementation("op2", parse_dtypes("int", "bool", "bool"))() == 4 - assert reg.get_implementation("op2", parse_dtypes("str", "str", "str"))() == 5 - assert reg.get_implementation("op2", parse_dtypes("float", "str", "str"))() == 6 + reg.add_impl(op2, lambda: 1, "int, int, int -> int") + reg.add_impl(op2, lambda: 2, "int, str, T -> int") + reg.add_impl(op2, lambda: 3, "int, T, str -> int") + reg.add_impl(op2, lambda: 4, "int, T, T -> int") + reg.add_impl(op2, lambda: 5, "T, T, T -> int") + reg.add_impl(op2, lambda: 6, "A, T, T -> int") + + assert reg.get_impl("op2", parse_dtypes("int", "int", "int"))() == 1 + assert reg.get_impl("op2", parse_dtypes("int", "str", "str"))() == 2 + assert reg.get_impl("op2", parse_dtypes("int", "int", "str"))() == 3 + assert reg.get_impl("op2", parse_dtypes("int", "bool", "bool"))() == 4 + assert reg.get_impl("op2", parse_dtypes("str", "str", "str"))() == 5 + assert reg.get_impl("op2", parse_dtypes("float", "str", "str"))() == 6 with pytest.raises(ValueError): - reg.get_implementation("op2", parse_dtypes("int", "bool", "float")) + reg.get_impl("op2", parse_dtypes("int", "bool", "float")) # Return type - reg.add_implementation(op3, lambda: 1, "T -> T") - reg.add_implementation(op3, lambda: 2, "int, T, U -> T") - reg.add_implementation(op3, lambda: 3, "str, T, U -> U") + reg.add_impl(op3, lambda: 1, "T -> T") + reg.add_impl(op3, lambda: 2, "int, T, U -> T") + reg.add_impl(op3, lambda: 3, "str, T, U -> U") with pytest.raises(ValueError, match="already defined."): - reg.add_implementation(op3, lambda: 4, "int, T, U -> U") + reg.add_impl(op3, lambda: 4, "int, T, U -> U") assert isinstance( - reg.get_implementation("op3", parse_dtypes("str")).rtype, + reg.get_impl("op3", parse_dtypes("str")).return_type, dtypes.String, ) assert isinstance( - reg.get_implementation("op3", parse_dtypes("int")).rtype, + reg.get_impl("op3", parse_dtypes("int")).return_type, dtypes.Int, ) assert isinstance( - reg.get_implementation("op3", parse_dtypes("int", "int", "float")).rtype, + reg.get_impl("op3", parse_dtypes("int", "int", "float")).return_type, dtypes.Int, ) assert isinstance( - reg.get_implementation("op3", parse_dtypes("str", "int", "float")).rtype, + reg.get_impl("op3", parse_dtypes("str", "int", "float")).return_type, dtypes.Float, ) @@ -204,12 +204,12 @@ def test_vararg(self): op1 = self.Op1() reg.register_op(op1) - reg.add_implementation(op1, lambda: 1, "int... -> int") - reg.add_implementation(op1, lambda: 2, "int, int... -> int") - reg.add_implementation(op1, lambda: 3, "int, T... -> T") + reg.add_impl(op1, lambda: 1, "int... -> int") + reg.add_impl(op1, lambda: 2, "int, int... -> int") + reg.add_impl(op1, lambda: 3, "int, T... -> T") assert ( - reg.get_implementation( + reg.get_impl( "op1", parse_dtypes( "int", @@ -217,12 +217,12 @@ def test_vararg(self): )() == 1 ) - assert reg.get_implementation("op1", parse_dtypes("int", "int"))() == 2 - assert reg.get_implementation("op1", parse_dtypes("int", "int", "int"))() == 2 - assert reg.get_implementation("op1", parse_dtypes("int", "str", "str"))() == 3 + assert reg.get_impl("op1", parse_dtypes("int", "int"))() == 2 + assert reg.get_impl("op1", parse_dtypes("int", "int", "int"))() == 2 + assert reg.get_impl("op1", parse_dtypes("int", "str", "str"))() == 3 assert isinstance( - reg.get_implementation("op1", parse_dtypes("int", "str", "str")).rtype, + reg.get_impl("op1", parse_dtypes("int", "str", "str")).return_type, dtypes.String, ) @@ -233,13 +233,13 @@ def test_variant(self): reg.register_op(op1) with pytest.raises(ValueError): - reg.add_implementation(op1, lambda: 2, "-> int", variant="VAR") + reg.add_impl(op1, lambda: 2, "-> int", variant="VAR") - reg.add_implementation(op1, lambda: 1, "-> int") - reg.add_implementation(op1, lambda: 2, "-> int", variant="VAR") + reg.add_impl(op1, lambda: 1, "-> int") + reg.add_impl(op1, lambda: 2, "-> int", variant="VAR") - assert reg.get_implementation("op1", tuple())() == 1 - assert reg.get_implementation("op1", tuple()).get_variant("VAR")() == 2 + assert reg.get_impl("op1", tuple())() == 1 + assert reg.get_impl("op1", tuple()).get_variant("VAR")() == 2 with pytest.raises(ValueError): - reg.add_implementation(op1, lambda: 2, "-> int", variant="VAR") + reg.add_impl(op1, lambda: 2, "-> int", variant="VAR") diff --git a/tests/test_polars_table.py b/tests/test_polars_table.py index be65d3c7..d3b53cc7 100644 --- a/tests/test_polars_table.py +++ b/tests/test_polars_table.py @@ -6,14 +6,11 @@ import pytest from pydiverse.transform import C -from pydiverse.transform.core import dtypes -from pydiverse.transform.core import functions as f -from pydiverse.transform.core.alignment import aligned, eval_aligned -from pydiverse.transform.core.dispatchers import Pipeable, verb -from pydiverse.transform.core.table import Table -from pydiverse.transform.core.verbs import * -from pydiverse.transform.errors import AlignmentError -from pydiverse.transform.polars.polars_table import PolarsEager +from pydiverse.transform.pipe import functions as f +from pydiverse.transform.pipe.pipeable import verb +from pydiverse.transform.pipe.table import Table +from pydiverse.transform.pipe.verbs import * +from pydiverse.transform.tree import dtypes from tests.util import assert_equal df1 = pl.DataFrame( @@ -77,65 +74,61 @@ ) -@pytest.fixture(params=["numpy", "arrow"]) +@pytest.fixture def dtype_backend(request): return request.param @pytest.fixture def tbl1(): - return Table(PolarsEager("df1", df1)) + return Table(df1, name="df1") @pytest.fixture def tbl2(): - return Table(PolarsEager("df2", df2)) + return Table(df2, name="df2") @pytest.fixture def tbl3(): - return Table(PolarsEager("df3", df3)) + return Table(df3, name="df3") @pytest.fixture def tbl4(): - return Table(PolarsEager("df4", df4.clone())) + return Table(df4, name="df4") @pytest.fixture def tbl_left(): - return Table(PolarsEager("df_left", df_left.clone())) + return Table(df_left, name="df_left") @pytest.fixture def tbl_right(): - return Table(PolarsEager("df_right", df_right.clone())) + return Table(df_right, name="df_right") @pytest.fixture def tbl_dt(): - return Table(PolarsEager("df_dt", df_dt)) + return Table(df_dt) -def assert_not_inplace(tbl: Table[PolarsEager], operation: Pipeable): - """ - Operations should not happen in-place. They should always return a new dataframe. - """ - initial = tbl._impl.df.clone() - tbl >> operation - after = tbl._impl.df +class TestPolarsLazyImpl: + def test_dtype(self, tbl1, tbl2): + assert isinstance(tbl1.col1.dtype(), dtypes.Int) + assert isinstance(tbl1.col2.dtype(), dtypes.String) - assert initial.equals(after) + assert isinstance(tbl2.col1.dtype(), dtypes.Int) + assert isinstance(tbl2.col2.dtype(), dtypes.Int) + assert isinstance(tbl2.col3.dtype(), dtypes.Float) + # test that column expression type errors are checked immediately + with pytest.raises(TypeError): + tbl1.col1 + tbl1.col2 -class TestPolarsEager: - def test_dtype(self, tbl1, tbl2): - assert isinstance(tbl1.col1._.dtype, dtypes.Int) - assert isinstance(tbl1.col2._.dtype, dtypes.String) - - assert isinstance(tbl2.col1._.dtype, dtypes.Int) - assert isinstance(tbl2.col2._.dtype, dtypes.Int) - assert isinstance(tbl2.col3._.dtype, dtypes.Float) + # here, transform should not be able to resolve the type and throw an error + C.col1 + tbl1.col2 def test_build_query(self, tbl1): assert (tbl1 >> build_query()) is None @@ -145,14 +138,11 @@ def test_export(self, tbl1): assert_equal(tbl1, df1) def test_select(self, tbl1): - assert_not_inplace(tbl1, select(tbl1.col1)) assert_equal(tbl1 >> select(tbl1.col1), df1.select("col1")) assert_equal(tbl1 >> select(tbl1.col2), df1.select("col2")) assert_equal(tbl1 >> select(), df1.select()) def test_mutate(self, tbl1): - assert_not_inplace(tbl1, mutate(x=tbl1.col1)) - assert_equal( tbl1 >> mutate(col1times2=tbl1.col1 * 2), pl.DataFrame( @@ -173,7 +163,7 @@ def test_mutate(self, tbl1): ), ) - # Check proper column referencing + # # Check proper column referencing t = tbl1 >> mutate(col2=tbl1.col1, col1=tbl1.col2) >> select() assert_equal( t >> mutate(x=t.col1, y=t.col2), @@ -185,8 +175,6 @@ def test_mutate(self, tbl1): ) def test_join(self, tbl_left, tbl_right): - assert_not_inplace(tbl_left, join(tbl_right, tbl_left.a == tbl_right.b, "left")) - assert_equal( tbl_left >> join(tbl_right, tbl_left.a == tbl_right.b, "left") @@ -226,24 +214,21 @@ def test_join(self, tbl_left, tbl_right): df_left.join(df_left, on="a", coalesce=False, suffix="_df_left"), ) - assert_equal( - tbl_right - >> inner_join( - tbl_right2 := tbl_right >> alias(), tbl_right.b == tbl_right2.b - ) - >> inner_join( - tbl_right3 := tbl_right >> alias(), tbl_right.b == tbl_right3.b - ), - df_right.join(df_right, "b", suffix="_df_right", coalesce=False).join( - df_right, "b", suffix="_df_right1", coalesce=False - ), - ) + # assert_equal( + # tbl_right + # >> inner_join( + # tbl_right2 := tbl_right >> alias(), tbl_right.b == tbl_right2.b + # ) + # >> inner_join( + # tbl_right3 := tbl_right >> alias(), tbl_right.b == tbl_right3.b + # ), + # df_right.join(df_right, "b", suffix="_df_right", coalesce=False).join( + # df_right, "b", suffix="_df_right1", coalesce=False + # ), + # ) def test_filter(self, tbl1, tbl2): - assert_not_inplace(tbl1, filter(tbl1.col1 == 3)) - # Simple filter expressions - assert_equal(tbl1 >> filter(), df1) assert_equal(tbl1 >> filter(tbl1.col1 == tbl1.col1), df1) assert_equal(tbl1 >> filter(tbl1.col1 == 3), df1.filter(pl.col("col1") == 3)) @@ -260,7 +245,6 @@ def test_filter(self, tbl1, tbl2): def test_arrange(self, tbl2, tbl4): tbl4.col1.nulls_first() - assert_not_inplace(tbl2, arrange(tbl2.col2)) assert_equal( tbl2 >> arrange(tbl2.col3) >> select(tbl2.col3), @@ -281,8 +265,8 @@ def test_arrange(self, tbl2, tbl4): tbl4 >> arrange( tbl4.col1.nulls_first(), - -tbl4.col2.nulls_first(), - -tbl4.col5.nulls_first(), + tbl4.col2.nulls_first().descending(), + tbl4.col5.nulls_first().descending(), ), df4.sort( ["col1", "col2", "col5"], @@ -295,8 +279,8 @@ def test_arrange(self, tbl2, tbl4): tbl4 >> arrange( tbl4.col1.nulls_last(), - -tbl4.col2.nulls_last(), - -tbl4.col5.nulls_last(), + tbl4.col2.descending().nulls_last(), + tbl4.col5.descending().nulls_last(), ), df4.sort( ["col1", "col2", "col5"], @@ -363,10 +347,8 @@ def test_group_by(self, tbl3): ) def test_alias(self, tbl1, tbl2): - assert_not_inplace(tbl1, alias("tblxxx")) - x = tbl2 >> alias("x") - assert x._impl.name == "x" + assert x._ast.name == "x" # Check that applying alias doesn't change the output a = ( @@ -379,7 +361,7 @@ def test_alias(self, tbl1, tbl2): assert_equal(a, b) - # Self Join + # self join assert_equal( tbl2 >> join(x, tbl2.col1 == x.col1, "left", suffix="_right"), df2.join( @@ -394,7 +376,8 @@ def test_alias(self, tbl1, tbl2): def test_window_functions(self, tbl3): # Everything else should stay the same assert_equal( - tbl3 >> mutate(x=f.row_number(arrange=[-C.col4])) >> select(*tbl3), df3 + tbl3 >> mutate(x=f.row_number(arrange=[-C.col4])) >> select(*tbl3), + df3, ) assert_equal( @@ -424,14 +407,14 @@ def test_window_functions(self, tbl3): def test_slice_head(self, tbl3): @verb - def slice_head_custom(tbl: Table, n: int, *, offset: int = 0): + def slice_head_custom(table: Table, n: int, *, offset: int = 0): t = ( - tbl + table >> mutate(_n=f.row_number(arrange=[])) >> alias() >> filter((offset < C._n) & (C._n <= (n + offset))) ) - return t >> select(*[c for c in t if c._.name != "_n"]) + return t >> select(*[C[col.name] for col in table if col.name != "_n"]) assert_equal( tbl3 >> slice_head(6), @@ -454,12 +437,13 @@ def test_case_expression(self, tbl3): tbl3 >> select() >> mutate( - col1=C.col1.case( - (0, 1), - (1, 2), - (2, 3), - default=-1, - ) + col1=f.when(C.col1 == 0) + .then(1) + .when(C.col1 == 1) + .then(2) + .when(C.col1 == 2) + .then(3) + .otherwise(-1) ) ), (df3.select("col1") + 1), @@ -470,26 +454,11 @@ def test_case_expression(self, tbl3): tbl3 >> select() >> mutate( - x=C.col1.case( - (C.col2, 1), - (C.col3, 2), - default=0, - ) - ) - ), - pl.DataFrame({"x": [1, 1, 0, 0, 0, 2, 1, 1, 0, 0, 2, 0]}), - ) - - assert_equal( - ( - tbl3 - >> select() - >> mutate( - x=f.case( - (C.col1 == C.col2, 1), - (C.col1 == C.col3, 2), - default=C.col4, - ) + x=f.when(C.col1 == C.col2) + .then(1) + .when(C.col1 == C.col3) + .then(2) + .otherwise(C.col4) ) ), pl.DataFrame({"x": [1, 1, 2, 3, 4, 2, 1, 1, 8, 9, 2, 11]}), @@ -525,18 +494,6 @@ def test_lambda_column(self, tbl1, tbl2): >> join(tbl2, tbl1.col1 == tbl2.col1, "left"), ) - # Join that also uses lambda for the right table - assert_equal( - tbl1 - >> select() - >> mutate(a=tbl1.col1) - >> join(tbl2, C.a == C.col1_custom_suffix, "left", suffix="_custom_suffix"), - tbl1 - >> select() - >> mutate(a=tbl1.col1) - >> join(tbl2, tbl1.col1 == tbl2.col1, "left", suffix="_custom_suffix"), - ) - # Filter assert_equal( tbl1 >> mutate(a=tbl1.col1 * 2) >> filter(C.a % 2 == 0), @@ -549,40 +506,11 @@ def test_lambda_column(self, tbl1, tbl2): tbl1 >> arrange(tbl1.col1) >> mutate(a=tbl1.col1 * 2), ) - def test_table_setitem(self, tbl_left, tbl_right): - tl = tbl_left >> alias("df_left") - tr = tbl_right >> alias("df_right") - - # Iterate over cols and modify - for col in tl: - tl[col] = (col * 2) % 3 - for col in tr: - tr[col] = (col * 2) % 5 - - # Check if it worked... - assert_equal( - (tl >> join(tr, C.a == C.b, "left", suffix="")), - ( - tbl_left - >> mutate(a=(tbl_left.a * 2) % 3) - >> join( - tbl_right - >> mutate(b=(tbl_right.b * 2) % 5, c=(tbl_right.c * 2) % 5), - C.a == C.b, - "left", - suffix="", - ) - ), - ) - def test_custom_verb(self, tbl1): @verb - def double_col1(tbl): - tbl[C.col1] = C.col1 * 2 - return tbl - - # Custom verb should not mutate input object - assert_not_inplace(tbl1, double_col1()) + def double_col1(table): + table >>= mutate(col1=C.col1 * 2) + return table assert_equal(tbl1 >> double_col1(), tbl1 >> mutate(col1=C.col1 * 2)) @@ -603,11 +531,11 @@ def test_null(self, tbl4): tbl4 >> mutate(u=tbl4.col3.fill_null(tbl4.col2)), df4.with_columns(pl.col("col3").fill_null(pl.col("col2")).alias("u")), ) - assert_equal( - tbl4 >> mutate(u=tbl4.col3.fill_null(tbl4.col2)), - tbl4 - >> mutate(u=f.case((tbl4.col3.is_null(), tbl4.col2), default=tbl4.col3)), - ) + # assert_equal( + # tbl4 >> mutate(u=tbl4.col3.fill_null(tbl4.col2)), + # tbl4 + # >> mutate(u=f.case((tbl4.col3.is_null(), tbl4.col2), default=tbl4.col3)), + # ) def test_datetime(self, tbl_dt): assert_equal( @@ -629,83 +557,12 @@ def test_datetime(self, tbl_dt): ) -class TestPolarsAligned: - def test_eval_aligned(self, tbl1, tbl3, tbl_left, tbl_right): - # No exception with correct length - eval_aligned(tbl_left.a + tbl_left.a) - eval_aligned(tbl_left.a + tbl_right.b) - - with pytest.raises(AlignmentError): - eval_aligned(tbl1.col1 + tbl3.col1) - - # Test aggregate functions still work - eval_aligned(tbl1.col1 + tbl3.col1.mean()) - - # Test that `with_` argument gets enforced - eval_aligned(tbl1.col1 + tbl1.col1, with_=tbl1) - eval_aligned(tbl_left.a * 2, with_=tbl_left) - eval_aligned(tbl_left.a * 2, with_=tbl_right) # Same length - eval_aligned( - tbl1.col1.mean(), with_=tbl_left - ) # Aggregate is aligned with everything - - with pytest.raises(AlignmentError): - eval_aligned(tbl3.col1 * 2, with_=tbl1) - - def test_aligned_decorator(self, tbl1, tbl3, tbl_left, tbl_right): - @aligned(with_="a") - def f(a, b): - return a + b - - f(tbl3.col1, tbl3.col2) - f(tbl_left.a, tbl_right.b) - - with pytest.raises(AlignmentError): - f(tbl1.col1, tbl3.col1) - - # Bad Alignment of return type - @aligned(with_="a") - def f(a, b): - return a.mean() + b - - with pytest.raises(AlignmentError): - f(tbl1.col1, tbl3.col1) - - # Invalid with_ argument - with pytest.raises(ValueError): - aligned(with_="x")(lambda a: 0) - - def test_col_addition(self, tbl_left, tbl_right): - @aligned(with_="a") - def f(a, b): - return a + b - - assert_equal( - tbl_left >> mutate(x=f(tbl_left.a, tbl_right.b)) >> select(C.x), - pl.DataFrame({"x": (df_left.get_column("a") + df_right.get_column("b"))}), - ) - - with pytest.raises(AlignmentError): - f(tbl_left.a, (tbl_right >> filter(C.b == 2)).b) - - with pytest.raises(AlignmentError): - x = f(tbl_left.a, tbl_right.b) - tbl_left >> filter(C.a <= 3) >> mutate(x=x) - - class TestPrintAndRepr: def test_table_str(self, tbl1): - # Table: df1, backend: PolarsEager - # col1 col2 - # 0 1 a - # 1 2 b - # 2 3 c - # 3 4 d - tbl_str = str(tbl1) assert "df1" in tbl_str - assert "PolarsEager" in tbl_str + assert "PolarsImpl" in tbl_str assert str(df1) in tbl_str def test_table_repr_html(self, tbl1): @@ -713,19 +570,8 @@ def test_table_repr_html(self, tbl1): assert "exception" not in tbl1._repr_html_() def test_col_str(self, tbl1): - # Symbolic Expression: - # dtype: int - # - # 0 1 - # 1 2 - # 2 3 - # 3 4 - # Name: df1_col1_XXXXXXXX, dtype: Int64 - col1_str = str(tbl1.col1) - series = tbl1._impl.df.get_column( - tbl1._impl.underlying_col_name[tbl1.col1._.uuid] - ) + series = tbl1._impl.df.collect().get_column("col1") assert str(series) in col1_str assert "exception" not in col1_str @@ -741,12 +587,5 @@ def test_expr_html_repr(self, tbl1): assert "exception" not in (tbl1.col1 * 2)._repr_html_() def test_lambda_str(self, tbl1): - assert "exception" in str(C.col) - assert "exception" in str(C.col1 + tbl1.col1) - - def test_eval_expr_str(self, tbl_left, tbl_right): - valid = tbl_left.a + tbl_right.b - invalid = tbl_left.a + (tbl_right >> filter(C.b == 2)).b - - assert "exception" not in str(valid) - assert "exception" in str(invalid) + assert "exception" not in str(C.col) + assert "exception" not in str(C.col1 + tbl1.col1) diff --git a/tests/test_sql_table.py b/tests/test_sql_table.py index cd6fa11e..126c3d49 100644 --- a/tests/test_sql_table.py +++ b/tests/test_sql_table.py @@ -1,18 +1,14 @@ from __future__ import annotations -import sqlite3 - import polars as pl import pytest -import sqlalchemy as sa +import sqlalchemy as sqa from pydiverse.transform import C -from pydiverse.transform.core import functions as f -from pydiverse.transform.core.alignment import aligned, eval_aligned -from pydiverse.transform.core.table import Table -from pydiverse.transform.core.verbs import * -from pydiverse.transform.errors import AlignmentError -from pydiverse.transform.sql.sql_table import SQLTableImpl +from pydiverse.transform.backend.targets import Polars, SqlAlchemy +from pydiverse.transform.pipe import functions as f +from pydiverse.transform.pipe.table import Table +from pydiverse.transform.pipe.verbs import * from tests.util import assert_equal df1 = pl.DataFrame( @@ -66,7 +62,13 @@ @pytest.fixture def engine(): - engine = sa.create_engine("sqlite:///:memory:") + engine = sqa.create_engine("sqlite:///:memory:") + # engine = sqa.create_engine("postgresql://sqa:Pydiverse23@127.0.0.1:6543") + # engine = sqa.create_engine( + # "mssql+pyodbc://sqa:PydiQuant27@127.0.0.1:1433" + # "/master?driver=ODBC+Driver+18+for+SQL+Server&encrypt=no" + # ) + df1.write_database("df1", engine, if_table_exists="replace") df2.write_database("df2", engine, if_table_exists="replace") df3.write_database("df3", engine, if_table_exists="replace") @@ -78,35 +80,35 @@ def engine(): @pytest.fixture def tbl1(engine): - return Table(SQLTableImpl(engine, "df1")) + return Table("df1", SqlAlchemy(engine)) @pytest.fixture def tbl2(engine): - return Table(SQLTableImpl(engine, "df2")) + return Table("df2", SqlAlchemy(engine)) @pytest.fixture def tbl3(engine): - return Table(SQLTableImpl(engine, "df3")) + return Table("df3", SqlAlchemy(engine)) @pytest.fixture def tbl4(engine): - return Table(SQLTableImpl(engine, "df4")) + return Table("df4", SqlAlchemy(engine)) @pytest.fixture def tbl_left(engine): - return Table(SQLTableImpl(engine, "df_left")) + return Table("df_left", SqlAlchemy(engine)) @pytest.fixture def tbl_right(engine): - return Table(SQLTableImpl(engine, "df_right")) + return Table("df_right", SqlAlchemy(engine)) -class TestSQLTable: +class TestSqlTable: def test_build_query(self, tbl1): query_str = tbl1 >> build_query() expected_out = "SELECT df1.col1 AS col1, df1.col2 AS col2 FROM df1" @@ -123,11 +125,11 @@ def test_show_query(self, tbl1, capfd): tbl1 >> show_query() >> collect() def test_export(self, tbl1): - assert_equal(tbl1 >> export(), df1) + assert_equal(tbl1 >> export(Polars()), df1) - def test_select(self, tbl1, tbl2): - assert_equal(tbl1 >> select(tbl1.col1), df1[["col1"]]) - assert_equal(tbl1 >> select(tbl1.col2), df1[["col2"]]) + def test_select(self, tbl1): + assert_equal(tbl1 >> select(tbl1.col1), df1.select("col1")) + assert_equal(tbl1 >> select(tbl1.col2), df1.select("col2")) def test_mutate(self, tbl1): assert_equal( @@ -162,40 +164,37 @@ def test_mutate(self, tbl1): ) def test_join(self, tbl_left, tbl_right): - assert_equal( - tbl_left - >> join(tbl_right, tbl_left.a == tbl_right.b, "left", suffix="") - >> select(tbl_left.a, tbl_right.b), - pl.DataFrame({"a": [1, 2, 2, 3, 4], "b": [1, 2, 2, None, None]}), - ) + # assert_equal( + # tbl_left + # >> join(tbl_right, tbl_left.a == tbl_right.b, "left", suffix="") + # >> select(tbl_left.a, tbl_right.b), + # pl.DataFrame({"a": [1, 2, 2, 3, 4], "b": [1, 2, 2, None, None]}), + # ) + + # assert_equal( + # tbl_left + # >> join(tbl_right, tbl_left.a == tbl_right.b, "inner", suffix="") + # >> select(tbl_left.a, tbl_right.b), + # pl.DataFrame({"a": [1, 2, 2], "b": [1, 2, 2]}), + # ) assert_equal( - tbl_left - >> join(tbl_right, tbl_left.a == tbl_right.b, "inner", suffix="") - >> select(tbl_left.a, tbl_right.b), - pl.DataFrame({"a": [1, 2, 2], "b": [1, 2, 2]}), + ( + tbl_left + >> join(tbl_right, tbl_left.a == tbl_right.b, "outer", suffix="_1729") + >> select(tbl_left.a, tbl_right.b) + ), + pl.DataFrame( + { + "a": [1, 2, 2, 3, 4, None], + "b_1729": [1, 2, 2, None, None, 0], + } + ), + check_row_order=False, ) - if sqlite3.sqlite_version_info >= (3, 39, 0): - assert_equal( - ( - tbl_left - >> join( - tbl_right, tbl_left.a == tbl_right.b, "outer", suffix="_1729" - ) - >> select(tbl_left.a, tbl_right.b) - ), - pl.DataFrame( - { - "a": [1.0, 2.0, 2.0, 3.0, 4.0, None], - "b_1729": [1.0, 2.0, 2.0, None, None, 0.0], - } - ), - ) - - def test_filter(self, tbl1, tbl2): + def test_filter(self, tbl1): # Simple filter expressions - assert_equal(tbl1 >> filter(), df1) assert_equal(tbl1 >> filter(tbl1.col1 == tbl1.col1), df1) assert_equal(tbl1 >> filter(tbl1.col1 == 3), df1.filter(pl.col("col1") == 3)) @@ -242,6 +241,7 @@ def test_summarise(self, tbl3): assert_equal( tbl3 >> group_by(tbl3.col1) >> summarise(mean=tbl3.col4.mean()), pl.DataFrame({"col1": [0, 1, 2], "mean": [1.5, 5.5, 9.5]}), + check_row_order=False, ) assert_equal( @@ -279,7 +279,7 @@ def test_group_by(self, tbl3): def test_alias(self, tbl1, tbl2): x = tbl2 >> alias("x") - assert x._impl.name == "x" + assert x._ast.name == "x" # Check that applying alias doesn't change the output a = ( @@ -298,7 +298,6 @@ def test_alias(self, tbl1, tbl2): >> join(x, tbl2.col1 == x.col1, "left", suffix="42") >> alias("self_join") ) - self_join >>= arrange(*self_join) self_join_expected = df2.join( df2, @@ -308,11 +307,8 @@ def test_alias(self, tbl1, tbl2): coalesce=False, suffix="42", ) - self_join_expected = self_join_expected.sort( - by=[col._.name for col in self_join] - ) - assert_equal(self_join, self_join_expected) + assert_equal(self_join, self_join_expected, check_row_order=False) def test_lambda_column(self, tbl1, tbl2): # Select @@ -344,18 +340,6 @@ def test_lambda_column(self, tbl1, tbl2): >> join(tbl2, tbl1.col1 * 2 == tbl2.col1, "left"), ) - # Join that also uses lambda for the right table - assert_equal( - tbl1 - >> select() - >> mutate(a=tbl1.col1) - >> join(tbl2, C.a == C.col1_df2, "left"), - tbl1 - >> select() - >> mutate(a=tbl1.col1) - >> join(tbl2, tbl1.col1 == tbl2.col1, "left"), - ) - # Filter assert_equal( tbl1 >> mutate(a=tbl1.col1 * 2) >> filter(C.a % 2 == 0), @@ -368,31 +352,6 @@ def test_lambda_column(self, tbl1, tbl2): tbl1 >> arrange(tbl1.col1) >> mutate(a=tbl1.col1 * 2), ) - def test_table_setitem(self, tbl_left, tbl_right): - tl = tbl_left >> alias("df_left") - tr = tbl_right >> alias("df_right") - - # Iterate over cols and modify - for col in tl: - tl[col] = (col * 2) % 3 - for col in tr: - tr[col] = (col * 2) % 5 - - # Check if it worked... - assert_equal( - (tl >> join(tr, C.a == tr.b, "left")), - ( - tbl_left - >> mutate(a=(tbl_left.a * 2) % 3) - >> join( - tbl_right - >> mutate(b=(tbl_right.b * 2) % 5, c=(tbl_right.c * 2) % 5), - C.a == C.b_df_right, - "left", - ) - ), - ) - def test_select_without_tbl_ref(self, tbl2): assert_equal( tbl2 >> summarise(count=f.count()), @@ -414,84 +373,35 @@ def test_null_comparison(self, tbl4): df4.with_columns(pl.col("col3").is_null().alias("u")), ) - -class TestSQLAligned: - def test_eval_aligned(self, tbl1, tbl3, tbl_left, tbl_right): - # Columns must be from same table - eval_aligned(tbl_left.a + tbl_left.a) - eval_aligned(tbl3.col1 + tbl3.col2) - - # Derived columns are also OK - tbl1_mutate = tbl1 >> mutate(x=tbl1.col1 * 2) - eval_aligned(tbl1.col1 + tbl1_mutate.x) - - with pytest.raises(AlignmentError): - eval_aligned(tbl1.col1 + tbl3.col1) - with pytest.raises(AlignmentError): - eval_aligned(tbl_left.a + tbl_right.b) - with pytest.raises(AlignmentError): - eval_aligned(tbl1.col1 + tbl3.col1.mean()) - with pytest.raises(AlignmentError): - tbl1_joined = tbl1 >> join(tbl3, tbl1.col1 == tbl3.col1, how="left") - eval_aligned(tbl1.col1 + tbl1_joined.col1) - - # Test that `with_` argument gets enforced - eval_aligned(tbl1.col1 + tbl1.col1, with_=tbl1) - eval_aligned(tbl_left.a * 2, with_=tbl_left) - eval_aligned(tbl1.col1, with_=tbl1_mutate) - - with pytest.raises(AlignmentError): - eval_aligned(tbl1.col1.mean(), with_=tbl_left) - - with pytest.raises(AlignmentError): - eval_aligned(tbl3.col1 * 2, with_=tbl1) - - with pytest.raises(AlignmentError): - eval_aligned(tbl_left.a, with_=tbl_right) - - def test_aligned_decorator(self, tbl1, tbl3, tbl_left, tbl_right): - @aligned(with_="a") - def f(a, b): - return a + b - - f(tbl3.col1, tbl3.col2) - f(tbl_right.b, tbl_right.c) - - with pytest.raises(AlignmentError): - f(tbl1.col1, tbl3.col1) - - with pytest.raises(AlignmentError): - f(tbl_left.a, tbl_right.b) - - # Check with_ parameter gets enforced - @aligned(with_="a") - def f(a, b): - return b - - f(tbl1.col1, tbl1.col2) - with pytest.raises(AlignmentError): - f(tbl1.col1, tbl3.col1) - - # Invalid with_ argument - with pytest.raises(ValueError): - aligned(with_="x")(lambda a: 0) - - def test_col_addition(self, tbl3): - @aligned(with_="a") - def f(a, b): - return a + b - + def test_case_expression(self, tbl3): assert_equal( - tbl3 >> mutate(x=f(tbl3.col1, tbl3.col2)) >> select(C.x), - pl.DataFrame({"x": [0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3]}), + ( + tbl3 + >> select() + >> mutate( + col1=f.when(C.col1 == 0) + .then(1) + .when(C.col1 == 1) + .then(2) + .when(C.col1 == 2) + .then(3) + .otherwise(-1) + ) + ), + (df3.select("col1") + 1), ) - # Test if it also works with derived tables - tbl3_mutate = tbl3 >> mutate(x=tbl3.col1 * 2) - tbl3 >> mutate(x=f(tbl3_mutate.col1, tbl3_mutate.x)) - - with pytest.raises(AlignmentError): - tbl3 >> arrange(C.col1) >> mutate(x=f(tbl3.col1, tbl3.col2)) - - with pytest.raises(AlignmentError): - tbl3 >> filter(C.col1 == 1) >> mutate(x=f(tbl3.col1, tbl3.col2)) + assert_equal( + ( + tbl3 + >> select() + >> mutate( + x=f.when(C.col1 == C.col2) + .then(1) + .when(C.col1 == C.col3) + .then(2) + .otherwise(C.col4) + ) + ), + pl.DataFrame({"x": [1, 1, 2, 3, 4, 2, 1, 1, 8, 9, 2, 11]}), + ) diff --git a/tests/util/__init__.py b/tests/util/__init__.py index 68f9d5fd..1d71ca96 100644 --- a/tests/util/__init__.py +++ b/tests/util/__init__.py @@ -1,4 +1,3 @@ from __future__ import annotations from .assertion import assert_equal, assert_result_equal -from .verbs import full_sort diff --git a/tests/util/assertion.py b/tests/util/assertion.py index f9d1aacd..0fec9339 100644 --- a/tests/util/assertion.py +++ b/tests/util/assertion.py @@ -8,13 +8,14 @@ from polars.testing import assert_frame_equal from pydiverse.transform import Table -from pydiverse.transform.core.verbs import export, show_query +from pydiverse.transform.backend.targets import Polars from pydiverse.transform.errors import NonStandardBehaviourWarning +from pydiverse.transform.pipe.verbs import export, show_query def assert_equal(left, right, check_dtypes=False, check_row_order=True): - left_df = left >> export() if isinstance(left, Table) else left - right_df = right >> export() if isinstance(right, Table) else right + left_df = left >> export(Polars()) if isinstance(left, Table) else left + right_df = right >> export(Polars()) if isinstance(right, Table) else right try: assert_frame_equal( @@ -65,9 +66,9 @@ def assert_result_equal( if exception and not may_throw: with pytest.raises(exception): - pipe_factory(*x) >> export() + pipe_factory(*x) >> export(Polars()) with pytest.raises(exception): - pipe_factory(*y) >> export() + pipe_factory(*y) >> export(Polars()) return did_raise_warning = False @@ -77,10 +78,10 @@ def assert_result_equal( query_x = pipe_factory(*x) query_y = pipe_factory(*y) - dfx: pl.DataFrame = (query_x >> export()).with_columns( + dfx: pl.DataFrame = (query_x >> export(Polars())).with_columns( pl.col(pl.Decimal(scale=10)).cast(pl.Float64) ) - dfy: pl.DataFrame = (query_y >> export()).with_columns( + dfy: pl.DataFrame = (query_y >> export(Polars())).with_columns( pl.col(pl.Decimal(scale=10)).cast(pl.Float64) ) diff --git a/tests/util/backend.py b/tests/util/backend.py index ea8c97e2..957cad28 100644 --- a/tests/util/backend.py +++ b/tests/util/backend.py @@ -4,12 +4,11 @@ import polars as pl -from pydiverse.transform.core import Table -from pydiverse.transform.polars.polars_table import PolarsEager -from pydiverse.transform.sql.sql_table import SQLTableImpl +from pydiverse.transform.backend.targets import SqlAlchemy +from pydiverse.transform.pipe.table import Table -def _cached_impl(fn): +def _cached_table(fn): cache = {} @functools.wraps(fn) @@ -25,16 +24,16 @@ def wrapped(df: pl.DataFrame, name: str): return wrapped -@_cached_impl -def polars_impl(df: pl.DataFrame, name: str): - return PolarsEager(name, df) +@_cached_table +def polars_table(df: pl.DataFrame, name: str): + return Table(df, name=name) _sql_engine_cache = {} -def _sql_table(df: pl.DataFrame, name: str, url: str, dtypes_map: dict = None): - import sqlalchemy as sa +def sql_table(df: pl.DataFrame, name: str, url: str, dtypes_map: dict = None): + import sqlalchemy as sqa global _sql_engine_cache @@ -43,7 +42,7 @@ def _sql_table(df: pl.DataFrame, name: str, url: str, dtypes_map: dict = None): if url in _sql_engine_cache: engine = _sql_engine_cache[url] else: - engine = sa.create_engine(url) + engine = sqa.create_engine(url) _sql_engine_cache[url] = engine sql_dtypes = {} @@ -54,56 +53,45 @@ def _sql_table(df: pl.DataFrame, name: str, url: str, dtypes_map: dict = None): df.write_database( name, engine, if_table_exists="replace", engine_options={"dtype": sql_dtypes} ) - return SQLTableImpl(engine, name) + return Table(name, SqlAlchemy(engine)) -@_cached_impl -def sqlite_impl(df: pl.DataFrame, name: str): - return _sql_table(df, name, "sqlite:///:memory:") +@_cached_table +def sqlite_table(df: pl.DataFrame, name: str): + return sql_table(df, name, "sqlite:///:memory:") -@_cached_impl -def duckdb_impl(df: pl.DataFrame, name: str): - return _sql_table(df, name, "duckdb:///:memory:") +@_cached_table +def duckdb_table(df: pl.DataFrame, name: str): + return sql_table(df, name, "duckdb:///:memory:") -@_cached_impl -def postgres_impl(df: pl.DataFrame, name: str): +@_cached_table +def postgres_table(df: pl.DataFrame, name: str): url = "postgresql://sa:Pydiverse23@127.0.0.1:6543" - return _sql_table(df, name, url) + return sql_table(df, name, url) -@_cached_impl -def mssql_impl(df: pl.DataFrame, name: str): +@_cached_table +def mssql_table(df: pl.DataFrame, name: str): from sqlalchemy.dialects.mssql import DATETIME2 url = ( "mssql+pyodbc://sa:PydiQuant27@127.0.0.1:1433" "/master?driver=ODBC+Driver+18+for+SQL+Server&encrypt=no" ) - return _sql_table( + return sql_table( df, name, url, - dtypes_map={ - pl.Datetime(): DATETIME2(), - }, + dtypes_map={pl.Datetime(): DATETIME2()}, ) -def impl_to_table_callable(fn): - @functools.wraps(fn) - def wrapped(df: pl.DataFrame, name: str): - impl = fn(df, name) - return Table(impl) - - return wrapped - - BACKEND_TABLES = { - "polars": impl_to_table_callable(polars_impl), - "sqlite": impl_to_table_callable(sqlite_impl), - "duckdb": impl_to_table_callable(duckdb_impl), - "postgres": impl_to_table_callable(postgres_impl), - "mssql": impl_to_table_callable(mssql_impl), + "polars": polars_table, + "sqlite": sqlite_table, + "duckdb": duckdb_table, + "postgres": postgres_table, + "mssql": mssql_table, } diff --git a/tests/util/verbs.py b/tests/util/verbs.py deleted file mode 100644 index 719518ac..00000000 --- a/tests/util/verbs.py +++ /dev/null @@ -1,13 +0,0 @@ -from __future__ import annotations - -from pydiverse.transform import Table, verb -from pydiverse.transform.core.verbs import arrange - - -@verb -def full_sort(t: Table): - """ - Ordering after join is not determined. - This helper applies a deterministic ordering to a table. - """ - return t >> arrange(*t)