From 2281da499142debba7fb19a7030ffa7d66ddb047 Mon Sep 17 00:00:00 2001 From: Nicolas Camenisch Date: Sun, 23 Jul 2023 09:17:59 +0200 Subject: [PATCH] MSSQL implement bool <-> bit casting --- .../transform/core/expressions/translator.py | 4 +- .../lazy/sql_table/dialects/mssql.py | 144 ++++++++++++++++++ .../transform/lazy/sql_table/sql_table.py | 10 +- 3 files changed, 151 insertions(+), 7 deletions(-) diff --git a/src/pydiverse/transform/core/expressions/translator.py b/src/pydiverse/transform/core/expressions/translator.py index 8d6087a..b7e8b93 100644 --- a/src/pydiverse/transform/core/expressions/translator.py +++ b/src/pydiverse/transform/core/expressions/translator.py @@ -76,7 +76,7 @@ def _translate(self, expr, accept_literal_col=True, **kwargs): expr_args, expr_kwargs = operator.mutate_args(expr.args, expr.kwargs) expr = FunctionCall(expr.name, *expr_args, **expr_kwargs) - op_args, op_kwargs, context_kwargs = self.__translate_function_arguments( + op_args, op_kwargs, context_kwargs = self._translate_function_arguments( expr, operator, **kwargs ) @@ -119,7 +119,7 @@ def _translate_function( def _translate_literal(self, expr, **kwargs) -> T: raise NotImplementedError - def __translate_function_arguments( + def _translate_function_arguments( self, expr: FunctionCall, operator: Operator, **kwargs ) -> tuple[list[T], dict[str, T], dict[str, Any]]: op_args = [self._translate(arg, **kwargs) for arg in expr.args] diff --git a/src/pydiverse/transform/lazy/sql_table/dialects/mssql.py b/src/pydiverse/transform/lazy/sql_table/dialects/mssql.py index d882c1c..1dbae0b 100644 --- a/src/pydiverse/transform/lazy/sql_table/dialects/mssql.py +++ b/src/pydiverse/transform/lazy/sql_table/dialects/mssql.py @@ -3,13 +3,30 @@ import sqlalchemy as sa from pydiverse.transform import ops +from pydiverse.transform._typing import CallableT +from pydiverse.transform.core import dtypes +from pydiverse.transform.core.expressions import TypedValue +from pydiverse.transform.core.registry import TypedOperatorImpl from pydiverse.transform.core.util import OrderingDescriptor from pydiverse.transform.lazy.sql_table.sql_table import SQLTableImpl +from pydiverse.transform.ops import Operator, OPType class MSSqlTableImpl(SQLTableImpl): _dialect_name = "mssql" + def _build_select_select(self, select): + s = [] + for name, uuid_ in self.selected_cols(): + sql_col = self.cols[uuid_].compiled(self.sql_columns) + if not isinstance(sql_col, sa.ColumnElement): + sql_col = sa.literal(sql_col) + if dtypes.Bool().same_kind(self.cols[uuid_].dtype): + # Make sure that any boolean values get stored as bit + sql_col = sa.cast(sql_col, sa.Boolean()) + s.append(sql_col.label(name)) + return select.with_only_columns(*s) + def _order_col( self, col: sa.SQLColumnExpression, ordering: OrderingDescriptor ) -> list[sa.SQLColumnExpression]: @@ -27,6 +44,133 @@ def _order_col( order_by_expressions.append(col.asc() if ordering.asc else col.desc()) return order_by_expressions + class ExpressionCompiler(SQLTableImpl.ExpressionCompiler): + def translate(self, expr, **kwargs): + mssql_bool_as_bit = True + if verb := kwargs.get("verb"): + mssql_bool_as_bit = verb not in ("filter",) + + return super().translate( + expr, **kwargs, mssql_bool_as_bit=mssql_bool_as_bit + ) + + def _translate_col(self, expr, **kwargs): + # If mssql_bool_as_bit is true, then we can just return the + # precompiled col. Otherwise, we must recompile it to ensure + # we return booleans as bools and not as bits. + if kwargs.get("mssql_bool_as_bit") is True: + return super()._translate_col(expr, **kwargs) + + # Can either be a base SQL column, or a reference to an expression + if expr.uuid in self.backend.sql_columns: + + def sql_col(cols): + return cols[expr.uuid] + + return TypedValue(sql_col, expr.dtype, OPType.EWISE) + + col = self.backend.cols[expr.uuid] + return self._translate(col.expr, **kwargs) + + def _translate_function_value( + self, expr, implementation, op_args, context_kwargs, verb=None, **kwargs + ): + value = super()._translate_function_value( + expr, + implementation, + op_args, + context_kwargs, + verb, + **kwargs, + ) + + bool_as_bit = kwargs.get("mssql_bool_as_bit") + returns_bool_as_bit = mssql_op_returns_bool_as_bit(implementation) + return mssql_convert_bool_bit_value(value, bool_as_bit, returns_bool_as_bit) + + def _translate_function_arguments(self, expr, operator, **kwargs): + kwargs["mssql_bool_as_bit"] = mssql_op_wants_bool_as_bit(operator) + return super()._translate_function_arguments(expr, operator, **kwargs) + + +# Boolean / Bit Conversion +# +# MSSQL doesn't have a boolean type. This means that expressions that +# return a boolean (e.g. ==, !=, >) can't be used in other expressions +# without casting to the BIT type. +# Conversely, after casting to BIT, we sometimes may need to convert +# back to booleans. + + +def mssql_op_wants_bool_as_bit(operator: Operator) -> bool: + # These operations want boolean types (not BIT) as input + exceptions = [ + ops.logical.BooleanBinary, + ops.logical.Invert, + ] + + for exception in exceptions: + if isinstance(operator, exception): + return False + + return True + + +def mssql_op_returns_bool_as_bit(implementation: TypedOperatorImpl) -> bool | None: + if not dtypes.Bool().same_kind(implementation.rtype): + return None + + # These operations return boolean types (not BIT) + exceptions = [ + ops.logical.BooleanBinary, + ops.logical.Invert, + ops.logical.Comparison, + ] + + operator = implementation.operator + for exception in exceptions: + if isinstance(operator, exception): + return False + + return True + + +def mssql_convert_bit_to_bool(x: sa.SQLColumnExpression): + return x.is_(1) + + +def mssql_convert_bool_to_bit(x: sa.SQLColumnExpression): + return sa.case( + (x, sa.literal_column("1")), + (sa.not_(x), sa.literal_column("0")), + ) + + +def mssql_convert_bool_bit_value( + value_func: CallableT, + wants_bool_as_bit: bool | None, + is_bool_as_bit: bool | None, +) -> CallableT: + if wants_bool_as_bit is True and is_bool_as_bit is False: + + def value(*args, **kwargs): + x = value_func(*args, **kwargs) + return mssql_convert_bool_to_bit(x) + + return value + + if wants_bool_as_bit is False and is_bool_as_bit is True: + + def value(*args, **kwargs): + x = value_func(*args, **kwargs) + return mssql_convert_bit_to_bool(x) + + return value + + return value_func + + +# Operators with MSSqlTableImpl.op(ops.Mean()) as op: diff --git a/src/pydiverse/transform/lazy/sql_table/sql_table.py b/src/pydiverse/transform/lazy/sql_table/sql_table.py index af87c01..c97d9ba 100644 --- a/src/pydiverse/transform/lazy/sql_table/sql_table.py +++ b/src/pydiverse/transform/lazy/sql_table/sql_table.py @@ -248,7 +248,7 @@ def build_select(self) -> sql.Select: def _build_select_from(self, select): for join in self.joins: - compiled, _ = self.compiler.translate(join.on) + compiled, _ = self.compiler.translate(join.on, verb="join") on = compiled(self.sql_columns) select = select.join( @@ -268,7 +268,7 @@ def _build_select_where(self, select): combined_where = functools.reduce( operator.and_, map(SymbolicExpression, self.wheres) )._ - compiled, where_dtype = self.compiler.translate(combined_where) + compiled, where_dtype = self.compiler.translate(combined_where, verb="filter") assert isinstance(where_dtype, dtypes.Bool) where = compiled(self.sql_columns) return select.where(where) @@ -279,7 +279,7 @@ def _build_select_group_by(self, select): compiled_gb, group_by_dtypes = zip( *( - self.compiler.translate(group_by) + self.compiler.translate(group_by, verb="group_by") for group_by in self.intrinsic_grouped_by ) ) @@ -294,7 +294,7 @@ def _build_select_having(self, select): combined_having = functools.reduce( operator.and_, map(SymbolicExpression, self.having) )._ - compiled, having_dtype = self.compiler.translate(combined_having) + compiled, having_dtype = self.compiler.translate(combined_having, verb="filter") assert isinstance(having_dtype, dtypes.Bool) having = compiled(self.sql_columns) return select.having(having) @@ -322,7 +322,7 @@ def _build_select_order_by(self, select): o = [] for o_by in self.order_bys: - compiled, _ = self.compiler.translate(o_by.order) + compiled, _ = self.compiler.translate(o_by.order, verb="arrange") col = compiled(self.sql_columns) o.extend(self._order_col(col, o_by))