Skip to content

Commit

Permalink
[SPARK-46565][PYTHON] Refine error classes and error messages for Pyt…
Browse files Browse the repository at this point in the history
…hon data sources

### What changes were proposed in this pull request?

This PR improves error classes and messages associated with Python data sources. It removes unnecessary error handling in Python and makes error class names more user-friendly.

### Why are the changes needed?

To make the error messages clearer and more user-friendly. For instance, current stack traces has redundant information
```
AnalysisException: [PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON] Failed to create Python data source instance in Python: Traceback (most recent call last):
 ...
pyspark.errors.exceptions.base.PySparkNotImplementedError: [NOT_IMPLEMENTED] schema is not implemented.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
    raise PySparkRuntimeError(
pyspark.errors.exceptions.base.PySparkRuntimeError: [PYTHON_DATA_SOURCE_METHOD_NOT_IMPLEMENTED] Unable to create the Python data source instance because the 'schema' method hasn't been implemented.
 SQLSTATE: 38000
```
After this PR, this `During handling of the above exception, another exception occurred:` will not show up.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing unit tests

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#44560 from allisonwang-db/spark-46565-pyds-error-msgs.

Authored-by: allisonwang-db <allison.wang@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
allisonwang-db authored and HyukjinKwon committed Jan 3, 2024
1 parent 31b3f81 commit bdb6172
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 114 deletions.
4 changes: 2 additions & 2 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -3009,9 +3009,9 @@
],
"sqlState" : "42K0G"
},
"PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON" : {
"PYTHON_DATA_SOURCE_ERROR" : {
"message" : [
"Failed to <action> Python data source <type> in Python: <msg>"
"Failed to <action> Python data source <type>: <msg>"
],
"sqlState" : "38000"
},
Expand Down
4 changes: 2 additions & 2 deletions docs/sql-error-conditions.md
Original file line number Diff line number Diff line change
Expand Up @@ -1808,11 +1808,11 @@ Unable to locate Message `<messageName>` in Descriptor.

Protobuf type not yet supported: `<protobufType>`.

### PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON
### PYTHON_DATA_SOURCE_ERROR

[SQLSTATE: 38000](sql-error-conditions-sqlstates.html#class-38-external-routine-exception)

Failed to `<action>` Python data source `<type>` in Python: `<msg>`
Failed to `<action>` Python data source `<type>`: `<msg>`

### RECURSIVE_PROTOBUF_SCHEMA

Expand Down
45 changes: 15 additions & 30 deletions python/pyspark/errors/error_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,21 @@
"Remote client cannot create a SparkContext. Create SparkSession instead."
]
},
"DATA_SOURCE_INVALID_RETURN_TYPE" : {
"message" : [
"Unsupported return type ('<type>') from Python data source '<name>'. Expected types: <supported_types>."
]
},
"DATA_SOURCE_RETURN_SCHEMA_MISMATCH" : {
"message" : [
"Return schema mismatch in the result from 'read' method. Expected: <expected> columns, Found: <actual> columns. Make sure the returned values match the required output schema."
]
},
"DATA_SOURCE_TYPE_MISMATCH" : {
"message" : [
"Expected <expected>, but got <actual>."
]
},
"DIFFERENT_PANDAS_DATAFRAME" : {
"message" : [
"DataFrames are not almost equal:",
Expand Down Expand Up @@ -747,36 +762,6 @@
"Pipe function `<func_name>` exited with error code <error_code>."
]
},
"PYTHON_DATA_SOURCE_CREATE_ERROR" : {
"message" : [
"Unable to create the Python data source <type>: <error>."
]
},
"PYTHON_DATA_SOURCE_METHOD_NOT_IMPLEMENTED" : {
"message" : [
"Unable to create the Python data source <type> because the '<method>' method hasn't been implemented."
]
},
"PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE" : {
"message" : [
"The data type of the returned value ('<type>') from the Python data source '<name>' is not supported. Supported types: <supported_types>."
]
},
"PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH" : {
"message" : [
"The number of columns in the result does not match the required schema. Expected column count: <expected>, Actual column count: <actual>. Please make sure the values returned by the 'read' method have the same number of columns as required by the output schema."
]
},
"PYTHON_DATA_SOURCE_TYPE_MISMATCH" : {
"message" : [
"Expected <expected>, but got <actual>."
]
},
"PYTHON_DATA_SOURCE_WRITE_ERROR" : {
"message" : [
"Unable to write to the Python data source: <error>."
]
},
"PYTHON_HASH_SEED_NOT_SET" : {
"message" : [
"Randomness of hash of string should be disabled via PYTHONHASHSEED."
Expand Down
12 changes: 4 additions & 8 deletions python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_data_source_read_output_named_row_with_wrong_schema(self):
def test_data_source_read_output_none(self):
self.register_data_source(read_func=lambda schema, partition: None)
df = self.spark.read.format("test").load()
with self.assertRaisesRegex(PythonException, "PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE"):
with self.assertRaisesRegex(PythonException, "DATA_SOURCE_INVALID_RETURN_TYPE"):
assertDataFrameEqual(df, [])

def test_data_source_read_output_empty_iter(self):
Expand Down Expand Up @@ -186,22 +186,18 @@ def read_func(schema, partition):
def test_data_source_read_output_with_schema_mismatch(self):
self.register_data_source(read_func=lambda schema, partition: iter([(0, 1)]))
df = self.spark.read.format("test").schema("i int").load()
with self.assertRaisesRegex(
PythonException, "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH"
):
with self.assertRaisesRegex(PythonException, "DATA_SOURCE_RETURN_SCHEMA_MISMATCH"):
df.collect()
self.register_data_source(
read_func=lambda schema, partition: iter([(0, 1)]), output="i int, j int, k int"
)
with self.assertRaisesRegex(
PythonException, "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH"
):
with self.assertRaisesRegex(PythonException, "DATA_SOURCE_RETURN_SCHEMA_MISMATCH"):
df.collect()

def test_read_with_invalid_return_row_type(self):
self.register_data_source(read_func=lambda schema, partition: iter([1]))
df = self.spark.read.format("test").load()
with self.assertRaisesRegex(PythonException, "PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE"):
with self.assertRaisesRegex(PythonException, "DATA_SOURCE_INVALID_RETURN_TYPE"):
df.collect()

def test_in_memory_data_source(self):
Expand Down
46 changes: 12 additions & 34 deletions python/pyspark/sql/worker/plan_data_source_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def main(infile: IO, outfile: IO) -> None:
data_source = read_command(pickleSer, infile)
if not isinstance(data_source, DataSource):
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "a Python data source instance of type 'DataSource'",
"actual": f"'{type(data_source).__name__}'",
Expand All @@ -97,7 +97,7 @@ def main(infile: IO, outfile: IO) -> None:
input_schema = _parse_datatype_json_string(input_schema_json)
if not isinstance(input_schema, StructType):
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "an input schema of type 'StructType'",
"actual": f"'{type(input_schema).__name__}'",
Expand Down Expand Up @@ -128,54 +128,32 @@ def main(infile: IO, outfile: IO) -> None:
)

# Instantiate data source reader.
try:
reader = data_source.reader(schema=schema)
except NotImplementedError:
raise PySparkRuntimeError(
error_class="PYTHON_DATA_SOURCE_METHOD_NOT_IMPLEMENTED",
message_parameters={"type": "reader", "method": "reader"},
)
except Exception as e:
raise PySparkRuntimeError(
error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
message_parameters={"type": "reader", "error": str(e)},
)
reader = data_source.reader(schema=schema)

# Get the partitions if any.
try:
partitions = reader.partitions()
if not isinstance(partitions, list):
raise PySparkRuntimeError(
error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"type": "reader",
"error": (
"Expect 'partitions' to return a list, but got "
f"'{type(partitions).__name__}'"
),
"expected": "'partitions' to return a list",
"actual": f"'{type(partitions).__name__}'",
},
)
if not all(isinstance(p, InputPartition) for p in partitions):
partition_types = ", ".join([f"'{type(p).__name__}'" for p in partitions])
raise PySparkRuntimeError(
error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"type": "reader",
"error": (
"All elements in 'partitions' should be of type "
f"'InputPartition', but got {partition_types}"
),
"expected": "all elements in 'partitions' to be of type 'InputPartition'",
"actual": partition_types,
},
)
if len(partitions) == 0:
partitions = [None] # type: ignore
except NotImplementedError:
partitions = [None] # type: ignore
except Exception as e:
raise PySparkRuntimeError(
error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
message_parameters={"type": "reader", "error": str(e)},
)

