Skip to content

Commit

Permalink
Simplify conditional catalyst expressions generated by udf-compiler (#…
Browse files Browse the repository at this point in the history
…5315)

if (c) true else false => c
if (c) false else true => !c

Signed-off-by: Sean Lee <selee@nvidia.com>
  • Loading branch information
seanprime7 authored Apr 27, 2022
1 parent 45d6fcc commit 8877b33
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -254,7 +254,7 @@ object CatalystExpressionBuilder extends Logging {
val res = expr match {
case And(Literal.TrueLiteral, c) => simplifyExpr(c)
case And(c, Literal.TrueLiteral) => simplifyExpr(c)
case And(Literal.FalseLiteral, _) => Literal.FalseLiteral
case and: And if and.containsChild(Literal.FalseLiteral) => Literal.FalseLiteral
case And(c1@LessThan(s1, Literal(v1, t1)),
c2@LessThan(s2, Literal(v2, t2))) if s1 == s2 && t1 == t2 => {
t1 match {
Expand Down Expand Up @@ -346,7 +346,7 @@ object CatalystExpressionBuilder extends Logging {
}
}
case And(c1, c2) => And(simplifyExpr(c1), simplifyExpr(c2))
case Or(Literal.TrueLiteral, _) => Literal.TrueLiteral
case or: Or if or.containsChild(Literal.TrueLiteral) => Literal.TrueLiteral
case Or(Literal.FalseLiteral, c) => simplifyExpr(c)
case Or(c, Literal.FalseLiteral) => simplifyExpr(c)
case Or(c1@GreaterThan(s1, Literal(v1, t1)),
Expand Down Expand Up @@ -374,6 +374,7 @@ object CatalystExpressionBuilder extends Logging {
case Not(LessThanOrEqual(c1, c2)) => GreaterThan(c1, c2)
case Not(GreaterThan(c1, c2)) => LessThanOrEqual(c1, c2)
case Not(GreaterThanOrEqual(c1, c2)) => LessThan(c1, c2)
case Not(c) => Not(simplifyExpr(c))
case EqualTo(Literal(v1, _), Literal(v2, _)) =>
if (v1 == v2) Literal.TrueLiteral else Literal.FalseLiteral
case LessThan(If(c1,
Expand Down Expand Up @@ -424,6 +425,9 @@ object CatalystExpressionBuilder extends Logging {
}
case If(c, Repr.ArrayBuffer(t), Repr.ArrayBuffer(f)) => Repr.ArrayBuffer(If(c, t, f))
case If(c, Repr.StringBuilder(t), Repr.StringBuilder(f)) => Repr.StringBuilder(If(c, t, f))
case If(c, Literal.TrueLiteral, Literal.FalseLiteral) => c
case If(c, Literal.FalseLiteral, Literal.TrueLiteral) => Not(c)
case If(c, t, f) => If(simplifyExpr(c), simplifyExpr(t), simplifyExpr(f))
case _ => expr
}
logDebug(s"[CatalystExpressionBuilder] simplify: ${expr} ==> ${res}")
Expand Down
26 changes: 25 additions & 1 deletion udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -853,6 +853,30 @@ class OpcodeSuite extends FunSuite {
checkEquiv(result, ref)
}

test("Conditional simplification - if (c) true else false => c") {
val myudf: (Int) => Boolean = i => {
if (i < 20) true else false
}
val u = makeUdf(myudf)
val dataset = List(20, 19).toDF("x")
val result = dataset.withColumn("new", u(col("x")))
val ref = dataset.withColumn("new", col("x") < 20)
assert(udfIsCompiled(result))
assert(!result.queryExecution.analyzed.toString.contains("if"))
}

test("Conditional simplification - if (c) false else true => !c") {
val myudf: (Int) => Boolean = i => {
if (i < 20) false else true
}
val u = makeUdf(myudf)
val dataset = List(20, 19).toDF("x")
val result = dataset.withColumn("new", u(col("x")))
val ref = dataset.withColumn("new", col("x") >= 20)
checkEquiv(result, ref)
assert(!result.queryExecution.analyzed.toString.contains("if"))
}

test("LDC_W opcode") {
val myudf: () => String = () => {
val myString : String = "a"
Expand Down

0 comments on commit 8877b33

Please sign in to comment.