Skip to content

Commit

Permalink
[SPARK-50056][SQL] Codegen Support for ParseUrl (by Invoke & RuntimeR…
Browse files Browse the repository at this point in the history
…eplaceable)

### What changes were proposed in this pull request?
The pr aims to add `Codegen` Support for `parse_url`.

### Why are the changes needed?
- improve codegen coverage.
- simplified code.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Pass GA & Existed UT (eg: UrlFunctionsSuite#`*parse_url*`)

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#48579 from panbingkun/SPARK-50056.

Authored-by: panbingkun <panbingkun@baidu.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
panbingkun authored and cloud-fan committed Nov 12, 2024
1 parent 4f95a7f commit 00bff28
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 145 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.url

import java.net.{URI, URISyntaxException}
import java.util.regex.Pattern

import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.unsafe.types.UTF8String

case class ParseUrlEvaluator(
url: UTF8String,
extractPart: UTF8String,
pattern: UTF8String,
failOnError: Boolean) {

import ParseUrlEvaluator._

private lazy val cachedUrl: URI =
if (url != null) getUrl(url, failOnError) else null

private lazy val cachedExtractPartFunc: URI => String =
if (extractPart != null) getExtractPartFunc(extractPart) else null

private lazy val cachedPattern: Pattern =
if (pattern != null) getPattern(pattern) else null

private def extractValueFromQuery(query: UTF8String, pattern: Pattern): UTF8String = {
val m = pattern.matcher(query.toString)
if (m.find()) {
UTF8String.fromString(m.group(2))
} else {
null
}
}

private def extractFromUrl(url: URI, partToExtract: UTF8String): UTF8String = {
if (cachedExtractPartFunc ne null) {
UTF8String.fromString(cachedExtractPartFunc(url))
} else {
UTF8String.fromString(getExtractPartFunc(partToExtract)(url))
}
}

private def parseUrlWithoutKey(url: UTF8String, partToExtract: UTF8String): UTF8String = {
if (cachedUrl ne null) {
extractFromUrl(cachedUrl, partToExtract)
} else {
val currentUrl = getUrl(url, failOnError)
if (currentUrl ne null) {
extractFromUrl(currentUrl, partToExtract)
} else {
null
}
}
}

final def evaluate(url: UTF8String, path: UTF8String): Any = {
parseUrlWithoutKey(url, path)
}

final def evaluate(url: UTF8String, path: UTF8String, key: UTF8String): Any = {
if (path != QUERY) return null

val query = parseUrlWithoutKey(url, path)
if (query eq null) return null

if (cachedPattern ne null) {
extractValueFromQuery(query, cachedPattern)
} else {
extractValueFromQuery(query, getPattern(key))
}
}
}

object ParseUrlEvaluator {
private val HOST = UTF8String.fromString("HOST")
private val PATH = UTF8String.fromString("PATH")
private val QUERY = UTF8String.fromString("QUERY")
private val REF = UTF8String.fromString("REF")
private val PROTOCOL = UTF8String.fromString("PROTOCOL")
private val FILE = UTF8String.fromString("FILE")
private val AUTHORITY = UTF8String.fromString("AUTHORITY")
private val USERINFO = UTF8String.fromString("USERINFO")
private val REGEXPREFIX = "(&|^)"
private val REGEXSUBFIX = "=([^&]*)"

private def getPattern(key: UTF8String): Pattern = {
Pattern.compile(REGEXPREFIX + key.toString + REGEXSUBFIX)
}

private def getUrl(url: UTF8String, failOnError: Boolean): URI = {
try {
new URI(url.toString)
} catch {
case e: URISyntaxException if failOnError =>
throw QueryExecutionErrors.invalidUrlError(url, e)
case _: URISyntaxException => null
}
}

private def getExtractPartFunc(partToExtract: UTF8String): URI => String = {

// partToExtract match {
// case HOST => _.toURL().getHost
// case PATH => _.toURL().getPath
// case QUERY => _.toURL().getQuery
// case REF => _.toURL().getRef
// case PROTOCOL => _.toURL().getProtocol
// case FILE => _.toURL().getFile
// case AUTHORITY => _.toURL().getAuthority
// case USERINFO => _.toURL().getUserInfo
// case _ => (url: URI) => null
// }

partToExtract match {
case HOST => _.getHost
case PATH => _.getRawPath
case QUERY => _.getRawQuery
case REF => _.getRawFragment
case PROTOCOL => _.getScheme
case FILE =>
(url: URI) =>
if (url.getRawQuery ne null) {
url.getRawPath + "?" + url.getRawQuery
} else {
url.getRawPath
}
case AUTHORITY => _.getRawAuthority
case USERINFO => _.getRawUserInfo
case _ => (_: URI) => null
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,18 @@

package org.apache.spark.sql.catalyst.expressions

import java.net.{URI, URISyntaxException, URLDecoder, URLEncoder}
import java.net.{URLDecoder, URLEncoder}
import java.nio.charset.StandardCharsets
import java.util.regex.Pattern

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
import org.apache.spark.sql.catalyst.expressions.url.ParseUrlEvaluator
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType}
import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, ObjectType}
import org.apache.spark.unsafe.types.UTF8String

// scalastyle:off line.size.limit
Expand Down Expand Up @@ -156,19 +154,6 @@ object UrlCodec {
}
}

object ParseUrl {
private val HOST = UTF8String.fromString("HOST")
private val PATH = UTF8String.fromString("PATH")
private val QUERY = UTF8String.fromString("QUERY")
private val REF = UTF8String.fromString("REF")
private val PROTOCOL = UTF8String.fromString("PROTOCOL")
private val FILE = UTF8String.fromString("FILE")
private val AUTHORITY = UTF8String.fromString("AUTHORITY")
private val USERINFO = UTF8String.fromString("USERINFO")
private val REGEXPREFIX = "(&|^)"
private val REGEXSUBFIX = "=([^&]*)"
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(url, partToExtract[, key]) - This is a special version of `parse_url` that performs the same operation, but returns a NULL value instead of raising an error if the parsing cannot be performed.",
Expand Down Expand Up @@ -215,8 +200,13 @@ case class TryParseUrl(params: Seq[Expression], replacement: Expression)
""",
since = "2.0.0",
group = "url_funcs")
case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.get.ansiEnabled)
extends Expression with ExpectsInputTypes with CodegenFallback {
case class ParseUrl(
children: Seq[Expression],
failOnError: Boolean = SQLConf.get.ansiEnabled)
extends Expression
with ExpectsInputTypes
with RuntimeReplaceable {

def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled)

override def nullable: Boolean = true
Expand All @@ -225,29 +215,6 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge
override def dataType: DataType = SQLConf.get.defaultStringType
override def prettyName: String = "parse_url"

// If the url is a constant, cache the URL object so that we don't need to convert url
// from UTF8String to String to URL for every row.
@transient private lazy val cachedUrl = children(0) match {
case Literal(url: UTF8String, _) if url ne null => getUrl(url)
case _ => null
}

// If the key is a constant, cache the Pattern object so that we don't need to convert key
// from UTF8String to String to StringBuilder to String to Pattern for every row.
@transient private lazy val cachedPattern = children(2) match {
case Literal(key: UTF8String, _) if key ne null => getPattern(key)
case _ => null
}

// If the partToExtract is a constant, cache the Extract part function so that we don't need
// to check the partToExtract for every row.
@transient private lazy val cachedExtractPartFunc = children(1) match {
case Literal(part: UTF8String, _) => getExtractPartFunc(part)
case _ => null
}

import ParseUrl._

override def checkInputDataTypes(): TypeCheckResult = {
if (children.size > 3 || children.size < 2) {
throw QueryCompilationErrors.wrongNumArgsError(
Expand All @@ -258,108 +225,41 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge
}
}

private def getPattern(key: UTF8String): Pattern = {
Pattern.compile(REGEXPREFIX + key.toString + REGEXSUBFIX)
}

private def getUrl(url: UTF8String): URI = {
try {
new URI(url.toString)
} catch {
case e: URISyntaxException if failOnError =>
throw QueryExecutionErrors.invalidUrlError(url, e)
case _: URISyntaxException => null
}
}

private def getExtractPartFunc(partToExtract: UTF8String): URI => String = {

// partToExtract match {
// case HOST => _.toURL().getHost
// case PATH => _.toURL().getPath
// case QUERY => _.toURL().getQuery
// case REF => _.toURL().getRef
// case PROTOCOL => _.toURL().getProtocol
// case FILE => _.toURL().getFile
// case AUTHORITY => _.toURL().getAuthority
// case USERINFO => _.toURL().getUserInfo
// case _ => (url: URI) => null
// }

partToExtract match {
case HOST => _.getHost
case PATH => _.getRawPath
case QUERY => _.getRawQuery
case REF => _.getRawFragment
case PROTOCOL => _.getScheme
case FILE =>
(url: URI) =>
if (url.getRawQuery ne null) {
url.getRawPath + "?" + url.getRawQuery
} else {
url.getRawPath
}
case AUTHORITY => _.getRawAuthority
case USERINFO => _.getRawUserInfo
case _ => (url: URI) => null
}
}
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ParseUrl =
copy(children = newChildren)

private def extractValueFromQuery(query: UTF8String, pattern: Pattern): UTF8String = {
val m = pattern.matcher(query.toString)
if (m.find()) {
UTF8String.fromString(m.group(2))
} else {
null
}
// If the url is a constant, cache the URL object so that we don't need to convert url
// from UTF8String to String to URL for every row.
@transient private lazy val url = children.head match {
case Literal(url: UTF8String, _) if url ne null => url
case _ => null
}

private def extractFromUrl(url: URI, partToExtract: UTF8String): UTF8String = {
if (cachedExtractPartFunc ne null) {
UTF8String.fromString(cachedExtractPartFunc.apply(url))
} else {
UTF8String.fromString(getExtractPartFunc(partToExtract).apply(url))
}
// If the partToExtract is a constant, cache the Extract part function so that we don't need
// to check the partToExtract for every row.
@transient private lazy val extractPart = children(1) match {
case Literal(part: UTF8String, _) => part
case _ => null
}

private def parseUrlWithoutKey(url: UTF8String, partToExtract: UTF8String): UTF8String = {
if (cachedUrl ne null) {
extractFromUrl(cachedUrl, partToExtract)
} else {
val currentUrl = getUrl(url)
if (currentUrl ne null) {
extractFromUrl(currentUrl, partToExtract)
} else {
null
}
// If the key is a constant, cache the Pattern object so that we don't need to convert key
// from UTF8String to String to StringBuilder to String to Pattern for every row.
@transient private lazy val pattern = children.size match {
case 3 => children(2) match {
case Literal(key: UTF8String, _) if key ne null => key
case _ => null
}
case _ => null
}

override def eval(input: InternalRow): Any = {
val evaluated = children.map{e => e.eval(input).asInstanceOf[UTF8String]}
if (evaluated.contains(null)) return null
if (evaluated.size == 2) {
parseUrlWithoutKey(evaluated(0), evaluated(1))
} else {
// 3-arg, i.e. QUERY with key
assert(evaluated.size == 3)
if (evaluated(1) != QUERY) {
return null
}

val query = parseUrlWithoutKey(evaluated(0), evaluated(1))
if (query eq null) {
return null
}

if (cachedPattern ne null) {
extractValueFromQuery(query, cachedPattern)
} else {
extractValueFromQuery(query, getPattern(evaluated(2)))
}
}
}
@transient
private lazy val evaluator: ParseUrlEvaluator = ParseUrlEvaluator(
url, extractPart, pattern, failOnError)

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ParseUrl =
copy(children = newChildren)
override def replacement: Expression = Invoke(
Literal.create(evaluator, ObjectType(classOf[ParseUrlEvaluator])),
"evaluate",
dataType,
children,
children.map(_.dataType))
}
Original file line number Diff line number Diff line change
Expand Up @@ -1905,7 +1905,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// exceptional cases
intercept[java.util.regex.PatternSyntaxException] {
evaluateWithoutCodegen(ParseUrl(Seq(Literal("http://spark.apache.org/path?"),
Literal("QUERY"), Literal("???"))))
Literal("QUERY"), Literal("???"))).replacement)
}

// arguments checking
Expand Down Expand Up @@ -1956,7 +1956,8 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
"inputType" -> "\"INT\"")))

// Test escaping of arguments
GenerateUnsafeProjection.generate(ParseUrl(Seq(Literal("\"quote"), Literal("\"quote"))) :: Nil)
GenerateUnsafeProjection.generate(
ParseUrl(Seq(Literal("\"quote"), Literal("\"quote"))).replacement :: Nil)
}

test("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [parse_url(g#0, g#0, false) AS parse_url(g, g)#0]
Project [invoke(ParseUrlEvaluator(null,null,null,false).evaluate(g#0, g#0)) AS parse_url(g, g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Loading

0 comments on commit 00bff28

Please sign in to comment.