Skip to content

Commit

Permalink
refactor write_table match case, change SchemaEvolution.UNION to MERG…
Browse files Browse the repository at this point in the history
…E to match Delta Lake
  • Loading branch information
VillePuuska authored Aug 10, 2024
1 parent e69a6c6 commit 4810e81
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 76 deletions.
4 changes: 2 additions & 2 deletions uchelper/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def write_table(
name: str,
mode: Literal["append", "overwrite"] | WriteMode = WriteMode.APPEND,
schema_evolution: (
Literal["strict", "union", "overwrite"] | SchemaEvolution
Literal["strict", "merge", "overwrite"] | SchemaEvolution
) = SchemaEvolution.STRICT,
) -> None:
"""
Expand All @@ -298,7 +298,7 @@ def write_table(
`schema_evolution` specifies how to handle possible schema mismatches:
- SchemaEvolution.STRICT raises an Exception if there is a difference in schemas.
- SchemaEvolution.UNION will attempt to take the union of the schemas; raises if impossible.
- SchemaEvolution.MERGE will attempt to merge the schemas; raises if impossible.
- SchemaEvolution.OVERWRITE will attempt to cast the existing table to the schema of the new
DataFrame; raises if impossible.
"""
Expand Down
113 changes: 42 additions & 71 deletions uchelper/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class WriteMode(str, Enum):

class SchemaEvolution(str, Enum):
STRICT = "STRICT"
UNION = "UNION"
MERGE = "MERGE"
OVERWRITE = "OVERWRITE"


