diff --git a/src/pydiverse/transform/backend/duckdb.py b/src/pydiverse/transform/backend/duckdb.py index d14b036..33f1fa7 100644 --- a/src/pydiverse/transform/backend/duckdb.py +++ b/src/pydiverse/transform/backend/duckdb.py @@ -26,7 +26,7 @@ def export(cls, nd: AstNode, target: Target, final_select: list[Col]): @classmethod def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast: - if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.Int: + if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.Int64: return sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)).cast( sqa.BigInteger() ) diff --git a/src/pydiverse/transform/backend/polars.py b/src/pydiverse/transform/backend/polars.py index 2c7494e..783e62f 100644 --- a/src/pydiverse/transform/backend/polars.py +++ b/src/pydiverse/transform/backend/polars.py @@ -351,7 +351,7 @@ def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype: if t.is_float(): return dtypes.Float64() elif t.is_integer(): - return dtypes.Int() + return dtypes.Int64() elif isinstance(t, pl.Boolean): return dtypes.Bool() elif isinstance(t, pl.String): @@ -371,7 +371,7 @@ def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype: def pdt_type_to_polars(t: dtypes.Dtype) -> pl.DataType: if isinstance(t, (dtypes.Float64, dtypes.Decimal)): return pl.Float64() - elif isinstance(t, dtypes.Int): + elif isinstance(t, dtypes.Int64): return pl.Int64() elif isinstance(t, dtypes.Bool): return pl.Boolean() diff --git a/src/pydiverse/transform/backend/postgres.py b/src/pydiverse/transform/backend/postgres.py index 28b84df..183ad3e 100644 --- a/src/pydiverse/transform/backend/postgres.py +++ b/src/pydiverse/transform/backend/postgres.py @@ -16,7 +16,7 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast: compiled_val = cls.compile_col_expr(cast.val, sqa_col) if isinstance(cast.val.dtype(), dtypes.Float64): - if isinstance(cast.target_type, dtypes.Int): + if isinstance(cast.target_type, dtypes.Int64): return sqa.func.trunc(compiled_val).cast(sqa.BigInteger()) if isinstance(cast.target_type, dtypes.String): diff --git a/src/pydiverse/transform/backend/sql.py b/src/pydiverse/transform/backend/sql.py index b04aba0..8bae93c 100644 --- a/src/pydiverse/transform/backend/sql.py +++ b/src/pydiverse/transform/backend/sql.py @@ -489,7 +489,7 @@ def compile_ast( @classmethod def sqa_type(cls, t: Dtype) -> sqa.types.TypeEngine: - if isinstance(t, dtypes.Int): + if isinstance(t, dtypes.Int64): return sqa.BigInteger() elif isinstance(t, dtypes.Float64): return sqa.Double() @@ -513,7 +513,7 @@ def sqa_type(cls, t: Dtype) -> sqa.types.TypeEngine: @classmethod def pdt_type(cls, t: sqa.types.TypeEngine) -> Dtype: if isinstance(t, sqa.Integer): - return dtypes.Int() + return dtypes.Int64() elif isinstance(t, sqa.Float): return dtypes.Float64() elif isinstance(t, (sqa.DECIMAL, sqa.NUMERIC)): diff --git a/src/pydiverse/transform/tree/col_expr.py b/src/pydiverse/transform/tree/col_expr.py index 6b098e5..c56a0f4 100644 --- a/src/pydiverse/transform/tree/col_expr.py +++ b/src/pydiverse/transform/tree/col_expr.py @@ -460,11 +460,11 @@ def dtype(self) -> Dtype: if not self.val.dtype().can_promote_to(self.target_type): valid_casts = { - (dtypes.String, dtypes.Int), + (dtypes.String, dtypes.Int64), (dtypes.String, dtypes.Float64), - (dtypes.Float64, dtypes.Int), + (dtypes.Float64, dtypes.Int64), (dtypes.DateTime, dtypes.Date), - (dtypes.Int, dtypes.String), + (dtypes.Int64, dtypes.String), (dtypes.Float64, dtypes.String), (dtypes.DateTime, dtypes.String), (dtypes.Date, dtypes.String), diff --git a/src/pydiverse/transform/tree/dtypes.py b/src/pydiverse/transform/tree/dtypes.py index 28c7171..b81fc8b 100644 --- a/src/pydiverse/transform/tree/dtypes.py +++ b/src/pydiverse/transform/tree/dtypes.py @@ -73,8 +73,8 @@ def can_promote_to(self, other: Dtype) -> bool: return other.same_kind(self) -class Int(Dtype): - name = "int" +class Int64(Dtype): + name = "int64" def can_promote_to(self, other: Dtype) -> bool: if super().can_promote_to(other): @@ -151,7 +151,7 @@ class NoneDtype(Dtype): def python_type_to_pdt(t: type) -> Dtype: if t is int: - return Int() + return Int64() elif t is float: return Float64() elif t is bool: @@ -200,7 +200,7 @@ def dtype_from_string(t: str) -> Dtype: return Template(base_type, const=is_const, vararg=is_vararg) if base_type == "int": - return Int(const=is_const, vararg=is_vararg) + return Int64(const=is_const, vararg=is_vararg) if base_type == "float64": return Float64(const=is_const, vararg=is_vararg) if base_type == "decimal": diff --git a/tests/test_backend_equivalence/test_ops/test_cast.py b/tests/test_backend_equivalence/test_ops/test_cast.py index 77a6f23..019edc8 100644 --- a/tests/test_backend_equivalence/test_ops/test_cast.py +++ b/tests/test_backend_equivalence/test_ops/test_cast.py @@ -17,25 +17,25 @@ def test_string_to_float(df_strings): def test_string_to_int(df_strings): assert_result_equal( df_strings, - lambda t: t >> mutate(u=t.d.cast(pdt.Int())), + lambda t: t >> mutate(u=t.d.cast(pdt.Int64())), ) def test_float_to_int(df_num): assert_result_equal( df_num, - lambda t: t >> mutate(**{col.name: col.cast(pdt.Int()) for col in t}), + lambda t: t >> mutate(**{col.name: col.cast(pdt.Int64()) for col in t}), ) assert_result_equal( df_num, - lambda t: t >> add_nan_inf_cols() >> mutate(u=C.inf.cast(pdt.Int())), + lambda t: t >> add_nan_inf_cols() >> mutate(u=C.inf.cast(pdt.Int64())), exception=Exception, may_throw=True, ) assert_result_equal( df_num, - lambda t: t >> add_nan_inf_cols() >> mutate(u=C.nan.cast(pdt.Int())), + lambda t: t >> add_nan_inf_cols() >> mutate(u=C.nan.cast(pdt.Int64())), exception=Exception, may_throw=True, ) diff --git a/tests/test_operator_registry.py b/tests/test_operator_registry.py index addae7a..558e265 100644 --- a/tests/test_operator_registry.py +++ b/tests/test_operator_registry.py @@ -24,13 +24,13 @@ def assert_signature( class TestOperatorSignature: def test_parse_simple(self): s = OperatorSignature.parse("int, int -> int") - assert_signature(s, [dtypes.Int(), dtypes.Int()], dtypes.Int()) + assert_signature(s, [dtypes.Int64(), dtypes.Int64()], dtypes.Int64()) s = OperatorSignature.parse("bool->bool ") assert_signature(s, [dtypes.Bool()], dtypes.Bool()) s = OperatorSignature.parse("-> int") - assert_signature(s, [], dtypes.Int()) + assert_signature(s, [], dtypes.Int64()) with pytest.raises(ValueError): OperatorSignature.parse("int, int -> ") @@ -112,7 +112,7 @@ def test_simple(self): assert reg.get_impl("op1", parse_dtypes("int", "int"))() == 1 assert isinstance( reg.get_impl("op1", parse_dtypes("int", "int")).return_type, - dtypes.Int, + dtypes.Int64, ) assert reg.get_impl("op2", parse_dtypes("int", "int"))() == 10 @@ -187,11 +187,11 @@ def test_template(self): ) assert isinstance( reg.get_impl("op3", parse_dtypes("int")).return_type, - dtypes.Int, + dtypes.Int64, ) assert isinstance( reg.get_impl("op3", parse_dtypes("int", "int", "float64")).return_type, - dtypes.Int, + dtypes.Int64, ) assert isinstance( reg.get_impl("op3", parse_dtypes("str", "int", "float64")).return_type, diff --git a/tests/test_polars_table.py b/tests/test_polars_table.py index 9571ff5..af2a48e 100644 --- a/tests/test_polars_table.py +++ b/tests/test_polars_table.py @@ -116,11 +116,11 @@ def tbl_dt(): class TestPolarsLazyImpl: def test_dtype(self, tbl1, tbl2): - assert isinstance(tbl1.col1.dtype(), dtypes.Int) + assert isinstance(tbl1.col1.dtype(), dtypes.Int64) assert isinstance(tbl1.col2.dtype(), dtypes.String) - assert isinstance(tbl2.col1.dtype(), dtypes.Int) - assert isinstance(tbl2.col2.dtype(), dtypes.Int) + assert isinstance(tbl2.col1.dtype(), dtypes.Int64) + assert isinstance(tbl2.col2.dtype(), dtypes.Int64) assert isinstance(tbl2.col3.dtype(), dtypes.Float64) # test that column expression type errors are checked immediately