-
Notifications
You must be signed in to change notification settings - Fork 237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support RaiseError [databricks] #5540
Conversation
} | ||
|
||
// Take the first one as the error message | ||
val msg = input.copyToHost().getUTF8String(0).toString |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any necessary to only copy the first row to get the error msg instead of copying the whole column vector?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is on an error case. The entire job is going to fail. I am not too concerned with failing faster than the CPU. Yes it would be nice to not copy everything. You can do that with getScalarElement
, which should keep the code small and clean.
withResource(input.getScalarElement(0)) { scalarMsg =>
if (!scalarMsg.isValid()) {
throw new RuntimeException()
} else {
throw new RuntimeException(scalarMsg.getJavaString())
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion. Done.
Signed-off-by: Bobby Wang <wbo4958@gmail.com>
build |
override def hasSideEffects: Boolean = true | ||
|
||
override protected def doColumnar(input: GpuColumnVector): ColumnVector = { | ||
if (input == null || input.getRowCount <= 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input
should never be null. If it is that is an internal error. I am okay with what you are doing but it would be good to know something unexpected happened.
If input has no rows in it, then I don't want to throw an exception. Just return an empty ColumnVector. This I can see actually happening if you have an IF/ELSE to check for error cases. I don't know if there are any corner cases when nothing matched and we got an empty ColumnVector, but I can see it happening.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I just changed this and added the related tests. Thx
} | ||
|
||
// Take the first one as the error message | ||
val msg = input.copyToHost().getUTF8String(0).toString |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is on an error case. The entire job is going to fail. I am not too concerned with failing faster than the CPU. Yes it would be nice to not copy everything. You can do that with getScalarElement
, which should keep the code small and clean.
withResource(input.getScalarElement(0)) { scalarMsg =>
if (!scalarMsg.isValid()) {
throw new RuntimeException()
} else {
throw new RuntimeException(scalarMsg.getJavaString())
}
}
} | ||
|
||
// Take the first one as the error message | ||
val msg = input.copyToHost().getUTF8String(0).toString |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This leaks the host column vector, and I don't know if this is going to do the right thing if that first string is null. We should have an explicit test for when the first column is null, I think we will get an assertion error if they are turned on.
build |
1 similar comment
build |
|
||
override protected def doColumnar(input: GpuColumnVector): ColumnVector = { | ||
if (input.getRowCount <= 0) { | ||
// For the case: when(condition, raise_error()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be nice to cover this specific case raise_error()
in the python tests, it doesn't seem like we are.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, according to spark, I don't think this raise_error()
(no args) is possible:
pyspark.sql.utils.AnalysisException: Invalid number of arguments for function raise_error. Expected: 1; Found: 0; line 1 pos 7
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is possible and should be tested, e.g.:
>>> import pyspark.sql.functions as f
>>> df = spark.range(0)
>>> df.count()
0
>>> df.select(f.raise_error(f.col("id"))).explain()
== Physical Plan ==
*(1) Project [raise_error(cast(id#12L as string), NullType) AS raise_error(id)#20]
+- *(1) Range (0, 0, step=1, splits=12)
>>> df.select(f.raise_error(f.col("id"))).collect()
[]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, raise_error needs to accept the parameter and I just updated the comment.
Actually, according to spark, I don't think this
raise_error()
(no args) is possible:pyspark.sql.utils.AnalysisException: Invalid number of arguments for function raise_error. Expected: 1; Found: 0; line 1 pos 7
Yeah, raise_error needs to accept the parameter and I just updated the comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is possible and should be tested, e.g.:
>>> import pyspark.sql.functions as f >>> df = spark.range(0) >>> df.count() 0 >>> df.select(f.raise_error(f.col("id"))).explain() == Physical Plan == *(1) Project [raise_error(cast(id#12L as string), NullType) AS raise_error(id)#20] +- *(1) Range (0, 0, step=1, splits=12) >>> df.select(f.raise_error(f.col("id"))).collect() []
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
) | ||
), | ||
expr[RaiseError]( | ||
"throw exception", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: All other descriptions start with a capital letter and are a bit more descriptive, as seen in the generated configs.md docs.
"throw exception", | |
"Throws an exception", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx, Done
lambda spark : unary_op_df(spark, short_gen, num_slices=2).select( | ||
f.raise_error(f.col('a'))).collect(), | ||
conf={}, | ||
error_message="java.lang.RuntimeException") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test should verify we are properly conveying the specified error message into the exception rather than just checking for the same exception type, checking both the null first element and non-null first element scenarios.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
override protected def doColumnar(input: GpuColumnVector): ColumnVector = { | ||
if (input.getRowCount <= 0) { | ||
// For the case: when(condition, raise_error()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is possible and should be tested, e.g.:
>>> import pyspark.sql.functions as f
>>> df = spark.range(0)
>>> df.count()
0
>>> df.select(f.raise_error(f.col("id"))).explain()
== Physical Plan ==
*(1) Project [raise_error(cast(id#12L as string), NullType) AS raise_error(id)#20]
+- *(1) Range (0, 0, step=1, splits=12)
>>> df.select(f.raise_error(f.col("id"))).collect()
[]
build |
build |
This PR adds GpuRaiseError to replace RaiseError Expression. It is to fix #5507.