From e6e8dfc57eac174ce39b1ed578e7001159b9571c Mon Sep 17 00:00:00 2001 From: Ville Puuska <40150442+VillePuuska@users.noreply.github.com> Date: Thu, 8 Aug 2024 13:43:54 +0000 Subject: [PATCH] add support for decimal datatype --- tests/conftest.py | 3 +++ tests/test_dataframes.py | 52 +++++++++++++++++++++++++++++------- tests/test_sql.py | 13 +++++++-- uchelper/dataframe.py | 57 ++++++++++++++++++++++++---------------- 4 files changed, 91 insertions(+), 34 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b7d2a25..d4b6e24 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -69,6 +69,7 @@ def _random_df() -> pl.DataFrame: uuids = [str(uuid.uuid4()) for _ in range(RANDOM_DF_ROWS)] ints = [random.randint(0, 10000) for _ in range(RANDOM_DF_ROWS)] floats = [random.uniform(0, 10000) for _ in range(RANDOM_DF_ROWS)] + decimals = [random.uniform(0, 10000) for _ in range(RANDOM_DF_ROWS)] strings = [ "".join( random.choices(population=string.ascii_letters, k=random.randint(2, 256)) @@ -80,12 +81,14 @@ def _random_df() -> pl.DataFrame: "id": uuids, "ints": ints, "floats": floats, + "decimals": decimals, "strings": strings, }, schema={ "id": pl.String, "ints": pl.Int64, "floats": pl.Float64, + "decimals": pl.Decimal(precision=10, scale=5), "strings": pl.String, }, ) diff --git a/tests/test_dataframes.py b/tests/test_dataframes.py index bb8c2ca..b793961 100644 --- a/tests/test_dataframes.py +++ b/tests/test_dataframes.py @@ -69,13 +69,26 @@ def test_basic_dataframe_operations( position=2, nullable=False, ), + Column( + name="decimals", + data_type=DataType.DECIMAL, + type_precision=10, + type_scale=5, + position=3, + nullable=False, + ), Column( name="strings", data_type=DataType.STRING, - position=3, + position=4, nullable=False, ), ] + # Polars does not support DECIMAL when reading CSVs + if file_type == FileType.CSV: + columns[3].data_type = DataType.DOUBLE + columns[3].type_precision = 0 + columns[3].type_scale = 0 client.create_table( Table( name=table_name, @@ -89,6 +102,9 @@ def test_basic_dataframe_operations( ) df = random_df() + # Polars does not support DECIMAL when reading CSVs + if file_type == FileType.CSV: + df = df.cast({"decimals": pl.Float64}) client.write_table( df, @@ -145,6 +161,9 @@ def test_basic_dataframe_operations( ) df4 = random_df() + # Polars does not support DECIMAL when reading CSVs + if file_type == FileType.CSV: + df4 = df4.cast({"decimals": pl.Float64}) # Test OVERWRITE writes client.write_table( @@ -172,6 +191,9 @@ def test_basic_dataframe_operations( df5 = random_df() df5 = df5.cast({"ints": pl.String}) + # Polars does not support DECIMAL when reading CSVs + if file_type == FileType.CSV: + df5 = df5.cast({"decimals": pl.Float64}) table = client.get_table( catalog=default_catalog, schema=default_schema, table=table_name @@ -257,23 +279,31 @@ def test_partitioned_dataframe_operations( position=2, nullable=False, ), + Column( + name="decimals", + data_type=DataType.DECIMAL, + type_precision=10, + type_scale=5, + position=3, + nullable=False, + ), Column( name="strings", data_type=DataType.STRING, - position=3, + position=4, nullable=False, ), Column( name="part1", data_type=DataType.LONG, - position=4, + position=5, nullable=False, partition_index=0, ), Column( name="part2", data_type=DataType.LONG, - position=5, + position=6, nullable=False, partition_index=1, ), @@ -360,8 +390,8 @@ def test_partitioned_dataframe_operations( # every partition. Otherwise, we might not overwrite all data # in the case of a partitioned Parquet table. df_concat = pl.concat([df, df2]) - df4 = df4.replace_column(4, df_concat.select("part1").to_series()) - df4 = df4.replace_column(5, df_concat.select("part2").to_series()) + df4 = df4.replace_column(5, df_concat.select("part1").to_series()) + df4 = df4.replace_column(6, df_concat.select("part2").to_series()) # Test OVERWRITE writes client.write_table( @@ -388,8 +418,8 @@ def test_partitioned_dataframe_operations( df5 = pl.concat([random_partitioned_df(), random_partitioned_df()]) df5 = df5.cast({"ints": pl.String}) - df5 = df5.replace_column(4, df_concat.select("part1").to_series()) - df5 = df5.replace_column(5, df_concat.select("part2").to_series()) + df5 = df5.replace_column(5, df_concat.select("part1").to_series()) + df5 = df5.replace_column(6, df_concat.select("part2").to_series()) table = client.get_table( catalog=default_catalog, schema=default_schema, table=table_name @@ -467,6 +497,10 @@ def test_create_as_table( if not partitioned: df = random_df() + # Polars does not support DECIMAL when reading CSVs + if file_type == FileType.CSV: + df = df.cast({"decimals": pl.Float64}) + client.create_as_table( df=df, catalog=default_catalog, @@ -643,7 +677,7 @@ def test_register_as_table( # CSV with tempfile.TemporaryDirectory() as tmpdir: filepath = os.path.join(tmpdir, "sgvsavdavsdsvd.csv") - df = random_df() + df = random_df().cast({"decimals": pl.Float64}) df.write_csv(file=filepath) client.register_as_table( diff --git a/tests/test_sql.py b/tests/test_sql.py index f78d496..2c0e80f 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -35,8 +35,17 @@ def test_sql( client.sql("ATTACH 'test_cat' AS test_cat (TYPE UC_CATALOG)") for cat_name in [default_catalog, "test_cat"]: - df1 = random_df().with_columns(pl.lit(1).alias("source")) - df2 = random_df().with_columns(pl.lit(2).alias("source")) + # DuckDB does not support DECIMAL + df1 = ( + random_df() + .with_columns(pl.lit(1).alias("source")) + .cast({"decimals": pl.Float64}) + ) + df2 = ( + random_df() + .with_columns(pl.lit(2).alias("source")) + .cast({"decimals": pl.Float64}) + ) with tempfile.TemporaryDirectory() as tmpdir: client.create_as_table( diff --git a/uchelper/dataframe.py b/uchelper/dataframe.py index b88a5f1..3c5518f 100644 --- a/uchelper/dataframe.py +++ b/uchelper/dataframe.py @@ -19,54 +19,59 @@ class SchemaEvolution(str, Enum): OVERWRITE = "OVERWRITE" -def polars_type_to_uc_type(t: pl.DataType) -> DataType: +def polars_type_to_uc_type(t: pl.DataType) -> tuple[DataType, int, int]: + """ + Converts the enum DataType to a polars.DataType + """ match t: case pl.Decimal: - return DataType.DECIMAL + return (DataType.DECIMAL, t.precision, t.scale) case pl.Float32: - return DataType.FLOAT + return (DataType.FLOAT, 0, 0) case pl.Float64: - return DataType.DOUBLE + return (DataType.DOUBLE, 0, 0) case pl.Int8: - return DataType.BYTE + return (DataType.BYTE, 0, 0) case pl.Int16: - return DataType.SHORT + return (DataType.SHORT, 0, 0) case pl.Int32: - return DataType.INT + return (DataType.INT, 0, 0) case pl.Int64: - return DataType.LONG + return (DataType.LONG, 0, 0) case pl.Date: - return DataType.DATE + return (DataType.DATE, 0, 0) case pl.Datetime: - return DataType.TIMESTAMP + return (DataType.TIMESTAMP, 0, 0) case pl.Array: - return DataType.ARRAY + return (DataType.ARRAY, 0, 0) case pl.List: - return DataType.ARRAY + return (DataType.ARRAY, 0, 0) case pl.Struct: - return DataType.STRUCT + return (DataType.STRUCT, 0, 0) case pl.String | pl.Utf8: - return DataType.STRING + return (DataType.STRING, 0, 0) case pl.Binary: - return DataType.BINARY + return (DataType.BINARY, 0, 0) case pl.Boolean: - return DataType.BOOLEAN + return (DataType.BOOLEAN, 0, 0) case pl.Null: - return DataType.NULL + return (DataType.NULL, 0, 0) case _: raise UnsupportedOperationError(f"Unsupported datatype: {t}") # Why did mypy start complaining about missing return here after bumping Polars to 1.3.0? - return DataType.NULL + return (DataType.NULL, 0, 0) -# TODO: Decimal scale and precision def df_schema_to_uc_schema(df: pl.DataFrame | pl.LazyFrame) -> list[Column]: res = [] for i, (col_name, col_type) in enumerate(df.schema.items()): + t = polars_type_to_uc_type(col_type) res.append( Column( name=col_name, - data_type=polars_type_to_uc_type(col_type), + data_type=t[0], + type_precision=t[1], + type_scale=t[2], position=i, nullable=True, ) @@ -74,7 +79,9 @@ def df_schema_to_uc_schema(df: pl.DataFrame | pl.LazyFrame) -> list[Column]: return res -def uc_type_to_polars_type(t: DataType) -> pl.DataType: +def uc_type_to_polars_type( + t: DataType, precision: int = 0, scale: int = 0 +) -> pl.DataType: match t: case DataType.BOOLEAN: return cast(pl.DataType, pl.Boolean) @@ -99,7 +106,7 @@ def uc_type_to_polars_type(t: DataType) -> pl.DataType: case DataType.BINARY: return cast(pl.DataType, pl.Binary) case DataType.DECIMAL: - return cast(pl.DataType, pl.Decimal) + return cast(pl.DataType, pl.Decimal(precision=precision, scale=scale)) case DataType.ARRAY: return cast(pl.DataType, pl.Array) case DataType.STRUCT: @@ -112,7 +119,6 @@ def uc_type_to_polars_type(t: DataType) -> pl.DataType: raise UnsupportedOperationError(f"Unsupported datatype: {t.value}") -# TODO: Decimal scale and precision def uc_schema_to_df_schema(cols: list[Column]) -> dict[str, pl.DataType]: return {col.name: uc_type_to_polars_type(col.data_type) for col in cols} @@ -125,6 +131,11 @@ def check_schema_equality(left: list[Column], right: list[Column]) -> bool: return False if left_col.data_type != right_col.data_type: return False + if left_col.data_type == DataType.DECIMAL and ( + left_col.type_precision != right_col.type_precision + or left_col.type_scale != right_col.type_scale + ): + return False return True