Skip to content

Commit

Permalink
Properly determine result type of greatest and least for postgres and…
Browse files Browse the repository at this point in the history
… sqlite
  • Loading branch information
NMAC427 committed Aug 17, 2023
1 parent c7ffb30 commit 9113fcc
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
13 changes: 8 additions & 5 deletions src/pydiverse/transform/lazy/sql_table/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import sqlalchemy as sa

from pydiverse.transform import ops
from pydiverse.transform.lazy.sql_table.sql_table import SQLTableImpl
from pydiverse.transform.lazy.sql_table.sql_table import (
SQLTableImpl,
determine_sa_type_union,
)


class PostgresTableImpl(SQLTableImpl):
Expand Down Expand Up @@ -74,16 +77,16 @@ def _millisecond(x):

@op("str... -> str")
def _greatest(*x):
# TODO: Determine return type
return sa.func.GREATEST(*(e.collate("POSIX") for e in x))
type_ = determine_sa_type_union([y.type for y in x])
return sa.func.GREATEST(*(e.collate("POSIX") for e in x), type_=type_)


with PostgresTableImpl.op(ops.Least()) as op:

@op("str... -> str")
def _least(*x):
# TODO: Determine return type
return sa.func.LEAST(*(e.collate("POSIX") for e in x))
type_ = determine_sa_type_union([y.type for y in x])
return sa.func.LEAST(*(e.collate("POSIX") for e in x), type_=type_)


with PostgresTableImpl.op(ops.Any()) as op:
Expand Down
13 changes: 8 additions & 5 deletions src/pydiverse/transform/lazy/sql_table/dialects/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import sqlalchemy as sa

from pydiverse.transform import ops
from pydiverse.transform.lazy.sql_table.sql_table import SQLTableImpl
from pydiverse.transform.lazy.sql_table.sql_table import (
SQLTableImpl,
determine_sa_type_union,
)
from pydiverse.transform.util.warnings import warn_non_standard


Expand Down Expand Up @@ -82,8 +85,8 @@ def _greatest(*x):
left = _greatest(*x[:mid])
right = _greatest(*x[mid:])

# TODO: Determine return type
return sa.func.coalesce(sa.func.MAX(left, right), left, right)
type_ = determine_sa_type_union([left.type, right.type])
return sa.func.coalesce(sa.func.MAX(left, right, type_=type_), left, right)


with SQLiteTableImpl.op(ops.Least()) as op:
Expand All @@ -99,8 +102,8 @@ def _least(*x):
left = _least(*x[:mid])
right = _least(*x[mid:])

# TODO: Determine return type
return sa.func.coalesce(sa.func.MIN(left, right), left, right)
type_ = determine_sa_type_union([left.type, right.type])
return sa.func.coalesce(sa.func.MIN(left, right, type_=type_), left, right)


with SQLiteTableImpl.op(ops.StringJoin()) as op:
Expand Down
2 changes: 1 addition & 1 deletion src/pydiverse/transform/lazy/sql_table/sql_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ def determine_sa_type_union(types: list[sa.types.TypeEngine]):

candidate = types[0]
for t in types[1:]:
if isinstance(t, type(candidate)):
if t == type(candidate):
continue

if isinstance(candidate, sa.Integer):
Expand Down

0 comments on commit 9113fcc

Please sign in to comment.