Skip to content

Commit

Permalink
add support for merging delta table schema, fix missing partition col…
Browse files Browse the repository at this point in the history
…s in UC schema after overwrite or schema merge that changed schema
  • Loading branch information
VillePuuska authored Aug 10, 2024
1 parent 4810e81 commit 3b914f9
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 5 deletions.
96 changes: 96 additions & 0 deletions tests/test_dataframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,15 @@ def test_partitioned_dataframe_operations(
assert modified_col.data_type == DataType.LONG
assert modified_col.position == 1

partition_cols = sorted(
[
(col.partition_index, col.name)
for col in table.columns
if col.partition_index is not None
]
)
assert partition_cols == [(0, "part1"), (1, "part2")]

client.write_table(
df5,
catalog=default_catalog,
Expand All @@ -366,6 +375,15 @@ def test_partitioned_dataframe_operations(
assert modified_col.data_type == DataType.STRING
assert modified_col.position == 1

partition_cols = sorted(
[
(col.partition_index, col.name)
for col in table.columns
if col.partition_index is not None
]
)
assert partition_cols == [(0, "part1"), (1, "part2")]

assert_frame_equal(
df5,
client.read_table(
Expand Down Expand Up @@ -552,3 +570,81 @@ def test_register_as_table(
catalog=default_catalog, schema=default_schema, name=table_name
)
assert_frame_equal(df, df_read, check_row_order=False)


@pytest.mark.parametrize(
"partitioned",
[True, False],
)
def test_write_delta_table_merge_schema(
client: UCClient,
random_df: Callable[[], pl.DataFrame],
random_df_cols: list[Column],
random_partitioned_df: Callable[[], pl.DataFrame],
random_partitioned_df_cols: list[Column],
partitioned: bool,
):
assert client.health_check()

default_catalog = "unity"
default_schema = "default"
table_name = "test_table"

with tempfile.TemporaryDirectory() as tmpdir:
if not partitioned:
df = random_df()
df2 = (
random_df()
.cast({"floats": pl.String})
.rename({"floats": "more_strings"})
)
partition_cols = None
cols = random_df_cols
else:
df = random_partitioned_df()
df2 = (
random_partitioned_df()
.cast({"floats": pl.String})
.rename({"floats": "more_strings"})
)
partition_cols = ["part1", "part2"]
cols = random_partitioned_df_cols

client.create_as_table(
df=df,
catalog=default_catalog,
schema=default_schema,
name=table_name,
file_type="delta",
table_type="external",
location="file://" + tmpdir,
partition_cols=partition_cols,
)

client.write_table(
df=df2,
catalog=default_catalog,
schema=default_schema,
name=table_name,
mode="append",
schema_evolution="merge",
)

df_read = client.read_table(
catalog=default_catalog, schema=default_schema, name=table_name
)
assert_frame_equal(
pl.concat([df, df2], how="diagonal"),
df_read,
check_column_order=False,
check_row_order=False,
)

table = client.get_table(
catalog=default_catalog, schema=default_schema, table=table_name
)
assert set(
(col.name, col.data_type, col.partition_index) for col in cols
).union(set([("more_strings", DataType.STRING, None)])) == set(
(col.name, col.data_type, col.partition_index) for col in table.columns
)
56 changes: 51 additions & 5 deletions uchelper/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,19 @@ def polars_type_to_uc_type(t: pl.DataType) -> tuple[DataType, int, int]:
return (DataType.NULL, 0, 0)


def df_schema_to_uc_schema(df: pl.DataFrame | pl.LazyFrame) -> list[Column]:
def df_schema_to_uc_schema(
df: pl.DataFrame | pl.LazyFrame, partition_cols: list[str] = []
) -> list[Column]:
res = []
for i, (col_name, col_type) in enumerate(df.schema.items()):
if isinstance(df, pl.DataFrame):
schema = df.schema
elif isinstance(df, pl.LazyFrame):
schema = df.collect_schema()
for i, (col_name, col_type) in enumerate(schema.items()):
t = polars_type_to_uc_type(col_type)
partition_ind = None
if col_name in partition_cols:
partition_ind = partition_cols.index(col_name)
res.append(
Column(
name=col_name,
Expand All @@ -74,6 +83,7 @@ def df_schema_to_uc_schema(df: pl.DataFrame | pl.LazyFrame) -> list[Column]:
type_scale=t[2],
position=i,
nullable=True,
partition_index=partition_ind,
)
)
return res
Expand Down Expand Up @@ -124,6 +134,8 @@ def uc_schema_to_df_schema(cols: list[Column]) -> dict[str, pl.DataType]:


def check_schema_equality(left: list[Column], right: list[Column]) -> bool:
if len(left) != len(right):
return False
left = sorted(left, key=lambda x: x.position)
right = sorted(right, key=lambda x: x.position)
for left_col, right_col in zip(left, right):
Expand Down Expand Up @@ -306,10 +318,42 @@ def write_table(
raise_for_schema_mismatch(df=df, uc=table.columns)
return None
except SchemaMismatchError:
return df_schema_to_uc_schema(df=df)
return df_schema_to_uc_schema(
df=df, partition_cols=[col.name for col in partition_cols]
)

case FileType.DELTA, WriteMode.APPEND, SchemaEvolution.MERGE:
raise NotImplementedError
partition_cols = get_partition_columns(table.columns)
# needing to specify the cast is not neat,
# but mypy gets angry if we just pass this as a str to write_delta
write_mode = cast(Literal["append", "overwrite"], mode.value.lower())
if len(partition_cols) > 0:
df.write_delta(
target=path,
mode=write_mode,
delta_write_options={
"partition_by": [col.name for col in partition_cols],
"schema_mode": "merge",
"engine": "rust",
},
)
else:
df.write_delta(
target=path,
mode=write_mode,
delta_write_options={
"schema_mode": "merge",
"engine": "rust",
},
)
try:
lf = pl.scan_delta(source=path)
raise_for_schema_mismatch(df=lf, uc=table.columns)
return None
except SchemaMismatchError:
return df_schema_to_uc_schema(
df=lf, partition_cols=[col.name for col in partition_cols]
)

case FileType.PARQUET, WriteMode.APPEND, SchemaEvolution.STRICT:
partition_cols = get_partition_columns(table.columns)
Expand Down Expand Up @@ -352,7 +396,9 @@ def write_table(
raise_for_schema_mismatch(df=df, uc=table.columns)
return None
except SchemaMismatchError:
return df_schema_to_uc_schema(df=df)
return df_schema_to_uc_schema(
df=df, partition_cols=[col.name for col in partition_cols]
)

case FileType.CSV, WriteMode.OVERWRITE, SchemaEvolution.STRICT:
raise_for_schema_mismatch(df=df, uc=table.columns)
Expand Down

0 comments on commit 3b914f9

Please sign in to comment.