# Wrap the data source read logic in an mapInArrow UDF.
import pyarrow as pa
Expand Down Expand Up @@ -222,7 +200,7 @@ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.Rec
# Validate the output iterator.
if not isinstance(output_iter, Iterator):
raise PySparkRuntimeError(
error_class="PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE",
error_class="DATA_SOURCE_INVALID_RETURN_TYPE",
message_parameters={
"type": type(output_iter).__name__,
"name": data_source.name(),
Expand All @@ -243,7 +221,7 @@ def batched(iterator: Iterator, n: int) -> Iterator:
# Validate the output row schema.
if hasattr(result, "__len__") and len(result) != num_cols:
raise PySparkRuntimeError(
error_class="PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH",
error_class="DATA_SOURCE_RETURN_SCHEMA_MISMATCH",
message_parameters={
"expected": str(num_cols),
"actual": str(len(result)),
Expand All @@ -253,7 +231,7 @@ def batched(iterator: Iterator, n: int) -> Iterator:
# Validate the output row type.
if not isinstance(result, (list, tuple)):
raise PySparkRuntimeError(
error_class="PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE",
error_class="DATA_SOURCE_INVALID_RETURN_TYPE",
message_parameters={
"type": type(result).__name__,
"name": data_source.name(),
Expand Down
34 changes: 12 additions & 22 deletions python/pyspark/sql/worker/write_into_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def main(infile: IO, outfile: IO) -> None:
data_source_cls = read_command(pickleSer, infile)
if not (isinstance(data_source_cls, type) and issubclass(data_source_cls, DataSource)):
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "a subclass of DataSource",
"actual": f"'{type(data_source_cls).__name__}'",
Expand All @@ -94,7 +94,7 @@ def main(infile: IO, outfile: IO) -> None:
# Check the name method is a class method.
if not inspect.ismethod(data_source_cls.name):
raise PySparkTypeError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "'name()' method to be a classmethod",
"actual": f"'{type(data_source_cls.name).__name__}'",
Expand All @@ -107,7 +107,7 @@ def main(infile: IO, outfile: IO) -> None:
# Check if the provider name matches the data source's name.
if provider.lower() != data_source_cls.name().lower():
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": f"provider with name {data_source_cls.name()}",
"actual": f"'{provider}'",
Expand All @@ -118,7 +118,7 @@ def main(infile: IO, outfile: IO) -> None:
schema = _parse_datatype_json_string(utf8_deserializer.loads(infile))
if not isinstance(schema, StructType):
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "the schema to be a 'StructType'",
"actual": f"'{type(data_source_cls).__name__}'",
Expand All @@ -129,7 +129,7 @@ def main(infile: IO, outfile: IO) -> None:
return_type = _parse_datatype_json_string(utf8_deserializer.loads(infile))
if not isinstance(return_type, StructType):
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "a return type of type 'StructType'",
"actual": f"'{type(return_type).__name__}'",
Expand All @@ -153,22 +153,10 @@ def main(infile: IO, outfile: IO) -> None:
overwrite = read_bool(infile)

# Instantiate a data source.
try:
data_source = data_source_cls(options=options)
except Exception as e:
raise PySparkRuntimeError(
error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
message_parameters={"type": "instance", "error": str(e)},
)
data_source = data_source_cls(options=options)

# Instantiate the data source writer.
try:
writer = data_source.writer(schema, overwrite)
except Exception as e:
raise PySparkRuntimeError(
error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
message_parameters={"type": "writer", "error": str(e)},
)
writer = data_source.writer(schema, overwrite)

# Create a function that can be used in mapInArrow.
import pyarrow as pa
Expand All @@ -193,10 +181,12 @@ def batch_to_rows() -> Iterator[Row]:
# Check the commit message has the right type.
if not isinstance(res, WriterCommitMessage):
raise PySparkRuntimeError(
error_class="PYTHON_DATA_SOURCE_WRITE_ERROR",
error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"error": f"return type of the `write` method must be "
f"an instance of WriterCommitMessage, but got {type(res)}"
"expected": (
"'WriterCommitMessage' as the return type of " "the `write` method"
),
"actual": type(res).__name__,
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2040,9 +2040,9 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
)
}

def failToPlanDataSourceError(action: String, tpe: String, msg: String): Throwable = {
def pythonDataSourceError(action: String, tpe: String, msg: String): Throwable = {
new AnalysisException(
errorClass = "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON",
errorClass = "PYTHON_DATA_SOURCE_ERROR",
messageParameters = Map("action" -> action, "type" -> tpe, "msg" -> msg)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ class UserDefinedPythonDataSourceLookupRunner(lookupSources: PythonFunction)
val length = dataIn.readInt()
if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
val msg = PythonWorkerUtils.readUTF(dataIn)
throw QueryCompilationErrors.failToPlanDataSourceError(
throw QueryCompilationErrors.pythonDataSourceError(
action = "lookup", tpe = "instance", msg = msg)
}

Expand Down Expand Up @@ -524,7 +524,7 @@ class UserDefinedPythonDataSourceRunner(
val length = dataIn.readInt()
if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
val msg = PythonWorkerUtils.readUTF(dataIn)
throw QueryCompilationErrors.failToPlanDataSourceError(
throw QueryCompilationErrors.pythonDataSourceError(
action = "create", tpe = "instance", msg = msg)
}

Expand Down Expand Up @@ -587,7 +587,7 @@ class UserDefinedPythonDataSourceReadRunner(
val length = dataIn.readInt()
if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
val msg = PythonWorkerUtils.readUTF(dataIn)
throw QueryCompilationErrors.failToPlanDataSourceError(
throw QueryCompilationErrors.pythonDataSourceError(
action = "plan", tpe = "read", msg = msg)
}

Expand Down Expand Up @@ -657,7 +657,7 @@ class UserDefinedPythonDataSourceWriteRunner(
val length = dataIn.readInt()
if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
val msg = PythonWorkerUtils.readUTF(dataIn)
throw QueryCompilationErrors.failToPlanDataSourceError(
throw QueryCompilationErrors.pythonDataSourceError(
action = "plan", tpe = "write", msg = msg)
}

Expand Down Expand Up @@ -707,7 +707,7 @@ class UserDefinedPythonDataSourceCommitRunner(
val code = dataIn.readInt()
if (code == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
val msg = PythonWorkerUtils.readUTF(dataIn)
throw QueryCompilationErrors.failToPlanDataSourceError(
throw QueryCompilationErrors.pythonDataSourceError(
action = "commit or abort", tpe = "write", msg = msg)
}
assert(code == 0, s"Python commit job should run successfully, but got exit code: $code")
Expand Down
Loading

0 comments on commit bdb6172

Please sign in to comment.