Skip to content

Commit

Permalink
implement create_as_table
Browse files Browse the repository at this point in the history
  • Loading branch information
VillePuuska authored Jul 30, 2024
1 parent 0f9ab69 commit 0de1677
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 22 deletions.
109 changes: 94 additions & 15 deletions tests/test_dataframes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import tempfile
import polars as pl
from polars.testing import assert_frame_equal
from polars.testing import assert_frame_equal, assert_frame_not_equal
import deltalake
import random
import uuid
import string
Expand Down Expand Up @@ -60,23 +61,23 @@ def random_partitioned_df() -> pl.DataFrame:


@pytest.mark.parametrize(
"filetype",
"file_type",
[
FileType.DELTA,
FileType.PARQUET,
FileType.CSV,
FileType.AVRO,
],
)
def test_basic_dataframe_operations(client: UCClient, filetype: FileType):
def test_basic_dataframe_operations(client: UCClient, file_type: FileType):
assert client.health_check()

default_catalog = "unity"
default_schema = "default"
table_name = "test_table"

with tempfile.TemporaryDirectory() as tmpdir:
match filetype:
match file_type:
case FileType.DELTA:
filepath = tmpdir
case FileType.PARQUET:
Expand Down Expand Up @@ -119,7 +120,7 @@ def test_basic_dataframe_operations(client: UCClient, filetype: FileType):
catalog_name=default_catalog,
schema_name=default_schema,
table_type=TableType.EXTERNAL,
file_type=filetype,
file_type=file_type,
columns=columns,
storage_location=filepath,
)
Expand All @@ -142,14 +143,14 @@ def test_basic_dataframe_operations(client: UCClient, filetype: FileType):
)
assert_frame_equal(df, df_read, check_row_order=False)

if filetype != FileType.AVRO:
if file_type != FileType.AVRO:
df_scan = client.scan_table(
catalog=default_catalog, schema=default_schema, name=table_name
)
assert_frame_equal(pl.LazyFrame(df), df_scan, check_row_order=False)

# Test APPEND writes; only supported for DELTA
if filetype == FileType.DELTA:
if file_type == FileType.DELTA:
df2 = random_df()
client.write_table(
df2,
Expand Down Expand Up @@ -201,7 +202,7 @@ def test_basic_dataframe_operations(client: UCClient, filetype: FileType):
check_row_order=False,
)

if filetype != FileType.AVRO:
if file_type != FileType.AVRO:
df4_scan = client.scan_table(
catalog=default_catalog, schema=default_schema, name=table_name
)
Expand Down Expand Up @@ -241,29 +242,29 @@ def test_basic_dataframe_operations(client: UCClient, filetype: FileType):
check_row_order=False,
)

if filetype != FileType.AVRO:
if file_type != FileType.AVRO:
df5_scan = client.scan_table(
catalog=default_catalog, schema=default_schema, name=table_name
)
assert_frame_equal(pl.LazyFrame(df5), df5_scan, check_row_order=False)


@pytest.mark.parametrize(
"filetype",
"file_type",
[
FileType.DELTA,
FileType.PARQUET,
],
)
def test_partitioned_dataframe_operations(client: UCClient, filetype: FileType):
def test_partitioned_dataframe_operations(client: UCClient, file_type: FileType):
assert client.health_check()

default_catalog = "unity"
default_schema = "default"
table_name = "test_table"

