Skip to content

Commit

Permalink
MSSQL implement bool <-> bit casting
Browse files Browse the repository at this point in the history
  • Loading branch information
NMAC427 committed Jul 24, 2023
1 parent 42aa1d1 commit 2281da4
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/pydiverse/transform/core/expressions/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _translate(self, expr, accept_literal_col=True, **kwargs):
expr_args, expr_kwargs = operator.mutate_args(expr.args, expr.kwargs)
expr = FunctionCall(expr.name, *expr_args, **expr_kwargs)

op_args, op_kwargs, context_kwargs = self.__translate_function_arguments(
op_args, op_kwargs, context_kwargs = self._translate_function_arguments(
expr, operator, **kwargs
)

Expand Down Expand Up @@ -119,7 +119,7 @@ def _translate_function(
def _translate_literal(self, expr, **kwargs) -> T:
raise NotImplementedError

def __translate_function_arguments(
def _translate_function_arguments(
self, expr: FunctionCall, operator: Operator, **kwargs
) -> tuple[list[T], dict[str, T], dict[str, Any]]:
op_args = [self._translate(arg, **kwargs) for arg in expr.args]
Expand Down
144 changes: 144 additions & 0 deletions src/pydiverse/transform/lazy/sql_table/dialects/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,30 @@
import sqlalchemy as sa

from pydiverse.transform import ops
from pydiverse.transform._typing import CallableT
from pydiverse.transform.core import dtypes
from pydiverse.transform.core.expressions import TypedValue
from pydiverse.transform.core.registry import TypedOperatorImpl
from pydiverse.transform.core.util import OrderingDescriptor
from pydiverse.transform.lazy.sql_table.sql_table import SQLTableImpl
from pydiverse.transform.ops import Operator, OPType


class MSSqlTableImpl(SQLTableImpl):
_dialect_name = "mssql"

def _build_select_select(self, select):
s = []
for name, uuid_ in self.selected_cols():
sql_col = self.cols[uuid_].compiled(self.sql_columns)
if not isinstance(sql_col, sa.ColumnElement):
sql_col = sa.literal(sql_col)
if dtypes.Bool().same_kind(self.cols[uuid_].dtype):
# Make sure that any boolean values get stored as bit
sql_col = sa.cast(sql_col, sa.Boolean())
s.append(sql_col.label(name))
return select.with_only_columns(*s)

def _order_col(
self, col: sa.SQLColumnExpression, ordering: OrderingDescriptor
) -> list[sa.SQLColumnExpression]:
Expand All @@ -27,6 +44,133 @@ def _order_col(
order_by_expressions.append(col.asc() if ordering.asc else col.desc())
return order_by_expressions

class ExpressionCompiler(SQLTableImpl.ExpressionCompiler):
def translate(self, expr, **kwargs):
mssql_bool_as_bit = True
if verb := kwargs.get("verb"):
mssql_bool_as_bit = verb not in ("filter",)

return super().translate(
expr, **kwargs, mssql_bool_as_bit=mssql_bool_as_bit
)

def _translate_col(self, expr, **kwargs):
# If mssql_bool_as_bit is true, then we can just return the
# precompiled col. Otherwise, we must recompile it to ensure
# we return booleans as bools and not as bits.
if kwargs.get("mssql_bool_as_bit") is True:
return super()._translate_col(expr, **kwargs)

# Can either be a base SQL column, or a reference to an expression
if expr.uuid in self.backend.sql_columns:

def sql_col(cols):
return cols[expr.uuid]

return TypedValue(sql_col, expr.dtype, OPType.EWISE)

col = self.backend.cols[expr.uuid]
return self._translate(col.expr, **kwargs)

def _translate_function_value(
self, expr, implementation, op_args, context_kwargs, verb=None, **kwargs
):
value = super()._translate_function_value(
expr,
implementation,
op_args,
context_kwargs,
verb,
**kwargs,
)

bool_as_bit = kwargs.get("mssql_bool_as_bit")
returns_bool_as_bit = mssql_op_returns_bool_as_bit(implementation)
return mssql_convert_bool_bit_value(value, bool_as_bit, returns_bool_as_bit)

def _translate_function_arguments(self, expr, operator, **kwargs):
kwargs["mssql_bool_as_bit"] = mssql_op_wants_bool_as_bit(operator)
return super()._translate_function_arguments(expr, operator, **kwargs)


# Boolean / Bit Conversion
#
# MSSQL doesn't have a boolean type. This means that expressions that
# return a boolean (e.g. ==, !=, >) can't be used in other expressions
# without casting to the BIT type.
# Conversely, after casting to BIT, we sometimes may need to convert
# back to booleans.


def mssql_op_wants_bool_as_bit(operator: Operator) -> bool:
# These operations want boolean types (not BIT) as input
exceptions = [
ops.logical.BooleanBinary,
ops.logical.Invert,
]

for exception in exceptions:
if isinstance(operator, exception):
return False

return True


def mssql_op_returns_bool_as_bit(implementation: TypedOperatorImpl) -> bool | None:
if not dtypes.Bool().same_kind(implementation.rtype):
return None

# These operations return boolean types (not BIT)
exceptions = [
ops.logical.BooleanBinary,
ops.logical.Invert,
ops.logical.Comparison,
]

operator = implementation.operator
for exception in exceptions:
if isinstance(operator, exception):
return False

return True


def mssql_convert_bit_to_bool(x: sa.SQLColumnExpression):
return x.is_(1)


def mssql_convert_bool_to_bit(x: sa.SQLColumnExpression):
return sa.case(
(x, sa.literal_column("1")),
(sa.not_(x), sa.literal_column("0")),
)


def mssql_convert_bool_bit_value(
value_func: CallableT,
wants_bool_as_bit: bool | None,
is_bool_as_bit: bool | None,
) -> CallableT:
if wants_bool_as_bit is True and is_bool_as_bit is False:

def value(*args, **kwargs):
x = value_func(*args, **kwargs)
return mssql_convert_bool_to_bit(x)

return value

if wants_bool_as_bit is False and is_bool_as_bit is True:

def value(*args, **kwargs):
x = value_func(*args, **kwargs)
return mssql_convert_bit_to_bool(x)

return value

return value_func


# Operators

with MSSqlTableImpl.op(ops.Mean()) as op:

Expand Down
10 changes: 5 additions & 5 deletions src/pydiverse/transform/lazy/sql_table/sql_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def build_select(self) -> sql.Select:

def _build_select_from(self, select):
for join in self.joins:
compiled, _ = self.compiler.translate(join.on)
compiled, _ = self.compiler.translate(join.on, verb="join")
on = compiled(self.sql_columns)

select = select.join(
Expand All @@ -268,7 +268,7 @@ def _build_select_where(self, select):
combined_where = functools.reduce(
operator.and_, map(SymbolicExpression, self.wheres)
)._
compiled, where_dtype = self.compiler.translate(combined_where)
compiled, where_dtype = self.compiler.translate(combined_where, verb="filter")
assert isinstance(where_dtype, dtypes.Bool)
where = compiled(self.sql_columns)
return select.where(where)
Expand All @@ -279,7 +279,7 @@ def _build_select_group_by(self, select):

compiled_gb, group_by_dtypes = zip(
*(
self.compiler.translate(group_by)
self.compiler.translate(group_by, verb="group_by")
for group_by in self.intrinsic_grouped_by
)
)
Expand All @@ -294,7 +294,7 @@ def _build_select_having(self, select):
combined_having = functools.reduce(
operator.and_, map(SymbolicExpression, self.having)
)._
compiled, having_dtype = self.compiler.translate(combined_having)
compiled, having_dtype = self.compiler.translate(combined_having, verb="filter")
assert isinstance(having_dtype, dtypes.Bool)
having = compiled(self.sql_columns)
return select.having(having)
Expand Down Expand Up @@ -322,7 +322,7 @@ def _build_select_order_by(self, select):

o = []
for o_by in self.order_bys:
compiled, _ = self.compiler.translate(o_by.order)
compiled, _ = self.compiler.translate(o_by.order, verb="arrange")
col = compiled(self.sql_columns)
o.extend(self._order_col(col, o_by))

Expand Down

0 comments on commit 2281da4

Please sign in to comment.