Skip to content

Commit

Permalink
[SPARK-48374][PYTHON] Support additional PyArrow Table column types
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This is a small follow-up to apache#46529. This adds support for some more Arrow data types:
- fixed-size binary
- fixed-size list
- large list

### Why are the changes needed?
Users who are creating Spark DataFrames from PyArrow Tables will expect it to work if their Tables contain these types of columns

### Does this PR introduce _any_ user-facing change?
It will prevent an error in the case where the user has one of these types of columns in their PyArrow Table. There are no other user-facing changes.

### How was this patch tested?
Tests are included.

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#46688 from ianmcook/SPARK-48374.

Authored-by: Ian Cook <ianmcook@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
ianmcook authored and HyukjinKwon committed Jun 3, 2024
1 parent bc18701 commit e208b77
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
15 changes: 15 additions & 0 deletions python/pyspark/sql/pandas/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
_create_row,
)
from pyspark.errors import PySparkTypeError, UnsupportedOperationException, PySparkValueError
from pyspark.loose_version import LooseVersion

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -242,6 +243,8 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da
spark_type = StringType()
elif types.is_binary(at):
spark_type = BinaryType()
elif types.is_fixed_size_binary(at):
spark_type = BinaryType()
elif types.is_large_binary(at):
spark_type = BinaryType()
elif types.is_date32(at):
Expand All @@ -254,6 +257,18 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da
spark_type = DayTimeIntervalType()
elif types.is_list(at):
spark_type = ArrayType(from_arrow_type(at.value_type, prefer_timestamp_ntz))
elif types.is_fixed_size_list(at):
import pyarrow as pa

if LooseVersion(pa.__version__) < LooseVersion("14.0.0"):
# PyArrow versions before 14.0.0 do not support casting FixedSizeListArray to ListArray
raise PySparkTypeError(
error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION",
message_parameters={"data_type": str(at)},
)
spark_type = ArrayType(from_arrow_type(at.value_type, prefer_timestamp_ntz))
elif types.is_large_list(at):
spark_type = ArrayType(from_arrow_type(at.value_type, prefer_timestamp_ntz))
elif types.is_map(at):
spark_type = MapType(
from_arrow_type(at.key_type, prefer_timestamp_ntz),
Expand Down
45 changes: 45 additions & 0 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
)
from pyspark.errors import ArithmeticException, PySparkTypeError, UnsupportedOperationException
from pyspark.util import is_remote_only
from pyspark.loose_version import LooseVersion

if have_pandas:
import pandas as pd
Expand Down Expand Up @@ -1656,6 +1657,50 @@ def test_negative_and_zero_batch_size(self):
pdf = pd.DataFrame({"a": [123]})
assert_frame_equal(pdf, self.spark.createDataFrame(pdf).toPandas())

def test_createDataFrame_arrow_large_string(self):
a = pa.array(["a"] * 5, type=pa.large_string())
t = pa.table([a], ["ls"])
df = self.spark.createDataFrame(t)
self.assertIsInstance(df.schema["ls"].dataType, StringType)

def test_createDataFrame_arrow_large_binary(self):
a = pa.array(["a"] * 5, type=pa.large_binary())
t = pa.table([a], ["lb"])
df = self.spark.createDataFrame(t)
self.assertIsInstance(df.schema["lb"].dataType, BinaryType)

def test_createDataFrame_arrow_large_list(self):
a = pa.array([[-1, 3]] * 5, type=pa.large_list(pa.int32()))
t = pa.table([a], ["ll"])
df = self.spark.createDataFrame(t)
self.assertIsInstance(df.schema["ll"].dataType, ArrayType)

def test_createDataFrame_arrow_large_list_int64_offset(self):
# Check for expected failure if the large list contains an index >= 2^31
a = pa.LargeListArray.from_arrays(
[0, 2**31], pa.NullArray.from_buffers(pa.null(), 2**31, [None])
)
t = pa.table([a], ["ll"])
with self.assertRaises(Exception):
self.spark.createDataFrame(t)

def test_createDataFrame_arrow_fixed_size_binary(self):
a = pa.array(["a"] * 5, type=pa.binary(1))
t = pa.table([a], ["fsb"])
df = self.spark.createDataFrame(t)
self.assertIsInstance(df.schema["fsb"].dataType, BinaryType)

def test_createDataFrame_arrow_fixed_size_list(self):
a = pa.array([[-1, 3]] * 5, type=pa.list_(pa.int32(), 2))
t = pa.table([a], ["fsl"])
if LooseVersion(pa.__version__) < LooseVersion("14.0.0"):
# PyArrow versions before 14.0.0 do not support casting FixedSizeListArray to ListArray
with self.assertRaises(PySparkTypeError):
df = self.spark.createDataFrame(t)
else:
df = self.spark.createDataFrame(t)
self.assertIsInstance(df.schema["fsl"].dataType, ArrayType)


@unittest.skipIf(
not have_pandas or not have_pyarrow,
Expand Down

0 comments on commit e208b77

Please sign in to comment.