with tempfile.TemporaryDirectory() as tmpdir:
match filetype:
match file_type:
case FileType.DELTA:
filepath = tmpdir
case FileType.PARQUET:
Expand Down Expand Up @@ -316,7 +317,7 @@ def test_partitioned_dataframe_operations(client: UCClient, filetype: FileType):
catalog_name=default_catalog,
schema_name=default_schema,
table_type=TableType.EXTERNAL,
file_type=filetype,
file_type=file_type,
columns=columns,
storage_location=filepath,
)
Expand Down Expand Up @@ -358,13 +359,13 @@ def test_partitioned_dataframe_operations(client: UCClient, filetype: FileType):
catalog=default_catalog, schema=default_schema, name=table_name
)
assert_frame_equal(pl.concat([df, df2]), df_read2, check_row_order=False)
if filetype == FileType.DELTA:
if file_type == FileType.DELTA:
assert_frame_equal(
pl.concat([df, df2]),
pl.read_delta(source=filepath),
check_row_order=False,
)
if filetype == FileType.PARQUET:
if file_type == FileType.PARQUET:
assert_frame_equal(
pl.concat([df, df2]),
pl.read_parquet(
Expand Down Expand Up @@ -458,3 +459,81 @@ def test_partitioned_dataframe_operations(client: UCClient, filetype: FileType):
catalog=default_catalog, schema=default_schema, name=table_name
)
assert_frame_equal(pl.LazyFrame(df5), df5_scan, check_row_order=False)


@pytest.mark.parametrize(
"file_type,partitioned",
[
(FileType.DELTA, False),
(FileType.PARQUET, False),
(FileType.DELTA, True),
(FileType.PARQUET, True),
(FileType.CSV, False),
(FileType.AVRO, False),
],
)
def test_create_as_table(client: UCClient, file_type: FileType, partitioned: bool):
assert client.health_check()

default_catalog = "unity"
default_schema = "default"
table_name = "test_table"

with tempfile.TemporaryDirectory() as tmpdir:
match file_type:
case FileType.DELTA:
filepath = tmpdir
case FileType.PARQUET:
filepath = os.path.join(tmpdir, table_name + ".parquet")
case FileType.CSV:
filepath = os.path.join(tmpdir, table_name + ".csv")
case FileType.AVRO:
filepath = os.path.join(tmpdir, table_name + ".avro")
case _:
raise NotImplementedError

if not partitioned:
df = random_df()
client.create_as_table(
df=df,
catalog=default_catalog,
schema=default_schema,
name=table_name,
file_type=file_type,
table_type=TableType.EXTERNAL,
location="file://" + filepath,
)
else:
df = random_partitioned_df()
client.create_as_table(
df=df,
catalog=default_catalog,
schema=default_schema,
name=table_name,
file_type=file_type,
table_type=TableType.EXTERNAL,
location="file://" + filepath,
partition_cols=["part1", "part2"],
)

# Test the written table is actually partitioned
if file_type == FileType.DELTA:
tbl_read = deltalake.DeltaTable(table_uri=filepath)
assert tbl_read.metadata().partition_columns == ["part1", "part2"]
elif file_type == FileType.PARQUET:
df_read = pl.read_parquet(
source=os.path.join(filepath, "**", "**", "*.parquet"),
hive_partitioning=True,
hive_schema={"part1": pl.Int64, "part2": pl.Int64},
)
assert_frame_equal(df, df_read, check_row_order=False)
assert_frame_not_equal(
df,
pl.read_parquet(source=filepath, hive_partitioning=False),
check_row_order=False,
)

df_read = client.read_table(
catalog=default_catalog, schema=default_schema, name=table_name
)
assert_frame_equal(df, df_read, check_row_order=False)
52 changes: 45 additions & 7 deletions uc_wrapper/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import requests
import polars as pl
from .exceptions import UnsupportedOperationError
from .models import Catalog, Schema, Table, TableType, FileType
from .dataframe import WriteMode, SchemaEvolution, read_table, scan_table, write_table
from .dataframe import (
WriteMode,
SchemaEvolution,
read_table,
scan_table,
write_table,
df_schema_to_uc_schema,
)
from .uc_api_wrapper import (
create_catalog,
create_schema,
Expand All @@ -18,7 +26,6 @@
list_tables,
update_catalog,
update_schema,
DoesNotExistError,
)


Expand Down Expand Up @@ -275,16 +282,47 @@ def write_table(

def create_as_table(
self,
df: pl.DataFrame | pl.LazyFrame,
df: pl.DataFrame,
catalog: str,
schema: str,
name: str,
filetype: FileType = FileType.DELTA,
type: TableType = TableType.MANAGED,
file_type: FileType = FileType.DELTA,
table_type: TableType = TableType.MANAGED,
location: str | None = None,
partition_cols: list[str] | None = None,
) -> Table:
"""
Creates a new table to Unity Catalog with the schema of the Polars DataFrame or LazyFrame `df`
Creates a new table to Unity Catalog with the schema of the Polars DataFrame `df`
and writes `df` to the new table. Raises an AlreadyExistsError if the table alredy exists.
"""
raise NotImplementedError
if table_type == TableType.MANAGED:
raise UnsupportedOperationError("MANAGED tables are not yet supported.")
if table_type == TableType.EXTERNAL and location is None:
raise UnsupportedOperationError(
"To create an EXTERNAL table, you must specify a location to store it in."
)
if not location.startswith("file://"):
raise UnsupportedOperationError(
"Only local storage is supported. Hint: location must be of the form file://<absolute_path>, e.g. file:///home/me/ex-delta-table"
)
cols = df_schema_to_uc_schema(df=df)
if partition_cols is not None:
for i, col in enumerate(cols):
if col.name not in partition_cols:
continue
partition_ind = partition_cols.index(col.name)
cols[i].partition_index = partition_ind
table = Table(
name=name,
catalog_name=catalog,
schema_name=schema,
table_type=table_type,
file_type=file_type,
columns=cols,
storage_location=location,
)
table = self.create_table(table=table)
self.write_table(
df=df, catalog=catalog, schema=schema, name=name, mode=WriteMode.OVERWRITE
)
return table

0 comments on commit 0de1677

Please sign in to comment.