diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 30675d5550465..2475984fbc392 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -52,6 +52,7 @@ _create_row, ) from pyspark.errors import PySparkTypeError, UnsupportedOperationException, PySparkValueError +from pyspark.loose_version import LooseVersion if TYPE_CHECKING: import pandas as pd @@ -242,6 +243,8 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da spark_type = StringType() elif types.is_binary(at): spark_type = BinaryType() + elif types.is_fixed_size_binary(at): + spark_type = BinaryType() elif types.is_large_binary(at): spark_type = BinaryType() elif types.is_date32(at): @@ -254,6 +257,18 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da spark_type = DayTimeIntervalType() elif types.is_list(at): spark_type = ArrayType(from_arrow_type(at.value_type, prefer_timestamp_ntz)) + elif types.is_fixed_size_list(at): + import pyarrow as pa + + if LooseVersion(pa.__version__) < LooseVersion("14.0.0"): + # PyArrow versions before 14.0.0 do not support casting FixedSizeListArray to ListArray + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION", + message_parameters={"data_type": str(at)}, + ) + spark_type = ArrayType(from_arrow_type(at.value_type, prefer_timestamp_ntz)) + elif types.is_large_list(at): + spark_type = ArrayType(from_arrow_type(at.value_type, prefer_timestamp_ntz)) elif types.is_map(at): spark_type = MapType( from_arrow_type(at.key_type, prefer_timestamp_ntz), diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index a2221f983694e..74d56491a29e7 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -56,6 +56,7 @@ ) from pyspark.errors import ArithmeticException, PySparkTypeError, UnsupportedOperationException from pyspark.util import is_remote_only +from pyspark.loose_version import LooseVersion if have_pandas: import pandas as pd @@ -1656,6 +1657,50 @@ def test_negative_and_zero_batch_size(self): pdf = pd.DataFrame({"a": [123]}) assert_frame_equal(pdf, self.spark.createDataFrame(pdf).toPandas()) + def test_createDataFrame_arrow_large_string(self): + a = pa.array(["a"] * 5, type=pa.large_string()) + t = pa.table([a], ["ls"]) + df = self.spark.createDataFrame(t) + self.assertIsInstance(df.schema["ls"].dataType, StringType) + + def test_createDataFrame_arrow_large_binary(self): + a = pa.array(["a"] * 5, type=pa.large_binary()) + t = pa.table([a], ["lb"]) + df = self.spark.createDataFrame(t) + self.assertIsInstance(df.schema["lb"].dataType, BinaryType) + + def test_createDataFrame_arrow_large_list(self): + a = pa.array([[-1, 3]] * 5, type=pa.large_list(pa.int32())) + t = pa.table([a], ["ll"]) + df = self.spark.createDataFrame(t) + self.assertIsInstance(df.schema["ll"].dataType, ArrayType) + + def test_createDataFrame_arrow_large_list_int64_offset(self): + # Check for expected failure if the large list contains an index >= 2^31 + a = pa.LargeListArray.from_arrays( + [0, 2**31], pa.NullArray.from_buffers(pa.null(), 2**31, [None]) + ) + t = pa.table([a], ["ll"]) + with self.assertRaises(Exception): + self.spark.createDataFrame(t) + + def test_createDataFrame_arrow_fixed_size_binary(self): + a = pa.array(["a"] * 5, type=pa.binary(1)) + t = pa.table([a], ["fsb"]) + df = self.spark.createDataFrame(t) + self.assertIsInstance(df.schema["fsb"].dataType, BinaryType) + + def test_createDataFrame_arrow_fixed_size_list(self): + a = pa.array([[-1, 3]] * 5, type=pa.list_(pa.int32(), 2)) + t = pa.table([a], ["fsl"]) + if LooseVersion(pa.__version__) < LooseVersion("14.0.0"): + # PyArrow versions before 14.0.0 do not support casting FixedSizeListArray to ListArray + with self.assertRaises(PySparkTypeError): + df = self.spark.createDataFrame(t) + else: + df = self.spark.createDataFrame(t) + self.assertIsInstance(df.schema["fsl"].dataType, ArrayType) + @unittest.skipIf( not have_pandas or not have_pyarrow,