From 3b914f904261a85f87da77f39240862db6bd6eaf Mon Sep 17 00:00:00 2001 From: Ville Puuska <40150442+VillePuuska@users.noreply.github.com> Date: Sat, 10 Aug 2024 13:13:40 +0000 Subject: [PATCH] add support for merging delta table schema, fix missing partition cols in UC schema after overwrite or schema merge that changed schema --- tests/test_dataframes.py | 96 ++++++++++++++++++++++++++++++++++++++++ uchelper/dataframe.py | 56 ++++++++++++++++++++--- 2 files changed, 147 insertions(+), 5 deletions(-) diff --git a/tests/test_dataframes.py b/tests/test_dataframes.py index 6ceb737..7c3fa5b 100644 --- a/tests/test_dataframes.py +++ b/tests/test_dataframes.py @@ -350,6 +350,15 @@ def test_partitioned_dataframe_operations( assert modified_col.data_type == DataType.LONG assert modified_col.position == 1 + partition_cols = sorted( + [ + (col.partition_index, col.name) + for col in table.columns + if col.partition_index is not None + ] + ) + assert partition_cols == [(0, "part1"), (1, "part2")] + client.write_table( df5, catalog=default_catalog, @@ -366,6 +375,15 @@ def test_partitioned_dataframe_operations( assert modified_col.data_type == DataType.STRING assert modified_col.position == 1 + partition_cols = sorted( + [ + (col.partition_index, col.name) + for col in table.columns + if col.partition_index is not None + ] + ) + assert partition_cols == [(0, "part1"), (1, "part2")] + assert_frame_equal( df5, client.read_table( @@ -552,3 +570,81 @@ def test_register_as_table( catalog=default_catalog, schema=default_schema, name=table_name ) assert_frame_equal(df, df_read, check_row_order=False) + + +@pytest.mark.parametrize( + "partitioned", + [True, False], +) +def test_write_delta_table_merge_schema( + client: UCClient, + random_df: Callable[[], pl.DataFrame], + random_df_cols: list[Column], + random_partitioned_df: Callable[[], pl.DataFrame], + random_partitioned_df_cols: list[Column], + partitioned: bool, +): + assert client.health_check() + + default_catalog = "unity" + default_schema = "default" + table_name = "test_table" + + with tempfile.TemporaryDirectory() as tmpdir: + if not partitioned: + df = random_df() + df2 = ( + random_df() + .cast({"floats": pl.String}) + .rename({"floats": "more_strings"}) + ) + partition_cols = None + cols = random_df_cols + else: + df = random_partitioned_df() + df2 = ( + random_partitioned_df() + .cast({"floats": pl.String}) + .rename({"floats": "more_strings"}) + ) + partition_cols = ["part1", "part2"] + cols = random_partitioned_df_cols + + client.create_as_table( + df=df, + catalog=default_catalog, + schema=default_schema, + name=table_name, + file_type="delta", + table_type="external", + location="file://" + tmpdir, + partition_cols=partition_cols, + ) + + client.write_table( + df=df2, + catalog=default_catalog, + schema=default_schema, + name=table_name, + mode="append", + schema_evolution="merge", + ) + + df_read = client.read_table( + catalog=default_catalog, schema=default_schema, name=table_name + ) + assert_frame_equal( + pl.concat([df, df2], how="diagonal"), + df_read, + check_column_order=False, + check_row_order=False, + ) + + table = client.get_table( + catalog=default_catalog, schema=default_schema, table=table_name + ) + assert set( + (col.name, col.data_type, col.partition_index) for col in cols + ).union(set([("more_strings", DataType.STRING, None)])) == set( + (col.name, col.data_type, col.partition_index) for col in table.columns + ) diff --git a/uchelper/dataframe.py b/uchelper/dataframe.py index 2695886..d7ae06b 100644 --- a/uchelper/dataframe.py +++ b/uchelper/dataframe.py @@ -62,10 +62,19 @@ def polars_type_to_uc_type(t: pl.DataType) -> tuple[DataType, int, int]: return (DataType.NULL, 0, 0) -def df_schema_to_uc_schema(df: pl.DataFrame | pl.LazyFrame) -> list[Column]: +def df_schema_to_uc_schema( + df: pl.DataFrame | pl.LazyFrame, partition_cols: list[str] = [] +) -> list[Column]: res = [] - for i, (col_name, col_type) in enumerate(df.schema.items()): + if isinstance(df, pl.DataFrame): + schema = df.schema + elif isinstance(df, pl.LazyFrame): + schema = df.collect_schema() + for i, (col_name, col_type) in enumerate(schema.items()): t = polars_type_to_uc_type(col_type) + partition_ind = None + if col_name in partition_cols: + partition_ind = partition_cols.index(col_name) res.append( Column( name=col_name, @@ -74,6 +83,7 @@ def df_schema_to_uc_schema(df: pl.DataFrame | pl.LazyFrame) -> list[Column]: type_scale=t[2], position=i, nullable=True, + partition_index=partition_ind, ) ) return res @@ -124,6 +134,8 @@ def uc_schema_to_df_schema(cols: list[Column]) -> dict[str, pl.DataType]: def check_schema_equality(left: list[Column], right: list[Column]) -> bool: + if len(left) != len(right): + return False left = sorted(left, key=lambda x: x.position) right = sorted(right, key=lambda x: x.position) for left_col, right_col in zip(left, right): @@ -306,10 +318,42 @@ def write_table( raise_for_schema_mismatch(df=df, uc=table.columns) return None except SchemaMismatchError: - return df_schema_to_uc_schema(df=df) + return df_schema_to_uc_schema( + df=df, partition_cols=[col.name for col in partition_cols] + ) case FileType.DELTA, WriteMode.APPEND, SchemaEvolution.MERGE: - raise NotImplementedError + partition_cols = get_partition_columns(table.columns) + # needing to specify the cast is not neat, + # but mypy gets angry if we just pass this as a str to write_delta + write_mode = cast(Literal["append", "overwrite"], mode.value.lower()) + if len(partition_cols) > 0: + df.write_delta( + target=path, + mode=write_mode, + delta_write_options={ + "partition_by": [col.name for col in partition_cols], + "schema_mode": "merge", + "engine": "rust", + }, + ) + else: + df.write_delta( + target=path, + mode=write_mode, + delta_write_options={ + "schema_mode": "merge", + "engine": "rust", + }, + ) + try: + lf = pl.scan_delta(source=path) + raise_for_schema_mismatch(df=lf, uc=table.columns) + return None + except SchemaMismatchError: + return df_schema_to_uc_schema( + df=lf, partition_cols=[col.name for col in partition_cols] + ) case FileType.PARQUET, WriteMode.APPEND, SchemaEvolution.STRICT: partition_cols = get_partition_columns(table.columns) @@ -352,7 +396,9 @@ def write_table( raise_for_schema_mismatch(df=df, uc=table.columns) return None except SchemaMismatchError: - return df_schema_to_uc_schema(df=df) + return df_schema_to_uc_schema( + df=df, partition_cols=[col.name for col in partition_cols] + ) case FileType.CSV, WriteMode.OVERWRITE, SchemaEvolution.STRICT: raise_for_schema_mismatch(df=df, uc=table.columns)