Expand Down Expand Up @@ -252,30 +252,23 @@ def write_table(
"""
Writes the Polars DataFrame `df` to the location of `table`.
If `mode` is APPEND, depending on the `schema_evolution` parameter, if the schema
stored in Unity Catalog needs to be updated, returns the new list of Columns.
If the schema does not need to be updated, returns None.
Returns None if the schema in Unity Catalog does NOT need to be updated.
Returns a list[Column] if the schema in Unity Catalog DOES need to be updated.
If `mode` is OVERWRITE, the function returns the list of Columns if it doesn't
match the previous schema in Unity Catalog. Otherwise returns None.
`schema_evolution` is completely ignored if `mode` is OVERWRITE.
In short: if this function returns None, Unity Catalog does not need an update;
otherwise update the schema in Unity Catalog with the returned list of Columns.
Raises UnsupportedOperationError for unsupported combination of `table.file_type`, `mode`, and `schema_evolution`.
"""
path = table.storage_location
assert path is not None
if not path.startswith("file://"):
raise UnsupportedOperationError("Only local storage is supported.")
path = path.removeprefix("file://")

# TODO: for the love of god NEVER let a match-case devolve like this again
match mode, table.file_type, schema_evolution:
case _, FileType.DELTA, SchemaEvolution.STRICT:
match table.file_type, mode, schema_evolution:
case FileType.DELTA, _, SchemaEvolution.STRICT:
raise_for_schema_mismatch(df=df, uc=table.columns)
partition_cols = get_partition_columns(table.columns)
# needing to do this cast is not great, but mypy gets angry if we just pass this as a str to write_delta
# needing to specify the cast is not neat,
# but mypy gets angry if we just pass this as a str to write_delta
write_mode = cast(Literal["append", "overwrite"], mode.value.lower())
if len(partition_cols) > 0:
df.write_delta(
Expand All @@ -289,12 +282,10 @@ def write_table(
df.write_delta(target=path, mode=write_mode)
return None

case _, FileType.DELTA, SchemaEvolution.UNION:
raise NotImplementedError

case WriteMode.OVERWRITE, FileType.DELTA, SchemaEvolution.OVERWRITE:
case FileType.DELTA, WriteMode.OVERWRITE, _:
partition_cols = get_partition_columns(table.columns)
# needing to do this cast is not great, but mypy gets angry if we just pass this as a str to write_delta
# needing to specify the cast is not neat,
# but mypy gets angry if we just pass this as a str to write_delta
write_mode = cast(Literal["append", "overwrite"], mode.value.lower())
if len(partition_cols) > 0:
df.write_delta(
Expand All @@ -317,12 +308,10 @@ def write_table(
except SchemaMismatchError:
return df_schema_to_uc_schema(df=df)

case _, FileType.DELTA, SchemaEvolution.OVERWRITE:
raise UnsupportedOperationError(
"Schema evolution OVERWRITE is only supported when write mode is also OVERWRITE."
)
case FileType.DELTA, WriteMode.APPEND, SchemaEvolution.MERGE:
raise NotImplementedError

case WriteMode.APPEND, FileType.PARQUET, SchemaEvolution.STRICT:
case FileType.PARQUET, WriteMode.APPEND, SchemaEvolution.STRICT:
partition_cols = get_partition_columns(table.columns)
if len(partition_cols) == 0:
raise UnsupportedOperationError(
Expand All @@ -341,49 +330,9 @@ def write_table(
)
return None

case WriteMode.APPEND, FileType.PARQUET, SchemaEvolution.UNION:
raise NotImplementedError

case WriteMode.APPEND, FileType.PARQUET, SchemaEvolution.OVERWRITE:
raise UnsupportedOperationError(
"Schema evolution OVERWRITE is only supported when write mode is also OVERWRITE."
)

case WriteMode.APPEND, _, _:
raise UnsupportedOperationError(
f"Appending is not supported for {table.file_type.value}."
)

case WriteMode.OVERWRITE, FileType.PARQUET, SchemaEvolution.STRICT:
raise_for_schema_mismatch(df=df, uc=table.columns)
partition_cols = get_partition_columns(table.columns)
if len(partition_cols) > 0:
df.write_parquet(
file=path,
use_pyarrow=True,
pyarrow_options={
"partition_cols": [col.name for col in partition_cols],
"basename_template": str(uuid.uuid4())
+ str(time.time()).replace(".", "")
+ "-{i}.parquet",
"existing_data_behavior": "delete_matching",
},
)
else:
df.write_parquet(file=path)
return None

case WriteMode.OVERWRITE, FileType.CSV, SchemaEvolution.STRICT:
raise_for_schema_mismatch(df=df, uc=table.columns)
df.write_csv(file=path)
return None

case WriteMode.OVERWRITE, FileType.AVRO, SchemaEvolution.STRICT:
raise_for_schema_mismatch(df=df, uc=table.columns)
df.write_avro(file=path)
return None

case WriteMode.OVERWRITE, FileType.PARQUET, SchemaEvolution.OVERWRITE:
case FileType.PARQUET, WriteMode.OVERWRITE, _:
if schema_evolution == SchemaEvolution.STRICT:
raise_for_schema_mismatch(df=df, uc=table.columns)
partition_cols = get_partition_columns(table.columns)
if len(partition_cols) > 0:
df.write_parquet(
Expand All @@ -405,26 +354,48 @@ def write_table(
except SchemaMismatchError:
return df_schema_to_uc_schema(df=df)

case WriteMode.OVERWRITE, FileType.CSV, SchemaEvolution.OVERWRITE:
case FileType.CSV, WriteMode.OVERWRITE, SchemaEvolution.STRICT:
raise_for_schema_mismatch(df=df, uc=table.columns)
df.write_csv(file=path)
return None

case FileType.CSV, WriteMode.OVERWRITE, SchemaEvolution.OVERWRITE:
df.write_csv(file=path)
try:
raise_for_schema_mismatch(df=df, uc=table.columns)
return None
except SchemaMismatchError:
return df_schema_to_uc_schema(df=df)

case WriteMode.OVERWRITE, FileType.AVRO, SchemaEvolution.OVERWRITE:
case FileType.AVRO, WriteMode.OVERWRITE, SchemaEvolution.STRICT:
raise_for_schema_mismatch(df=df, uc=table.columns)
df.write_avro(file=path)
return None

case FileType.AVRO, WriteMode.OVERWRITE, SchemaEvolution.OVERWRITE:
df.write_avro(file=path)
try:
raise_for_schema_mismatch(df=df, uc=table.columns)
return None
except SchemaMismatchError:
return df_schema_to_uc_schema(df=df)

case _, WriteMode.APPEND, _:
raise UnsupportedOperationError(
"Write mode APPEND is only supported for DELTA and partitioned PARQUET."
)

case _, _, SchemaEvolution.MERGE:
raise UnsupportedOperationError(
"Schema evolution MERGE is only supported for DELTA."
)

case _, _, SchemaEvolution.OVERWRITE:
raise UnsupportedOperationError(
"Schema evolution OVERWRITE is only supported when write mode is also OVERWRITE."
)

case _:
raise NotImplementedError
raise UnsupportedOperationError(
f"Unsupported parameters: {table.file_type}, {mode}, {schema_evolution}"
)
6 changes: 3 additions & 3 deletions uchelper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def literal_to_writemode(lit: Literal["append", "overwrite"]) -> WriteMode:


def literal_to_schemaevolution(
lit: Literal["strict", "union", "overwrite"]
lit: Literal["strict", "merge", "overwrite"]
) -> SchemaEvolution:
match lit:
case "strict":
return SchemaEvolution.STRICT
case "union":
return SchemaEvolution.UNION
case "merge":
return SchemaEvolution.MERGE
case "overwrite":
return SchemaEvolution.OVERWRITE
case _:
Expand Down

0 comments on commit 4810e81

Please sign in to comment.