Skip to content

Commit

Permalink
rename Int to Int64 for consistency with Float64
Browse files Browse the repository at this point in the history
  • Loading branch information
finn-rudolph committed Sep 30, 2024
1 parent 289d57f commit 081712c
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/pydiverse/transform/backend/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def export(cls, nd: AstNode, target: Target, final_select: list[Col]):

@classmethod
def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast:
if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.Int:
if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.Int64:
return sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)).cast(
sqa.BigInteger()
)
Expand Down
4 changes: 2 additions & 2 deletions src/pydiverse/transform/backend/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype:
if t.is_float():
return dtypes.Float64()
elif t.is_integer():
return dtypes.Int()
return dtypes.Int64()
elif isinstance(t, pl.Boolean):
return dtypes.Bool()
elif isinstance(t, pl.String):
Expand All @@ -371,7 +371,7 @@ def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype:
def pdt_type_to_polars(t: dtypes.Dtype) -> pl.DataType:
if isinstance(t, (dtypes.Float64, dtypes.Decimal)):
return pl.Float64()
elif isinstance(t, dtypes.Int):
elif isinstance(t, dtypes.Int64):
return pl.Int64()
elif isinstance(t, dtypes.Bool):
return pl.Boolean()
Expand Down
2 changes: 1 addition & 1 deletion src/pydiverse/transform/backend/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast:
compiled_val = cls.compile_col_expr(cast.val, sqa_col)

if isinstance(cast.val.dtype(), dtypes.Float64):
if isinstance(cast.target_type, dtypes.Int):
if isinstance(cast.target_type, dtypes.Int64):
return sqa.func.trunc(compiled_val).cast(sqa.BigInteger())

