Skip to content

Commit

Permalink
Rewrites PLUS and BITWISE_AND implementations using new modeling
Browse files Browse the repository at this point in the history
Updates how function instances are chosen and carried
  • Loading branch information
johnedquinn committed Nov 14, 2024
1 parent 894e443 commit aaa2a0c
Show file tree
Hide file tree
Showing 39 changed files with 756 additions and 311 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ internal class StandardCompiler(strategies: List<Strategy>) : PartiQLCompiler {
// Compile the candidates
val candidates = Array(functions.size) {
val fn = functions[it]
val fnArity = fn.parameters.size
val fnArity = fn.getParameters().size
if (arity == -1) {
// set first
arity = fnArity
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.partiql.eval.internal.operator.rex.ExprCallDynamic.Candidate
import org.partiql.eval.internal.operator.rex.ExprCallDynamic.CoercionFamily.DYNAMIC
import org.partiql.eval.internal.operator.rex.ExprCallDynamic.CoercionFamily.UNKNOWN
import org.partiql.spi.function.Function
import org.partiql.spi.function.Parameter
import org.partiql.spi.value.Datum
import org.partiql.types.PType
import org.partiql.value.PartiQLValue
Expand All @@ -31,7 +32,7 @@ import org.partiql.value.PartiQLValue
*/
internal class ExprCallDynamic(
private val name: String,
private val functions: Array<Function.Instance>,
private val functions: Array<Function>,
private val args: Array<ExprValue>
) : ExprValue {

Expand All @@ -48,7 +49,7 @@ internal class ExprCallDynamic(
*
* TODO actually make this an array instead of lists.
*/
private val paramTypes: List<List<PType>> = functions.map { c -> c.parameters.toList() }
private val paramTypes: List<List<Parameter>> = functions.map { c -> c.getParameters().toList() }

/**
* @property paramFamilies is a two-dimensional array.
Expand All @@ -58,7 +59,7 @@ internal class ExprCallDynamic(
*
* TODO actually make this an array instead of lists.
*/
private val paramFamilies: List<List<CoercionFamily>> = functions.map { c -> c.parameters.map { p -> family(p.kind) } }
private val paramFamilies: List<List<CoercionFamily>> = functions.map { c -> c.getParameters().map { p -> family(p.getType().kind) } }

/**
* A memoization cache for the [match] function.
Expand Down Expand Up @@ -92,7 +93,7 @@ internal class ExprCallDynamic(
for (paramIndex in paramIndices) {
val argType = args[paramIndex]
val paramType = paramTypes[candidateIndex][paramIndex]
if (paramType == argType) { currentExactMatches++ }
if (paramType.getMatch(argType) == argType) { currentExactMatches++ }
val argFamily = argFamilies[paramIndex]
val paramFamily = paramFamilies[candidateIndex][paramIndex]
if (paramFamily != argFamily && argFamily != CoercionFamily.UNKNOWN && paramFamily != CoercionFamily.DYNAMIC) { return@forEach }
Expand All @@ -102,7 +103,10 @@ internal class ExprCallDynamic(
exactMatches = currentExactMatches
}
}
return if (currentMatch == null) null else Candidate(functions[currentMatch!!])
return if (currentMatch == null) null else {
val instance = functions[currentMatch!!].getInstance(args.toTypedArray()) ?: return null
Candidate(instance)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ import kotlin.test.assertNotNull
@OptIn(PartiQLValueExperimental::class)
class PartiQLEvaluatorTest {

@ParameterizedTest
@MethodSource("plusTestCases")
@Execution(ExecutionMode.CONCURRENT)
fun plusTests(tc: SuccessTestCase) = tc.assert()

@ParameterizedTest
@MethodSource("sanityTestsCases")
@Execution(ExecutionMode.CONCURRENT)
Expand Down Expand Up @@ -81,6 +86,74 @@ class PartiQLEvaluatorTest {

companion object {

// Result precision: max(s1, s2) + max(p1 - s1, p2 - s2) + 1
// Result scale: max(s1, s2)
@JvmStatic
fun plusTestCases() = listOf(
SuccessTestCase(
input = """
-- DEC(1, 0) + DEC(6, 5)
-- P = 5 + MAX(1, 1) + 1 = 7
-- S = MAX(0, 5) = 5
1 + 2.00000;
""".trimIndent(),
mode = Mode.STRICT(),
expected = Datum.decimal(BigDecimal.valueOf(300000, 5), 7, 5),
jvmEquality = true
),
SuccessTestCase(
input = """
-- DEC(2, 1) + DEC(6, 5)
-- P = 5 + MAX(1, 1) + 1 = 7
-- S = MAX(1, 5) = 5
1.0 + 2.00000;
""".trimIndent(),
mode = Mode.STRICT(),
expected = Datum.decimal(BigDecimal.valueOf(300000, 5), 7, 5),
jvmEquality = true
),
SuccessTestCase(
input = """
-- DEC(5, 4) + DEC(6, 5)
-- P = 5 + MAX(1, 1) + 1 = 7
-- S = MAX(4, 5) = 5
1.0000 + 2.00000;
""".trimIndent(),
mode = Mode.STRICT(),
expected = Datum.decimal(BigDecimal.valueOf(300000, 5), 7, 5),
jvmEquality = true
),
SuccessTestCase(
input = """
-- DEC(7, 4) + DEC(13, 7)
-- P = 7 + MAX(3, 6) + 1 = 14
-- S = MAX(4, 7) = 7
234.0000 + 456789.0000000;
""".trimIndent(),
mode = Mode.STRICT(),
expected = Datum.decimal(BigDecimal.valueOf(457023), 14, 7),
jvmEquality = true
),
SuccessTestCase(
input = """
-- This shows that the value, while dynamic, still produces the right precision/scale
-- DEC(7, 4) + DEC(13, 7)
-- P = 7 + MAX(3, 6) + 1 = 14
-- S = MAX(4, 7) = 7
234.0000 + dynamic_decimal;
""".trimIndent(),
mode = Mode.STRICT(),
expected = Datum.decimal(BigDecimal.valueOf(457023), 14, 7),
globals = listOf(
SuccessTestCase.Global(
"dynamic_decimal",
"456789.0000000"
)
),
jvmEquality = true
),
)

@JvmStatic
fun castTestCases() = listOf(
SuccessTestCase(
Expand Down Expand Up @@ -1297,13 +1370,21 @@ class PartiQLEvaluatorTest {
)
}

public class SuccessTestCase @OptIn(PartiQLValueExperimental::class) constructor(
public class SuccessTestCase(
val input: String,
val expected: PartiQLValue,
val expected: Datum,
val mode: Mode = Mode.PERMISSIVE(),
val globals: List<Global> = emptyList(),
val jvmEquality: Boolean = false
) {

constructor(
input: String,
expected: PartiQLValue,
mode: Mode = Mode.PERMISSIVE(),
globals: List<Global> = emptyList(),
) : this(input, Datum.of(expected), mode, globals)

private val compiler = PartiQLCompiler.standard()
private val parser = PartiQLParser.standard()
private val planner = PartiQLPlanner.standard()
Expand Down Expand Up @@ -1340,24 +1421,22 @@ class PartiQLEvaluatorTest {
.build()
val plan = planner.plan(statement, session).plan
val result = compiler.prepare(plan, mode).execute()
val output = result.toPartiQLValue() // TODO: Assert directly on Datum
assert(expected == output) {
comparisonString(expected, output, plan)
val comparison = when (jvmEquality) {
true -> expected == result
false -> Datum.comparator().compare(expected, result) == 0
}
assert(comparison) {
comparisonString(expected, result, plan)
}
}

@OptIn(PartiQLValueExperimental::class)
private fun comparisonString(expected: PartiQLValue, actual: PartiQLValue, plan: Plan): String {
val expectedBuffer = ByteArrayOutputStream()
val expectedWriter = PartiQLValueIonWriterBuilder.standardIonTextBuilder().build(expectedBuffer)
expectedWriter.append(expected)
private fun comparisonString(expected: Datum, actual: Datum, plan: Plan): String {
return buildString {
// TODO pretty-print V1 plans!
appendLine(plan)
appendLine("Expected : $expectedBuffer")
expectedBuffer.reset()
expectedWriter.append(actual)
appendLine("Actual : $expectedBuffer")
// TODO: Add DatumWriter
appendLine("Expected : $expected")
appendLine("Actual : $actual")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.junit.jupiter.params.provider.MethodSource
import org.partiql.eval.Environment
import org.partiql.eval.internal.helpers.ValueUtility.check
import org.partiql.spi.function.Function
import org.partiql.spi.function.Parameter
import org.partiql.spi.value.Datum
import org.partiql.spi.value.Datum.array
import org.partiql.spi.value.Datum.bag
Expand Down Expand Up @@ -62,12 +63,30 @@ class ExprCallDynamicTest {
)

@OptIn(PartiQLValueExperimental::class)
internal val functions: Array<Function.Instance> = params.mapIndexed { index, it ->
object : Function.Instance(
returns = PType.integer(),
parameters = arrayOf(it.first.toPType(), it.second.toPType())
) {
override fun invoke(args: Array<Datum>): Datum = integer(index)
internal val functions: Array<Function> = params.mapIndexed { index, it ->
object : Function {

override fun getName(): String {
return "example"
}

override fun getParameters(): Array<Parameter> {
return arrayOf(Parameter("lhs", it.first.toPType()), Parameter("rhs", it.second.toPType()))
}

override fun getReturnType(args: Array<PType>): PType {
return PType.integer()
}

override fun getInstance(args: Array<PType>): Function.Instance {
return object : Function.Instance(
name = "example",
returns = PType.integer(),
parameters = arrayOf(it.first.toPType(), it.second.toPType())
) {
override fun invoke(args: Array<Datum>): Datum = integer(index)
}
}
}
}.toTypedArray()
}
Expand Down
2 changes: 2 additions & 0 deletions partiql-plan/api/partiql-plan.api
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ public abstract interface class org/partiql/plan/builder/PlanFactory {
public abstract fun rexBag (Ljava/util/Collection;Lorg/partiql/plan/rex/RexType;)Lorg/partiql/plan/rex/RexBag;
public abstract fun rexCall (Lorg/partiql/spi/function/Function$Instance;Ljava/util/List;)Lorg/partiql/plan/rex/RexCall;
public abstract fun rexCallDynamic (Ljava/lang/String;Ljava/util/List;Ljava/util/List;)Lorg/partiql/plan/rex/RexCallDynamic;
public abstract fun rexCallDynamic (Ljava/lang/String;Ljava/util/List;Ljava/util/List;Lorg/partiql/types/PType;)Lorg/partiql/plan/rex/RexCallDynamic;
public abstract fun rexCase (Ljava/util/List;Lorg/partiql/plan/rex/Rex;)Lorg/partiql/plan/rex/RexCase;
public abstract fun rexCase (Ljava/util/List;Lorg/partiql/plan/rex/Rex;Lorg/partiql/plan/rex/RexType;)Lorg/partiql/plan/rex/RexCase;
public abstract fun rexCase (Lorg/partiql/plan/rex/Rex;Ljava/util/List;Lorg/partiql/plan/rex/Rex;)Lorg/partiql/plan/rex/RexCase;
Expand Down Expand Up @@ -344,6 +345,7 @@ public final class org/partiql/plan/builder/PlanFactory$DefaultImpls {
public static fun rexBag (Lorg/partiql/plan/builder/PlanFactory;Ljava/util/Collection;Lorg/partiql/plan/rex/RexType;)Lorg/partiql/plan/rex/RexBag;
public static fun rexCall (Lorg/partiql/plan/builder/PlanFactory;Lorg/partiql/spi/function/Function$Instance;Ljava/util/List;)Lorg/partiql/plan/rex/RexCall;
public static fun rexCallDynamic (Lorg/partiql/plan/builder/PlanFactory;Ljava/lang/String;Ljava/util/List;Ljava/util/List;)Lorg/partiql/plan/rex/RexCallDynamic;
public static fun rexCallDynamic (Lorg/partiql/plan/builder/PlanFactory;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Lorg/partiql/types/PType;)Lorg/partiql/plan/rex/RexCallDynamic;
public static fun rexCase (Lorg/partiql/plan/builder/PlanFactory;Ljava/util/List;Lorg/partiql/plan/rex/Rex;)Lorg/partiql/plan/rex/RexCase;
public static fun rexCase (Lorg/partiql/plan/builder/PlanFactory;Ljava/util/List;Lorg/partiql/plan/rex/Rex;Lorg/partiql/plan/rex/RexType;)Lorg/partiql/plan/rex/RexCase;
public static fun rexCase (Lorg/partiql/plan/builder/PlanFactory;Lorg/partiql/plan/rex/Rex;Ljava/util/List;Lorg/partiql/plan/rex/Rex;)Lorg/partiql/plan/rex/RexCase;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,13 +391,27 @@ public interface PlanFactory {
/**
* Create a [RexCallDynamic] instance.
*
* @param functions
* @param args
* @return
* @param name TODO
* @param functions TODO
* @param args TODO
* @return TODO
*/
public fun rexCallDynamic(name: String, functions: List<Function.Instance>, args: List<Rex>): RexCallDynamic =
public fun rexCallDynamic(name: String, functions: List<Function>, args: List<Rex>): RexCallDynamic =
RexCallDynamicImpl(name, functions, args)

/**
* Create a [RexCallDynamic] instance.
*
* @param name TODO
* @param functions TODO
* @param args TODO
* @param type specifies the output type of the dynamic dispatch. This may be specified if all candidate functions
* return the same type.
* @return TODO
*/
public fun rexCallDynamic(name: String, functions: List<Function>, args: List<Rex>, type: PType): RexCallDynamic =
RexCallDynamicImpl(name, functions, args, type)

/**
* Create a [RexCase] instance for a searched case-when with dynamic type.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.partiql.plan.rex

import org.partiql.plan.Visitor
import org.partiql.spi.function.Function
import org.partiql.types.PType

/**
* Logical operator for a dynamic dispatch call.
Expand All @@ -16,7 +17,7 @@ public interface RexCallDynamic : Rex {
/**
* Returns the functions to dispatch to.
*/
public fun getFunctions(): List<Function.Instance>
public fun getFunctions(): List<Function>

/**
* Returns the list of function arguments.
Expand All @@ -33,16 +34,17 @@ public interface RexCallDynamic : Rex {
*/
internal class RexCallDynamicImpl(
private var name: String,
private var functions: List<Function.Instance>,
private var functions: List<Function>,
private var args: List<Rex>,
type: PType = PType.dynamic()
) : RexCallDynamic {

// DO NOT USE FINAL
private var _type: RexType = RexType.dynamic()
private var _type: RexType = RexType(type)

override fun getName(): String = name

override fun getFunctions(): List<Function.Instance> = functions
override fun getFunctions(): List<Function> = functions

override fun getArgs(): List<Rex> = args

Expand Down
Loading

0 comments on commit aaa2a0c

Please sign in to comment.