From 00bff28cd66dc4d68de836c4290291ebc1df572b Mon Sep 17 00:00:00 2001 From: panbingkun Date: Tue, 12 Nov 2024 19:27:49 +0800 Subject: [PATCH] [SPARK-50056][SQL] Codegen Support for ParseUrl (by Invoke & RuntimeReplaceable) ### What changes were proposed in this pull request? The pr aims to add `Codegen` Support for `parse_url`. ### Why are the changes needed? - improve codegen coverage. - simplified code. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA & Existed UT (eg: UrlFunctionsSuite#`*parse_url*`) ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48579 from panbingkun/SPARK-50056. Authored-by: panbingkun Signed-off-by: Wenchen Fan --- .../url/UrlExpressionEvalUtils.scala | 148 +++++++++++++++ .../catalyst/expressions/urlExpressions.scala | 178 ++++-------------- .../expressions/StringExpressionsSuite.scala | 5 +- .../function_parse_url.explain | 2 +- .../function_parse_url_with_key.explain | 2 +- .../function_try_parse_url.explain | 2 +- .../function_try_parse_url_with_key.explain | 2 +- 7 files changed, 194 insertions(+), 145 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/url/UrlExpressionEvalUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/url/UrlExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/url/UrlExpressionEvalUtils.scala new file mode 100644 index 0000000000000..1eaa25a6bf72c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/url/UrlExpressionEvalUtils.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.url + +import java.net.{URI, URISyntaxException} +import java.util.regex.Pattern + +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.unsafe.types.UTF8String + +case class ParseUrlEvaluator( + url: UTF8String, + extractPart: UTF8String, + pattern: UTF8String, + failOnError: Boolean) { + + import ParseUrlEvaluator._ + + private lazy val cachedUrl: URI = + if (url != null) getUrl(url, failOnError) else null + + private lazy val cachedExtractPartFunc: URI => String = + if (extractPart != null) getExtractPartFunc(extractPart) else null + + private lazy val cachedPattern: Pattern = + if (pattern != null) getPattern(pattern) else null + + private def extractValueFromQuery(query: UTF8String, pattern: Pattern): UTF8String = { + val m = pattern.matcher(query.toString) + if (m.find()) { + UTF8String.fromString(m.group(2)) + } else { + null + } + } + + private def extractFromUrl(url: URI, partToExtract: UTF8String): UTF8String = { + if (cachedExtractPartFunc ne null) { + UTF8String.fromString(cachedExtractPartFunc(url)) + } else { + UTF8String.fromString(getExtractPartFunc(partToExtract)(url)) + } + } + + private def parseUrlWithoutKey(url: UTF8String, partToExtract: UTF8String): UTF8String = { + if (cachedUrl ne null) { + extractFromUrl(cachedUrl, partToExtract) + } else { + val currentUrl = getUrl(url, failOnError) + if (currentUrl ne null) { + extractFromUrl(currentUrl, partToExtract) + } else { + null + } + } + } + + final def evaluate(url: UTF8String, path: UTF8String): Any = { + parseUrlWithoutKey(url, path) + } + + final def evaluate(url: UTF8String, path: UTF8String, key: UTF8String): Any = { + if (path != QUERY) return null + + val query = parseUrlWithoutKey(url, path) + if (query eq null) return null + + if (cachedPattern ne null) { + extractValueFromQuery(query, cachedPattern) + } else { + extractValueFromQuery(query, getPattern(key)) + } + } +} + +object ParseUrlEvaluator { + private val HOST = UTF8String.fromString("HOST") + private val PATH = UTF8String.fromString("PATH") + private val QUERY = UTF8String.fromString("QUERY") + private val REF = UTF8String.fromString("REF") + private val PROTOCOL = UTF8String.fromString("PROTOCOL") + private val FILE = UTF8String.fromString("FILE") + private val AUTHORITY = UTF8String.fromString("AUTHORITY") + private val USERINFO = UTF8String.fromString("USERINFO") + private val REGEXPREFIX = "(&|^)" + private val REGEXSUBFIX = "=([^&]*)" + + private def getPattern(key: UTF8String): Pattern = { + Pattern.compile(REGEXPREFIX + key.toString + REGEXSUBFIX) + } + + private def getUrl(url: UTF8String, failOnError: Boolean): URI = { + try { + new URI(url.toString) + } catch { + case e: URISyntaxException if failOnError => + throw QueryExecutionErrors.invalidUrlError(url, e) + case _: URISyntaxException => null + } + } + + private def getExtractPartFunc(partToExtract: UTF8String): URI => String = { + + // partToExtract match { + // case HOST => _.toURL().getHost + // case PATH => _.toURL().getPath + // case QUERY => _.toURL().getQuery + // case REF => _.toURL().getRef + // case PROTOCOL => _.toURL().getProtocol + // case FILE => _.toURL().getFile + // case AUTHORITY => _.toURL().getAuthority + // case USERINFO => _.toURL().getUserInfo + // case _ => (url: URI) => null + // } + + partToExtract match { + case HOST => _.getHost + case PATH => _.getRawPath + case QUERY => _.getRawQuery + case REF => _.getRawFragment + case PROTOCOL => _.getScheme + case FILE => + (url: URI) => + if (url.getRawQuery ne null) { + url.getRawPath + "?" + url.getRawQuery + } else { + url.getRawPath + } + case AUTHORITY => _.getRawAuthority + case USERINFO => _.getRawUserInfo + case _ => (_: URI) => null + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala index bf1a788554284..22dcd33937dfb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala @@ -17,20 +17,18 @@ package org.apache.spark.sql.catalyst.expressions -import java.net.{URI, URISyntaxException, URLDecoder, URLEncoder} +import java.net.{URLDecoder, URLEncoder} import java.nio.charset.StandardCharsets -import java.util.regex.Pattern -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.Cast._ -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke +import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} +import org.apache.spark.sql.catalyst.expressions.url.ParseUrlEvaluator import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeWithCollation -import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType} +import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, ObjectType} import org.apache.spark.unsafe.types.UTF8String // scalastyle:off line.size.limit @@ -156,19 +154,6 @@ object UrlCodec { } } -object ParseUrl { - private val HOST = UTF8String.fromString("HOST") - private val PATH = UTF8String.fromString("PATH") - private val QUERY = UTF8String.fromString("QUERY") - private val REF = UTF8String.fromString("REF") - private val PROTOCOL = UTF8String.fromString("PROTOCOL") - private val FILE = UTF8String.fromString("FILE") - private val AUTHORITY = UTF8String.fromString("AUTHORITY") - private val USERINFO = UTF8String.fromString("USERINFO") - private val REGEXPREFIX = "(&|^)" - private val REGEXSUBFIX = "=([^&]*)" -} - // scalastyle:off line.size.limit @ExpressionDescription( usage = "_FUNC_(url, partToExtract[, key]) - This is a special version of `parse_url` that performs the same operation, but returns a NULL value instead of raising an error if the parsing cannot be performed.", @@ -215,8 +200,13 @@ case class TryParseUrl(params: Seq[Expression], replacement: Expression) """, since = "2.0.0", group = "url_funcs") -case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.get.ansiEnabled) - extends Expression with ExpectsInputTypes with CodegenFallback { +case class ParseUrl( + children: Seq[Expression], + failOnError: Boolean = SQLConf.get.ansiEnabled) + extends Expression + with ExpectsInputTypes + with RuntimeReplaceable { + def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled) override def nullable: Boolean = true @@ -225,29 +215,6 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "parse_url" - // If the url is a constant, cache the URL object so that we don't need to convert url - // from UTF8String to String to URL for every row. - @transient private lazy val cachedUrl = children(0) match { - case Literal(url: UTF8String, _) if url ne null => getUrl(url) - case _ => null - } - - // If the key is a constant, cache the Pattern object so that we don't need to convert key - // from UTF8String to String to StringBuilder to String to Pattern for every row. - @transient private lazy val cachedPattern = children(2) match { - case Literal(key: UTF8String, _) if key ne null => getPattern(key) - case _ => null - } - - // If the partToExtract is a constant, cache the Extract part function so that we don't need - // to check the partToExtract for every row. - @transient private lazy val cachedExtractPartFunc = children(1) match { - case Literal(part: UTF8String, _) => getExtractPartFunc(part) - case _ => null - } - - import ParseUrl._ - override def checkInputDataTypes(): TypeCheckResult = { if (children.size > 3 || children.size < 2) { throw QueryCompilationErrors.wrongNumArgsError( @@ -258,108 +225,41 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge } } - private def getPattern(key: UTF8String): Pattern = { - Pattern.compile(REGEXPREFIX + key.toString + REGEXSUBFIX) - } - - private def getUrl(url: UTF8String): URI = { - try { - new URI(url.toString) - } catch { - case e: URISyntaxException if failOnError => - throw QueryExecutionErrors.invalidUrlError(url, e) - case _: URISyntaxException => null - } - } - - private def getExtractPartFunc(partToExtract: UTF8String): URI => String = { - - // partToExtract match { - // case HOST => _.toURL().getHost - // case PATH => _.toURL().getPath - // case QUERY => _.toURL().getQuery - // case REF => _.toURL().getRef - // case PROTOCOL => _.toURL().getProtocol - // case FILE => _.toURL().getFile - // case AUTHORITY => _.toURL().getAuthority - // case USERINFO => _.toURL().getUserInfo - // case _ => (url: URI) => null - // } - - partToExtract match { - case HOST => _.getHost - case PATH => _.getRawPath - case QUERY => _.getRawQuery - case REF => _.getRawFragment - case PROTOCOL => _.getScheme - case FILE => - (url: URI) => - if (url.getRawQuery ne null) { - url.getRawPath + "?" + url.getRawQuery - } else { - url.getRawPath - } - case AUTHORITY => _.getRawAuthority - case USERINFO => _.getRawUserInfo - case _ => (url: URI) => null - } - } + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ParseUrl = + copy(children = newChildren) - private def extractValueFromQuery(query: UTF8String, pattern: Pattern): UTF8String = { - val m = pattern.matcher(query.toString) - if (m.find()) { - UTF8String.fromString(m.group(2)) - } else { - null - } + // If the url is a constant, cache the URL object so that we don't need to convert url + // from UTF8String to String to URL for every row. + @transient private lazy val url = children.head match { + case Literal(url: UTF8String, _) if url ne null => url + case _ => null } - private def extractFromUrl(url: URI, partToExtract: UTF8String): UTF8String = { - if (cachedExtractPartFunc ne null) { - UTF8String.fromString(cachedExtractPartFunc.apply(url)) - } else { - UTF8String.fromString(getExtractPartFunc(partToExtract).apply(url)) - } + // If the partToExtract is a constant, cache the Extract part function so that we don't need + // to check the partToExtract for every row. + @transient private lazy val extractPart = children(1) match { + case Literal(part: UTF8String, _) => part + case _ => null } - private def parseUrlWithoutKey(url: UTF8String, partToExtract: UTF8String): UTF8String = { - if (cachedUrl ne null) { - extractFromUrl(cachedUrl, partToExtract) - } else { - val currentUrl = getUrl(url) - if (currentUrl ne null) { - extractFromUrl(currentUrl, partToExtract) - } else { - null - } + // If the key is a constant, cache the Pattern object so that we don't need to convert key + // from UTF8String to String to StringBuilder to String to Pattern for every row. + @transient private lazy val pattern = children.size match { + case 3 => children(2) match { + case Literal(key: UTF8String, _) if key ne null => key + case _ => null } + case _ => null } - override def eval(input: InternalRow): Any = { - val evaluated = children.map{e => e.eval(input).asInstanceOf[UTF8String]} - if (evaluated.contains(null)) return null - if (evaluated.size == 2) { - parseUrlWithoutKey(evaluated(0), evaluated(1)) - } else { - // 3-arg, i.e. QUERY with key - assert(evaluated.size == 3) - if (evaluated(1) != QUERY) { - return null - } - - val query = parseUrlWithoutKey(evaluated(0), evaluated(1)) - if (query eq null) { - return null - } - - if (cachedPattern ne null) { - extractValueFromQuery(query, cachedPattern) - } else { - extractValueFromQuery(query, getPattern(evaluated(2))) - } - } - } + @transient + private lazy val evaluator: ParseUrlEvaluator = ParseUrlEvaluator( + url, extractPart, pattern, failOnError) - override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ParseUrl = - copy(children = newChildren) + override def replacement: Expression = Invoke( + Literal.create(evaluator, ObjectType(classOf[ParseUrlEvaluator])), + "evaluate", + dataType, + children, + children.map(_.dataType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index aa7eafeed485a..1687d614cc5eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -1905,7 +1905,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // exceptional cases intercept[java.util.regex.PatternSyntaxException] { evaluateWithoutCodegen(ParseUrl(Seq(Literal("http://spark.apache.org/path?"), - Literal("QUERY"), Literal("???")))) + Literal("QUERY"), Literal("???"))).replacement) } // arguments checking @@ -1956,7 +1956,8 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "inputType" -> "\"INT\""))) // Test escaping of arguments - GenerateUnsafeProjection.generate(ParseUrl(Seq(Literal("\"quote"), Literal("\"quote"))) :: Nil) + GenerateUnsafeProjection.generate( + ParseUrl(Seq(Literal("\"quote"), Literal("\"quote"))).replacement :: Nil) } test("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") { diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_url.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_url.explain index 3c874b5c8b6a4..1f9f3df800b8f 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_url.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_url.explain @@ -1,2 +1,2 @@ -Project [parse_url(g#0, g#0, false) AS parse_url(g, g)#0] +Project [invoke(ParseUrlEvaluator(null,null,null,false).evaluate(g#0, g#0)) AS parse_url(g, g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_url_with_key.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_url_with_key.explain index eba1c5c814fe3..900de9c243a83 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_url_with_key.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_url_with_key.explain @@ -1,2 +1,2 @@ -Project [parse_url(g#0, g#0, g#0, false) AS parse_url(g, g, g)#0] +Project [invoke(ParseUrlEvaluator(null,null,null,false).evaluate(g#0, g#0, g#0)) AS parse_url(g, g, g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_url.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_url.explain index 2fbf751ecf193..87bad58090a08 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_url.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_url.explain @@ -1,2 +1,2 @@ -Project [parse_url(g#0, g#0, false) AS try_parse_url(g, g)#0] +Project [invoke(ParseUrlEvaluator(null,null,null,false).evaluate(g#0, g#0)) AS try_parse_url(g, g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_url_with_key.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_url_with_key.explain index 74c4a4985acf2..aed35cfeb7009 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_url_with_key.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_url_with_key.explain @@ -1,2 +1,2 @@ -Project [parse_url(g#0, g#0, g#0, false) AS try_parse_url(g, g, g)#0] +Project [invoke(ParseUrlEvaluator(null,null,null,false).evaluate(g#0, g#0, g#0)) AS try_parse_url(g, g, g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]