if isinstance(cast.target_type, dtypes.String):
Expand Down
4 changes: 2 additions & 2 deletions src/pydiverse/transform/backend/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def compile_ast(

@classmethod
def sqa_type(cls, t: Dtype) -> sqa.types.TypeEngine:
if isinstance(t, dtypes.Int):
if isinstance(t, dtypes.Int64):
return sqa.BigInteger()
elif isinstance(t, dtypes.Float64):
return sqa.Double()
Expand All @@ -513,7 +513,7 @@ def sqa_type(cls, t: Dtype) -> sqa.types.TypeEngine:
@classmethod
def pdt_type(cls, t: sqa.types.TypeEngine) -> Dtype:
if isinstance(t, sqa.Integer):
return dtypes.Int()
return dtypes.Int64()
elif isinstance(t, sqa.Float):
return dtypes.Float64()
elif isinstance(t, (sqa.DECIMAL, sqa.NUMERIC)):
Expand Down
6 changes: 3 additions & 3 deletions src/pydiverse/transform/tree/col_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,11 +460,11 @@ def dtype(self) -> Dtype:

if not self.val.dtype().can_promote_to(self.target_type):
valid_casts = {
(dtypes.String, dtypes.Int),
(dtypes.String, dtypes.Int64),
(dtypes.String, dtypes.Float64),
(dtypes.Float64, dtypes.Int),
(dtypes.Float64, dtypes.Int64),
(dtypes.DateTime, dtypes.Date),
(dtypes.Int, dtypes.String),
(dtypes.Int64, dtypes.String),
(dtypes.Float64, dtypes.String),
(dtypes.DateTime, dtypes.String),
(dtypes.Date, dtypes.String),
Expand Down
8 changes: 4 additions & 4 deletions src/pydiverse/transform/tree/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def can_promote_to(self, other: Dtype) -> bool:
return other.same_kind(self)


class Int(Dtype):
name = "int"
class Int64(Dtype):
name = "int64"

def can_promote_to(self, other: Dtype) -> bool:
if super().can_promote_to(other):
Expand Down Expand Up @@ -151,7 +151,7 @@ class NoneDtype(Dtype):

def python_type_to_pdt(t: type) -> Dtype:
if t is int:
return Int()
return Int64()
elif t is float:
return Float64()
elif t is bool:
Expand Down Expand Up @@ -200,7 +200,7 @@ def dtype_from_string(t: str) -> Dtype:
return Template(base_type, const=is_const, vararg=is_vararg)

if base_type == "int":
return Int(const=is_const, vararg=is_vararg)
return Int64(const=is_const, vararg=is_vararg)
if base_type == "float64":
return Float64(const=is_const, vararg=is_vararg)
if base_type == "decimal":
Expand Down
8 changes: 4 additions & 4 deletions tests/test_backend_equivalence/test_ops/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,25 @@ def test_string_to_float(df_strings):
def test_string_to_int(df_strings):
assert_result_equal(
df_strings,
lambda t: t >> mutate(u=t.d.cast(pdt.Int())),
lambda t: t >> mutate(u=t.d.cast(pdt.Int64())),
)


def test_float_to_int(df_num):
assert_result_equal(
df_num,
lambda t: t >> mutate(**{col.name: col.cast(pdt.Int()) for col in t}),
lambda t: t >> mutate(**{col.name: col.cast(pdt.Int64()) for col in t}),
)

assert_result_equal(
df_num,
lambda t: t >> add_nan_inf_cols() >> mutate(u=C.inf.cast(pdt.Int())),
lambda t: t >> add_nan_inf_cols() >> mutate(u=C.inf.cast(pdt.Int64())),
exception=Exception,
may_throw=True,
)
assert_result_equal(
df_num,
lambda t: t >> add_nan_inf_cols() >> mutate(u=C.nan.cast(pdt.Int())),
lambda t: t >> add_nan_inf_cols() >> mutate(u=C.nan.cast(pdt.Int64())),
exception=Exception,
may_throw=True,
)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_operator_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ def assert_signature(
class TestOperatorSignature:
def test_parse_simple(self):
s = OperatorSignature.parse("int, int -> int")
assert_signature(s, [dtypes.Int(), dtypes.Int()], dtypes.Int())
assert_signature(s, [dtypes.Int64(), dtypes.Int64()], dtypes.Int64())

s = OperatorSignature.parse("bool->bool ")
assert_signature(s, [dtypes.Bool()], dtypes.Bool())

s = OperatorSignature.parse("-> int")
assert_signature(s, [], dtypes.Int())
assert_signature(s, [], dtypes.Int64())

with pytest.raises(ValueError):
OperatorSignature.parse("int, int -> ")
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_simple(self):
assert reg.get_impl("op1", parse_dtypes("int", "int"))() == 1
assert isinstance(
reg.get_impl("op1", parse_dtypes("int", "int")).return_type,
dtypes.Int,
dtypes.Int64,
)
assert reg.get_impl("op2", parse_dtypes("int", "int"))() == 10

Expand Down Expand Up @@ -187,11 +187,11 @@ def test_template(self):
)
assert isinstance(
reg.get_impl("op3", parse_dtypes("int")).return_type,
dtypes.Int,
dtypes.Int64,
)
assert isinstance(
reg.get_impl("op3", parse_dtypes("int", "int", "float64")).return_type,
dtypes.Int,
dtypes.Int64,
)
assert isinstance(
reg.get_impl("op3", parse_dtypes("str", "int", "float64")).return_type,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_polars_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ def tbl_dt():

class TestPolarsLazyImpl:
def test_dtype(self, tbl1, tbl2):
assert isinstance(tbl1.col1.dtype(), dtypes.Int)
assert isinstance(tbl1.col1.dtype(), dtypes.Int64)
assert isinstance(tbl1.col2.dtype(), dtypes.String)

assert isinstance(tbl2.col1.dtype(), dtypes.Int)
assert isinstance(tbl2.col2.dtype(), dtypes.Int)
assert isinstance(tbl2.col1.dtype(), dtypes.Int64)
assert isinstance(tbl2.col2.dtype(), dtypes.Int64)
assert isinstance(tbl2.col3.dtype(), dtypes.Float64)

# test that column expression type errors are checked immediately
Expand Down

0 comments on commit 081712c

Please sign in to comment.