Skip to content

Commit

Permalink
add support for decimal datatype
Browse files Browse the repository at this point in the history
  • Loading branch information
VillePuuska authored Aug 8, 2024
1 parent 78b1142 commit e6e8dfc
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 34 deletions.
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _random_df() -> pl.DataFrame:
uuids = [str(uuid.uuid4()) for _ in range(RANDOM_DF_ROWS)]
ints = [random.randint(0, 10000) for _ in range(RANDOM_DF_ROWS)]
floats = [random.uniform(0, 10000) for _ in range(RANDOM_DF_ROWS)]
decimals = [random.uniform(0, 10000) for _ in range(RANDOM_DF_ROWS)]
strings = [
"".join(
random.choices(population=string.ascii_letters, k=random.randint(2, 256))
Expand All @@ -80,12 +81,14 @@ def _random_df() -> pl.DataFrame:
"id": uuids,
"ints": ints,
"floats": floats,
"decimals": decimals,
"strings": strings,
},
schema={
"id": pl.String,
"ints": pl.Int64,
"floats": pl.Float64,
"decimals": pl.Decimal(precision=10, scale=5),
"strings": pl.String,
},
)
Expand Down
52 changes: 43 additions & 9 deletions tests/test_dataframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,26 @@ def test_basic_dataframe_operations(
position=2,
nullable=False,
),
Column(
name="decimals",
data_type=DataType.DECIMAL,
type_precision=10,
type_scale=5,
position=3,
nullable=False,
),
Column(
name="strings",
data_type=DataType.STRING,
position=3,
position=4,
nullable=False,
),
]
# Polars does not support DECIMAL when reading CSVs
if file_type == FileType.CSV:
columns[3].data_type = DataType.DOUBLE
columns[3].type_precision = 0
columns[3].type_scale = 0
client.create_table(
Table(
name=table_name,
Expand All @@ -89,6 +102,9 @@ def test_basic_dataframe_operations(
)

df = random_df()
# Polars does not support DECIMAL when reading CSVs
if file_type == FileType.CSV:
df = df.cast({"decimals": pl.Float64})

client.write_table(
df,
Expand Down Expand Up @@ -145,6 +161,9 @@ def test_basic_dataframe_operations(
)

df4 = random_df()
# Polars does not support DECIMAL when reading CSVs
if file_type == FileType.CSV:
df4 = df4.cast({"decimals": pl.Float64})

# Test OVERWRITE writes
client.write_table(
Expand Down Expand Up @@ -172,6 +191,9 @@ def test_basic_dataframe_operations(

df5 = random_df()
df5 = df5.cast({"ints": pl.String})
# Polars does not support DECIMAL when reading CSVs
if file_type == FileType.CSV:
df5 = df5.cast({"decimals": pl.Float64})

table = client.get_table(
catalog=default_catalog, schema=default_schema, table=table_name
Expand Down Expand Up @@ -257,23 +279,31 @@ def test_partitioned_dataframe_operations(
position=2,
nullable=False,
),
Column(
name="decimals",
data_type=DataType.DECIMAL,
type_precision=10,
type_scale=5,
position=3,
nullable=False,
),
Column(
name="strings",
data_type=DataType.STRING,
position=3,
position=4,
nullable=False,
),
Column(
name="part1",
data_type=DataType.LONG,
position=4,
position=5,
nullable=False,
partition_index=0,
),
Column(
name="part2",
data_type=DataType.LONG,
position=5,
position=6,
nullable=False,
partition_index=1,
),
Expand Down Expand Up @@ -360,8 +390,8 @@ def test_partitioned_dataframe_operations(
# every partition. Otherwise, we might not overwrite all data
# in the case of a partitioned Parquet table.
df_concat = pl.concat([df, df2])
df4 = df4.replace_column(4, df_concat.select("part1").to_series())
df4 = df4.replace_column(5, df_concat.select("part2").to_series())
df4 = df4.replace_column(5, df_concat.select("part1").to_series())
df4 = df4.replace_column(6, df_concat.select("part2").to_series())

# Test OVERWRITE writes
client.write_table(
Expand All @@ -388,8 +418,8 @@ def test_partitioned_dataframe_operations(

df5 = pl.concat([random_partitioned_df(), random_partitioned_df()])
df5 = df5.cast({"ints": pl.String})
df5 = df5.replace_column(4, df_concat.select("part1").to_series())
df5 = df5.replace_column(5, df_concat.select("part2").to_series())
df5 = df5.replace_column(5, df_concat.select("part1").to_series())
df5 = df5.replace_column(6, df_concat.select("part2").to_series())

table = client.get_table(
catalog=default_catalog, schema=default_schema, table=table_name
Expand Down Expand Up @@ -467,6 +497,10 @@ def test_create_as_table(

if not partitioned:
df = random_df()
# Polars does not support DECIMAL when reading CSVs
if file_type == FileType.CSV:
df = df.cast({"decimals": pl.Float64})

client.create_as_table(
df=df,
catalog=default_catalog,
Expand Down Expand Up @@ -643,7 +677,7 @@ def test_register_as_table(
# CSV
with tempfile.TemporaryDirectory() as tmpdir:
filepath = os.path.join(tmpdir, "sgvsavdavsdsvd.csv")
df = random_df()
df = random_df().cast({"decimals": pl.Float64})
df.write_csv(file=filepath)

client.register_as_table(
Expand Down
13 changes: 11 additions & 2 deletions tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,17 @@ def test_sql(
client.sql("ATTACH 'test_cat' AS test_cat (TYPE UC_CATALOG)")

for cat_name in [default_catalog, "test_cat"]:
df1 = random_df().with_columns(pl.lit(1).alias("source"))
df2 = random_df().with_columns(pl.lit(2).alias("source"))
# DuckDB does not support DECIMAL
df1 = (
random_df()
.with_columns(pl.lit(1).alias("source"))
.cast({"decimals": pl.Float64})
)
df2 = (
random_df()
.with_columns(pl.lit(2).alias("source"))
.cast({"decimals": pl.Float64})
)

with tempfile.TemporaryDirectory() as tmpdir:
client.create_as_table(
Expand Down
57 changes: 34 additions & 23 deletions uchelper/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,62 +19,69 @@ class SchemaEvolution(str, Enum):
OVERWRITE = "OVERWRITE"


def polars_type_to_uc_type(t: pl.DataType) -> DataType:
def polars_type_to_uc_type(t: pl.DataType) -> tuple[DataType, int, int]:
"""
Converts the enum DataType to a polars.DataType
"""
match t:
case pl.Decimal:
return DataType.DECIMAL
return (DataType.DECIMAL, t.precision, t.scale)
case pl.Float32:
return DataType.FLOAT
return (DataType.FLOAT, 0, 0)
case pl.Float64:
return DataType.DOUBLE
return (DataType.DOUBLE, 0, 0)
case pl.Int8:
return DataType.BYTE
return (DataType.BYTE, 0, 0)
case pl.Int16:
return DataType.SHORT
return (DataType.SHORT, 0, 0)
case pl.Int32:
return DataType.INT
return (DataType.INT, 0, 0)
case pl.Int64:
return DataType.LONG
return (DataType.LONG, 0, 0)
case pl.Date:
return DataType.DATE
return (DataType.DATE, 0, 0)
case pl.Datetime:
return DataType.TIMESTAMP
return (DataType.TIMESTAMP, 0, 0)
case pl.Array:
return DataType.ARRAY
return (DataType.ARRAY, 0, 0)
case pl.List:
return DataType.ARRAY
return (DataType.ARRAY, 0, 0)
case pl.Struct:
return DataType.STRUCT
return (DataType.STRUCT, 0, 0)
case pl.String | pl.Utf8:
return DataType.STRING
return (DataType.STRING, 0, 0)
case pl.Binary:
return DataType.BINARY
return (DataType.BINARY, 0, 0)
case pl.Boolean:
return DataType.BOOLEAN
return (DataType.BOOLEAN, 0, 0)
case pl.Null:
return DataType.NULL
return (DataType.NULL, 0, 0)
case _:
raise UnsupportedOperationError(f"Unsupported datatype: {t}")
# Why did mypy start complaining about missing return here after bumping Polars to 1.3.0?
return DataType.NULL
return (DataType.NULL, 0, 0)


# TODO: Decimal scale and precision
def df_schema_to_uc_schema(df: pl.DataFrame | pl.LazyFrame) -> list[Column]:
res = []
for i, (col_name, col_type) in enumerate(df.schema.items()):
t = polars_type_to_uc_type(col_type)
res.append(
Column(
name=col_name,
data_type=polars_type_to_uc_type(col_type),
data_type=t[0],
type_precision=t[1],
type_scale=t[2],
position=i,
nullable=True,
)
)
return res


def uc_type_to_polars_type(t: DataType) -> pl.DataType:
def uc_type_to_polars_type(
t: DataType, precision: int = 0, scale: int = 0
) -> pl.DataType:
match t:
case DataType.BOOLEAN:
return cast(pl.DataType, pl.Boolean)
Expand All @@ -99,7 +106,7 @@ def uc_type_to_polars_type(t: DataType) -> pl.DataType:
case DataType.BINARY:
return cast(pl.DataType, pl.Binary)
case DataType.DECIMAL:
return cast(pl.DataType, pl.Decimal)
return cast(pl.DataType, pl.Decimal(precision=precision, scale=scale))
case DataType.ARRAY:
return cast(pl.DataType, pl.Array)
case DataType.STRUCT:
Expand All @@ -112,7 +119,6 @@ def uc_type_to_polars_type(t: DataType) -> pl.DataType:
raise UnsupportedOperationError(f"Unsupported datatype: {t.value}")


# TODO: Decimal scale and precision
def uc_schema_to_df_schema(cols: list[Column]) -> dict[str, pl.DataType]:
return {col.name: uc_type_to_polars_type(col.data_type) for col in cols}

Expand All @@ -125,6 +131,11 @@ def check_schema_equality(left: list[Column], right: list[Column]) -> bool:
return False
if left_col.data_type != right_col.data_type:
return False
if left_col.data_type == DataType.DECIMAL and (
left_col.type_precision != right_col.type_precision
or left_col.type_scale != right_col.type_scale
):
return False
return True


Expand Down

0 comments on commit e6e8dfc

Please sign in to comment.