From aaa2a0cdead3757daaeb039de7a46a681e6ab83b Mon Sep 17 00:00:00 2001 From: John Ed Quinn Date: Thu, 14 Nov 2024 08:30:21 -0800 Subject: [PATCH] Rewrites PLUS and BITWISE_AND implementations using new modeling Updates how function instances are chosen and carried --- .../internal/compiler/StandardCompiler.kt | 2 +- .../internal/operator/rex/ExprCallDynamic.kt | 14 +- .../eval/internal/PartiQLEvaluatorTest.kt | 107 ++++++++-- .../operator/rex/ExprCallDynamicTest.kt | 31 ++- partiql-plan/api/partiql-plan.api | 2 + .../org/partiql/plan/builder/PlanFactory.kt | 22 +- .../org/partiql/plan/rex/RexCallDynamic.kt | 10 +- .../planner/internal/CoercionFamily.kt | 87 ++++++++ .../org/partiql/planner/internal/Env.kt | 15 +- .../org/partiql/planner/internal/FnMatch.kt | 4 +- .../partiql/planner/internal/FnResolver.kt | 48 +++-- .../org/partiql/planner/internal/ir/Nodes.kt | 3 +- .../internal/transforms/PlanTransform.kt | 7 +- .../internal/transforms/RexConverter.kt | 16 +- .../planner/internal/typer/PlanTyper.kt | 12 +- .../main/resources/partiql_plan_internal.ion | 3 +- .../internal/typer/PlanTyperTestsPorted.kt | 24 ++- .../typer/operator/OpArithmeticTest.kt | 2 + .../typer/operator/OpBitwiseAndTest.kt | 4 +- .../internal/typer/predicate/OpBetweenTest.kt | 1 + partiql-spi/api/partiql-spi.api | 13 +- .../partiql/spi/errors/PErrorException.java | 10 + .../partiql/spi/errors/PErrorListener.java | 6 + .../java/org/partiql/spi/value/Datum.java | 7 +- .../org/partiql/spi/value/DatumDecimal.java | 22 ++ .../org/partiql/spi/function/Builtins.kt | 15 +- .../org/partiql/spi/function/Function.kt | 35 ++- .../org/partiql/spi/function/Parameter.kt | 4 +- .../builtins/ArithmeticDiadicOperator.kt | 157 ++++++++++++++ .../spi/function/builtins/FnBitwiseAnd.kt | 131 +++++------- .../spi/function/builtins/FnCollAgg.kt | 1 + .../spi/function/builtins/FnIsMissing.kt | 1 + .../partiql/spi/function/builtins/FnIsNull.kt | 1 + .../partiql/spi/function/builtins/FnNot.kt | 1 + .../partiql/spi/function/builtins/FnPlus.kt | 202 +++++++----------- .../spi/function/builtins/TypePrecedence.kt | 39 ++++ .../org/partiql/spi/internal/SqlTypeFamily.kt | 2 +- .../org/partiql/spi/internal/SqlTypes.kt | 2 +- .../main/java/org/partiql/types/PType.java | 4 +- 39 files changed, 756 insertions(+), 311 deletions(-) create mode 100644 partiql-planner/src/main/kotlin/org/partiql/planner/internal/CoercionFamily.kt create mode 100644 partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/ArithmeticDiadicOperator.kt create mode 100644 partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/TypePrecedence.kt diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/compiler/StandardCompiler.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/compiler/StandardCompiler.kt index 1d9b20d15d..c12e3a3f78 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/compiler/StandardCompiler.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/compiler/StandardCompiler.kt @@ -350,7 +350,7 @@ internal class StandardCompiler(strategies: List) : PartiQLCompiler { // Compile the candidates val candidates = Array(functions.size) { val fn = functions[it] - val fnArity = fn.parameters.size + val fnArity = fn.getParameters().size if (arity == -1) { // set first arity = fnArity diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamic.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamic.kt index 362cff1771..a2067124c7 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamic.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamic.kt @@ -8,6 +8,7 @@ import org.partiql.eval.internal.operator.rex.ExprCallDynamic.Candidate import org.partiql.eval.internal.operator.rex.ExprCallDynamic.CoercionFamily.DYNAMIC import org.partiql.eval.internal.operator.rex.ExprCallDynamic.CoercionFamily.UNKNOWN import org.partiql.spi.function.Function +import org.partiql.spi.function.Parameter import org.partiql.spi.value.Datum import org.partiql.types.PType import org.partiql.value.PartiQLValue @@ -31,7 +32,7 @@ import org.partiql.value.PartiQLValue */ internal class ExprCallDynamic( private val name: String, - private val functions: Array, + private val functions: Array, private val args: Array ) : ExprValue { @@ -48,7 +49,7 @@ internal class ExprCallDynamic( * * TODO actually make this an array instead of lists. */ - private val paramTypes: List> = functions.map { c -> c.parameters.toList() } + private val paramTypes: List> = functions.map { c -> c.getParameters().toList() } /** * @property paramFamilies is a two-dimensional array. @@ -58,7 +59,7 @@ internal class ExprCallDynamic( * * TODO actually make this an array instead of lists. */ - private val paramFamilies: List> = functions.map { c -> c.parameters.map { p -> family(p.kind) } } + private val paramFamilies: List> = functions.map { c -> c.getParameters().map { p -> family(p.getType().kind) } } /** * A memoization cache for the [match] function. @@ -92,7 +93,7 @@ internal class ExprCallDynamic( for (paramIndex in paramIndices) { val argType = args[paramIndex] val paramType = paramTypes[candidateIndex][paramIndex] - if (paramType == argType) { currentExactMatches++ } + if (paramType.getMatch(argType) == argType) { currentExactMatches++ } val argFamily = argFamilies[paramIndex] val paramFamily = paramFamilies[candidateIndex][paramIndex] if (paramFamily != argFamily && argFamily != CoercionFamily.UNKNOWN && paramFamily != CoercionFamily.DYNAMIC) { return@forEach } @@ -102,7 +103,10 @@ internal class ExprCallDynamic( exactMatches = currentExactMatches } } - return if (currentMatch == null) null else Candidate(functions[currentMatch!!]) + return if (currentMatch == null) null else { + val instance = functions[currentMatch!!].getInstance(args.toTypedArray()) ?: return null + Candidate(instance) + } } /** diff --git a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt index fa5bb5eced..4f566019f8 100644 --- a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt +++ b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt @@ -44,6 +44,11 @@ import kotlin.test.assertNotNull @OptIn(PartiQLValueExperimental::class) class PartiQLEvaluatorTest { + @ParameterizedTest + @MethodSource("plusTestCases") + @Execution(ExecutionMode.CONCURRENT) + fun plusTests(tc: SuccessTestCase) = tc.assert() + @ParameterizedTest @MethodSource("sanityTestsCases") @Execution(ExecutionMode.CONCURRENT) @@ -81,6 +86,74 @@ class PartiQLEvaluatorTest { companion object { + // Result precision: max(s1, s2) + max(p1 - s1, p2 - s2) + 1 + // Result scale: max(s1, s2) + @JvmStatic + fun plusTestCases() = listOf( + SuccessTestCase( + input = """ + -- DEC(1, 0) + DEC(6, 5) + -- P = 5 + MAX(1, 1) + 1 = 7 + -- S = MAX(0, 5) = 5 + 1 + 2.00000; + """.trimIndent(), + mode = Mode.STRICT(), + expected = Datum.decimal(BigDecimal.valueOf(300000, 5), 7, 5), + jvmEquality = true + ), + SuccessTestCase( + input = """ + -- DEC(2, 1) + DEC(6, 5) + -- P = 5 + MAX(1, 1) + 1 = 7 + -- S = MAX(1, 5) = 5 + 1.0 + 2.00000; + """.trimIndent(), + mode = Mode.STRICT(), + expected = Datum.decimal(BigDecimal.valueOf(300000, 5), 7, 5), + jvmEquality = true + ), + SuccessTestCase( + input = """ + -- DEC(5, 4) + DEC(6, 5) + -- P = 5 + MAX(1, 1) + 1 = 7 + -- S = MAX(4, 5) = 5 + 1.0000 + 2.00000; + """.trimIndent(), + mode = Mode.STRICT(), + expected = Datum.decimal(BigDecimal.valueOf(300000, 5), 7, 5), + jvmEquality = true + ), + SuccessTestCase( + input = """ + -- DEC(7, 4) + DEC(13, 7) + -- P = 7 + MAX(3, 6) + 1 = 14 + -- S = MAX(4, 7) = 7 + 234.0000 + 456789.0000000; + """.trimIndent(), + mode = Mode.STRICT(), + expected = Datum.decimal(BigDecimal.valueOf(457023), 14, 7), + jvmEquality = true + ), + SuccessTestCase( + input = """ + -- This shows that the value, while dynamic, still produces the right precision/scale + -- DEC(7, 4) + DEC(13, 7) + -- P = 7 + MAX(3, 6) + 1 = 14 + -- S = MAX(4, 7) = 7 + 234.0000 + dynamic_decimal; + """.trimIndent(), + mode = Mode.STRICT(), + expected = Datum.decimal(BigDecimal.valueOf(457023), 14, 7), + globals = listOf( + SuccessTestCase.Global( + "dynamic_decimal", + "456789.0000000" + ) + ), + jvmEquality = true + ), + ) + @JvmStatic fun castTestCases() = listOf( SuccessTestCase( @@ -1297,13 +1370,21 @@ class PartiQLEvaluatorTest { ) } - public class SuccessTestCase @OptIn(PartiQLValueExperimental::class) constructor( + public class SuccessTestCase( val input: String, - val expected: PartiQLValue, + val expected: Datum, val mode: Mode = Mode.PERMISSIVE(), val globals: List = emptyList(), + val jvmEquality: Boolean = false ) { + constructor( + input: String, + expected: PartiQLValue, + mode: Mode = Mode.PERMISSIVE(), + globals: List = emptyList(), + ) : this(input, Datum.of(expected), mode, globals) + private val compiler = PartiQLCompiler.standard() private val parser = PartiQLParser.standard() private val planner = PartiQLPlanner.standard() @@ -1340,24 +1421,22 @@ class PartiQLEvaluatorTest { .build() val plan = planner.plan(statement, session).plan val result = compiler.prepare(plan, mode).execute() - val output = result.toPartiQLValue() // TODO: Assert directly on Datum - assert(expected == output) { - comparisonString(expected, output, plan) + val comparison = when (jvmEquality) { + true -> expected == result + false -> Datum.comparator().compare(expected, result) == 0 + } + assert(comparison) { + comparisonString(expected, result, plan) } } - @OptIn(PartiQLValueExperimental::class) - private fun comparisonString(expected: PartiQLValue, actual: PartiQLValue, plan: Plan): String { - val expectedBuffer = ByteArrayOutputStream() - val expectedWriter = PartiQLValueIonWriterBuilder.standardIonTextBuilder().build(expectedBuffer) - expectedWriter.append(expected) + private fun comparisonString(expected: Datum, actual: Datum, plan: Plan): String { return buildString { // TODO pretty-print V1 plans! appendLine(plan) - appendLine("Expected : $expectedBuffer") - expectedBuffer.reset() - expectedWriter.append(actual) - appendLine("Actual : $expectedBuffer") + // TODO: Add DatumWriter + appendLine("Expected : $expected") + appendLine("Actual : $actual") } } diff --git a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamicTest.kt b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamicTest.kt index 4aa0f48921..72735cc10e 100644 --- a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamicTest.kt +++ b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamicTest.kt @@ -8,6 +8,7 @@ import org.junit.jupiter.params.provider.MethodSource import org.partiql.eval.Environment import org.partiql.eval.internal.helpers.ValueUtility.check import org.partiql.spi.function.Function +import org.partiql.spi.function.Parameter import org.partiql.spi.value.Datum import org.partiql.spi.value.Datum.array import org.partiql.spi.value.Datum.bag @@ -62,12 +63,30 @@ class ExprCallDynamicTest { ) @OptIn(PartiQLValueExperimental::class) - internal val functions: Array = params.mapIndexed { index, it -> - object : Function.Instance( - returns = PType.integer(), - parameters = arrayOf(it.first.toPType(), it.second.toPType()) - ) { - override fun invoke(args: Array): Datum = integer(index) + internal val functions: Array = params.mapIndexed { index, it -> + object : Function { + + override fun getName(): String { + return "example" + } + + override fun getParameters(): Array { + return arrayOf(Parameter("lhs", it.first.toPType()), Parameter("rhs", it.second.toPType())) + } + + override fun getReturnType(args: Array): PType { + return PType.integer() + } + + override fun getInstance(args: Array): Function.Instance { + return object : Function.Instance( + name = "example", + returns = PType.integer(), + parameters = arrayOf(it.first.toPType(), it.second.toPType()) + ) { + override fun invoke(args: Array): Datum = integer(index) + } + } } }.toTypedArray() } diff --git a/partiql-plan/api/partiql-plan.api b/partiql-plan/api/partiql-plan.api index 78f758078b..e50b891f5d 100644 --- a/partiql-plan/api/partiql-plan.api +++ b/partiql-plan/api/partiql-plan.api @@ -276,6 +276,7 @@ public abstract interface class org/partiql/plan/builder/PlanFactory { public abstract fun rexBag (Ljava/util/Collection;Lorg/partiql/plan/rex/RexType;)Lorg/partiql/plan/rex/RexBag; public abstract fun rexCall (Lorg/partiql/spi/function/Function$Instance;Ljava/util/List;)Lorg/partiql/plan/rex/RexCall; public abstract fun rexCallDynamic (Ljava/lang/String;Ljava/util/List;Ljava/util/List;)Lorg/partiql/plan/rex/RexCallDynamic; + public abstract fun rexCallDynamic (Ljava/lang/String;Ljava/util/List;Ljava/util/List;Lorg/partiql/types/PType;)Lorg/partiql/plan/rex/RexCallDynamic; public abstract fun rexCase (Ljava/util/List;Lorg/partiql/plan/rex/Rex;)Lorg/partiql/plan/rex/RexCase; public abstract fun rexCase (Ljava/util/List;Lorg/partiql/plan/rex/Rex;Lorg/partiql/plan/rex/RexType;)Lorg/partiql/plan/rex/RexCase; public abstract fun rexCase (Lorg/partiql/plan/rex/Rex;Ljava/util/List;Lorg/partiql/plan/rex/Rex;)Lorg/partiql/plan/rex/RexCase; @@ -344,6 +345,7 @@ public final class org/partiql/plan/builder/PlanFactory$DefaultImpls { public static fun rexBag (Lorg/partiql/plan/builder/PlanFactory;Ljava/util/Collection;Lorg/partiql/plan/rex/RexType;)Lorg/partiql/plan/rex/RexBag; public static fun rexCall (Lorg/partiql/plan/builder/PlanFactory;Lorg/partiql/spi/function/Function$Instance;Ljava/util/List;)Lorg/partiql/plan/rex/RexCall; public static fun rexCallDynamic (Lorg/partiql/plan/builder/PlanFactory;Ljava/lang/String;Ljava/util/List;Ljava/util/List;)Lorg/partiql/plan/rex/RexCallDynamic; + public static fun rexCallDynamic (Lorg/partiql/plan/builder/PlanFactory;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Lorg/partiql/types/PType;)Lorg/partiql/plan/rex/RexCallDynamic; public static fun rexCase (Lorg/partiql/plan/builder/PlanFactory;Ljava/util/List;Lorg/partiql/plan/rex/Rex;)Lorg/partiql/plan/rex/RexCase; public static fun rexCase (Lorg/partiql/plan/builder/PlanFactory;Ljava/util/List;Lorg/partiql/plan/rex/Rex;Lorg/partiql/plan/rex/RexType;)Lorg/partiql/plan/rex/RexCase; public static fun rexCase (Lorg/partiql/plan/builder/PlanFactory;Lorg/partiql/plan/rex/Rex;Ljava/util/List;Lorg/partiql/plan/rex/Rex;)Lorg/partiql/plan/rex/RexCase; diff --git a/partiql-plan/src/main/kotlin/org/partiql/plan/builder/PlanFactory.kt b/partiql-plan/src/main/kotlin/org/partiql/plan/builder/PlanFactory.kt index b542bfd297..0c445feb29 100644 --- a/partiql-plan/src/main/kotlin/org/partiql/plan/builder/PlanFactory.kt +++ b/partiql-plan/src/main/kotlin/org/partiql/plan/builder/PlanFactory.kt @@ -391,13 +391,27 @@ public interface PlanFactory { /** * Create a [RexCallDynamic] instance. * - * @param functions - * @param args - * @return + * @param name TODO + * @param functions TODO + * @param args TODO + * @return TODO */ - public fun rexCallDynamic(name: String, functions: List, args: List): RexCallDynamic = + public fun rexCallDynamic(name: String, functions: List, args: List): RexCallDynamic = RexCallDynamicImpl(name, functions, args) + /** + * Create a [RexCallDynamic] instance. + * + * @param name TODO + * @param functions TODO + * @param args TODO + * @param type specifies the output type of the dynamic dispatch. This may be specified if all candidate functions + * return the same type. + * @return TODO + */ + public fun rexCallDynamic(name: String, functions: List, args: List, type: PType): RexCallDynamic = + RexCallDynamicImpl(name, functions, args, type) + /** * Create a [RexCase] instance for a searched case-when with dynamic type. * diff --git a/partiql-plan/src/main/kotlin/org/partiql/plan/rex/RexCallDynamic.kt b/partiql-plan/src/main/kotlin/org/partiql/plan/rex/RexCallDynamic.kt index af9fece070..2c85dee791 100644 --- a/partiql-plan/src/main/kotlin/org/partiql/plan/rex/RexCallDynamic.kt +++ b/partiql-plan/src/main/kotlin/org/partiql/plan/rex/RexCallDynamic.kt @@ -2,6 +2,7 @@ package org.partiql.plan.rex import org.partiql.plan.Visitor import org.partiql.spi.function.Function +import org.partiql.types.PType /** * Logical operator for a dynamic dispatch call. @@ -16,7 +17,7 @@ public interface RexCallDynamic : Rex { /** * Returns the functions to dispatch to. */ - public fun getFunctions(): List + public fun getFunctions(): List /** * Returns the list of function arguments. @@ -33,16 +34,17 @@ public interface RexCallDynamic : Rex { */ internal class RexCallDynamicImpl( private var name: String, - private var functions: List, + private var functions: List, private var args: List, + type: PType = PType.dynamic() ) : RexCallDynamic { // DO NOT USE FINAL - private var _type: RexType = RexType.dynamic() + private var _type: RexType = RexType(type) override fun getName(): String = name - override fun getFunctions(): List = functions + override fun getFunctions(): List = functions override fun getArgs(): List = args diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/CoercionFamily.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/CoercionFamily.kt new file mode 100644 index 0000000000..303c465a08 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/CoercionFamily.kt @@ -0,0 +1,87 @@ +package org.partiql.planner.internal + +import org.partiql.types.PType + +/** + * This represents SQL:1999 Section 4.1.2 "Type conversions and mixing of data types" and breaks down the different + * coercion groups. + * + * TODO: [UNKNOWN] should likely be removed in the future. However, it is needed due to literal nulls and missings. + * TODO: [DYNAMIC] should likely be removed in the future. This is currently only kept to map function signatures. + */ +internal enum class CoercionFamily { + NUMBER, + STRING, + BINARY, + BOOLEAN, + STRUCTURE, + DATE, + TIME, + TIMESTAMP, + COLLECTION, + UNKNOWN, + DYNAMIC; + + companion object { + + /** + * Gets the coercion family for the given [PType.Kind]. + * + * @see CoercionFamily + * @see PType.Kind + * @see family + */ + @JvmStatic + fun family(type: PType.Kind): CoercionFamily { + return when (type) { + PType.Kind.TINYINT -> NUMBER + PType.Kind.SMALLINT -> NUMBER + PType.Kind.INTEGER -> NUMBER + PType.Kind.NUMERIC -> NUMBER + PType.Kind.BIGINT -> NUMBER + PType.Kind.REAL -> NUMBER + PType.Kind.DOUBLE -> NUMBER + PType.Kind.DECIMAL -> NUMBER + PType.Kind.DECIMAL_ARBITRARY -> NUMBER + PType.Kind.STRING -> STRING + PType.Kind.BOOL -> BOOLEAN + PType.Kind.TIMEZ -> TIME + PType.Kind.TIME -> TIME + PType.Kind.TIMESTAMPZ -> TIMESTAMP + PType.Kind.TIMESTAMP -> TIMESTAMP + PType.Kind.DATE -> DATE + PType.Kind.STRUCT -> STRUCTURE + PType.Kind.ARRAY -> COLLECTION + PType.Kind.SEXP -> COLLECTION + PType.Kind.BAG -> COLLECTION + PType.Kind.ROW -> STRUCTURE + PType.Kind.CHAR -> STRING + PType.Kind.VARCHAR -> STRING + PType.Kind.DYNAMIC -> DYNAMIC // TODO: REMOVE + PType.Kind.SYMBOL -> STRING + PType.Kind.BLOB -> BINARY + PType.Kind.CLOB -> STRING + PType.Kind.UNKNOWN -> UNKNOWN // TODO: REMOVE + PType.Kind.VARIANT -> UNKNOWN // TODO: HANDLE VARIANT + } + } + + /** + * Determines if the [from] type can be coerced to the [to] type. + * + * @see CoercionFamily + * @see PType + * @see family + */ + @JvmStatic + fun canCoerce(from: PType, to: PType): Boolean { + if (from.kind == PType.Kind.UNKNOWN) { + return true + } + if (to.kind == PType.Kind.DYNAMIC) { + return true + } + return family(from.kind) == family(to.kind) + } + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt index 488a98cab1..da8e6ed8fb 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt @@ -134,7 +134,6 @@ internal class Env(private val session: Session) { // 2. Search along the PATH. // TODO - val match = FnResolver.resolve(variants, args.map { it.type }) // If Type mismatch, then we return a missingOp whose trace is all possible candidates. if (match == null) { @@ -148,30 +147,24 @@ internal class Env(private val session: Session) { fn = refFn( catalog = catalog.getName(), name = Name.of(name), - signature = it.function, + signature = it, ), - coercions = it.mapping.toList(), + coercions = emptyList(), // TODO: Remove this from the plan ) } // Rewrite as a dynamic call to be typed by PlanTyper Rex(CompilerType(PType.dynamic()), Rex.Op.Call.Dynamic(args, candidates)) } is FnMatch.Static -> { - // Create an internal typed reference - val ref = refFn( - catalog = catalog.getName(), - name = Name.of(name), - signature = match.function, - ) // Apply the coercions as explicit casts val coercions: List = args.mapIndexed { i, arg -> when (val cast = match.mapping[i]) { null -> arg - else -> Rex(CompilerType(PType.dynamic()), Rex.Op.Cast.Resolved(cast, arg)) + else -> Rex(cast.target, Rex.Op.Cast.Resolved(cast, arg)) } } // Rewrite as a static call to be typed by PlanTyper - Rex(CompilerType(PType.dynamic()), Rex.Op.Call.Static(ref, coercions)) + Rex(CompilerType(PType.dynamic()), Rex.Op.Call.Static(match.function, coercions)) } } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnMatch.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnMatch.kt index 31cb15d437..726026d6bb 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnMatch.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnMatch.kt @@ -16,7 +16,7 @@ internal sealed class FnMatch { * @property mapping */ class Static( - val function: Function, + val function: Function.Instance, val mapping: Array, ) : FnMatch() { @@ -51,5 +51,5 @@ internal sealed class FnMatch { * * @property candidates Ordered list of potentially applicable functions to dispatch dynamically. */ - class Dynamic(val candidates: List) : FnMatch() + class Dynamic(val candidates: List) : FnMatch() } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnResolver.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnResolver.kt index 43253d3aa3..e773886b6f 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnResolver.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnResolver.kt @@ -42,15 +42,15 @@ internal object FnResolver { // 1. Look for exact match for (candidate in candidates) { if (candidate.matchesExactly(args)) { - return FnMatch.Static(candidate, arrayOfNulls(args.size)) + val fn = candidate.getInstance(args.toTypedArray()) ?: error("This shouldn't have happened. Matching exactly should produce a function instance.") + return FnMatch.Static(fn, arrayOfNulls(args.size)) } } // 2. If there are DYNAMIC arguments, return all candidates val isDynamic = args.any { it.kind == Kind.DYNAMIC } if (isDynamic) { - val matches = match(candidates, args).ifEmpty { return null } - val orderedMatches = matches.sortedWith(MatchResultComparator).map { it.match } + val orderedMatches = candidates.sortedWith(FnComparator) return FnMatch.Dynamic(orderedMatches) } @@ -62,13 +62,17 @@ internal object FnResolver { // 3. Discard functions that cannot be matched (via implicit coercion or exact matches) val invocableMatches = match(candidates, args).ifEmpty { return null } if (invocableMatches.size == 1) { - return invocableMatches.first().match + val match = invocableMatches.first() + val fn = match.match.getInstance(args.toTypedArray()) ?: return null + return FnMatch.Static(fn, match.mapping) } // 4. Run through all candidates and keep those with the most exact matches on input types. val matches = matchOn(invocableMatches) { it.numberOfExactInputTypes } if (matches.size == 1) { - return matches.first().match + val match = matches.first() + val fn = match.match.getInstance(args.toTypedArray()) ?: return null + return FnMatch.Static(fn, match.mapping) } // TODO: Do we care about preferred types? This is a PostgreSQL concept. @@ -76,7 +80,10 @@ internal object FnResolver { // 6. Find the highest precedence one. NOTE: This is a remnant of the previous implementation. Whether we want // to keep this is up to us. - return matches.sortedWith(MatchResultComparator).first().match + val match = matches.sortedWith(MatchResultComparator).first() + val fn = match.match + val instance = fn.getInstance(args.toTypedArray()) ?: return null + return FnMatch.Static(instance, match.mapping) } /** @@ -117,11 +124,12 @@ internal object FnResolver { * Check if this function accepts the exact input argument types. Assume same arity. */ private fun Function.matchesExactly(args: List): Boolean { - val parameters = getParameters() + val instance = getInstance(args.toTypedArray()) ?: return false + val parameters = instance.parameters for (i in args.indices) { val a = args[i] val p = parameters[i] - if (p.getMatch(a) != a) return false + if (p != a) return false } return true } @@ -133,7 +141,8 @@ internal object FnResolver { * @return */ private fun Function.match(args: List): MatchResult? { - val parameters = getParameters() + val instance = this.getInstance(args.toTypedArray()) ?: return null + val parameters = instance.parameters val mapping = arrayOfNulls(args.size) var exactInputTypes = 0 for (i in args.indices) { @@ -143,31 +152,34 @@ internal object FnResolver { } // check match val p = parameters[i] - val m = p.getMatch(a) when { - m == null -> return null // short-circuit - m == a -> exactInputTypes++ - else -> mapping[i] = coercion(a, m) + p == a -> exactInputTypes++ + else -> mapping[i] = coercion(a, p) ?: return null } } return MatchResult( - FnMatch.Static(this, mapping), + this, + mapping, exactInputTypes, ) } - private fun coercion(arg: PType, target: PType): Ref.Cast { - return Ref.Cast(arg.toCType(), target.toCType(), Ref.Cast.Safety.COERCION, true) + private fun coercion(arg: PType, target: PType): Ref.Cast? { + return when (CoercionFamily.canCoerce(arg, target)) { + true -> Ref.Cast(arg.toCType(), target.toCType(), Ref.Cast.Safety.COERCION, true) + false -> return null + } } private class MatchResult( - val match: FnMatch.Static, + val match: Function, + val mapping: Array, val numberOfExactInputTypes: Int, ) private object MatchResultComparator : Comparator { override fun compare(o1: MatchResult, o2: MatchResult): Int { - return FnComparator.reversed().compare(o1.match.function, o2.match.function) + return FnComparator.reversed().compare(o1.match, o2.match) } } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt index ccf470c489..ad5f69a85b 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt @@ -442,12 +442,11 @@ internal data class Rex( } internal data class Static( - @JvmField internal val fn: Ref.Fn, + @JvmField internal val fn: Function.Instance, @JvmField internal val args: List, ) : Call() { public override val children: List by lazy { val kids = mutableListOf() - kids.add(fn) kids.addAll(args) kids.filterNotNull() } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt index 66257edaeb..ddf95bf2d4 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt @@ -163,15 +163,14 @@ internal class PlanTransform(private val flags: Set) { override fun visitRexOpCallDynamic(node: IRex.Op.Call.Dynamic, ctx: PType): Any { // TODO add argument types and move to plan typer!! - val fns = node.candidates.map { it.fn.signature.getInstance(emptyArray()) } val args = node.args.map { visitRex(it, ctx) } + val fns = node.candidates.map { it.fn.signature } // TODO assert on function name in plan typer .. here is not the place. - return factory.rexCallDynamic("unknown", fns, args) + return factory.rexCallDynamic("unknown", fns, args, ctx) } override fun visitRexOpCallStatic(node: IRex.Op.Call.Static, ctx: PType): Any { - // TODO add argument types and move to PlanTyper!! - val fn = node.fn.signature.getInstance(emptyArray()) + val fn = node.fn val args = node.args.map { visitRex(it, ctx) } return factory.rexCall(fn, args) } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt index d4f751b1a9..a79f206453 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt @@ -86,7 +86,9 @@ import org.partiql.planner.internal.typer.CompilerType import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType import org.partiql.spi.catalog.Identifier import org.partiql.types.PType +import org.partiql.value.DecimalValue import org.partiql.value.MissingValue +import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.StringValue import org.partiql.value.boolValue @@ -126,7 +128,7 @@ internal object RexConverter { override fun visitExprLit(node: ExprLit, context: Env): Rex { val type = CompilerType( - _delegate = node.value.type.toPType(), + _delegate = node.value.toPType(), isNullValue = node.value.isNull, isMissingValue = node.value is MissingValue ) @@ -134,6 +136,18 @@ internal object RexConverter { return rex(type, op) } + private fun PartiQLValue.toPType(): PType { + if (this.isNull) { + return this.type.toPType() + } + return when (this) { + is DecimalValue -> { + PType.decimal(this.value!!.precision(), this.value!!.scale()) + } + else -> this.type.toPType() + } + } + /** * TODO PartiQLValue will be replaced by Datum (i.e. IonDatum) is a subsequent PR. */ diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index 5e963effa3..9107eed317 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -781,15 +781,12 @@ internal class PlanTyper(private val env: Env, config: Context) { else -> it } } - // TODO pass argument types to compute the return type. - val returnType = node.fn.signature.getReturnType(emptyArray()) + val instance = node.fn + val returnType: PType = instance.returns // Check if any arg is always missing val argIsAlwaysMissing = args.any { it.type.isMissingValue } - // TODO REMOVE ME !!! THIS IS A HACK (: - val instance = node.fn.signature.getInstance(emptyArray()) - if (argIsAlwaysMissing && instance.isMissingCall) { return errorRexAndReport(_listener, PErrors.alwaysMissing(null)) } @@ -807,8 +804,9 @@ internal class PlanTyper(private val env: Env, config: Context) { */ override fun visitRexOpCallDynamic(node: Rex.Op.Call.Dynamic, ctx: CompilerType?): Rex { // TODO pass argument types to compute the return type + val argTypes = node.args.map { it.type } val types = node.candidates - .map { it.fn.signature.getReturnType(emptyArray()) } + .mapNotNull { it.fn.signature.getInstance(argTypes.toTypedArray())?.returns } .toMutableSet() // TODO: Should this always be DYNAMIC? return Rex(type = CompilerType(anyOf(types) ?: PType.dynamic()), op = node) @@ -1154,7 +1152,7 @@ internal class PlanTyper(private val env: Env, config: Context) { if (firstBranchCondition !is Rex.Op.Call.Static) { return null } - if (!firstBranchCondition.fn.signature.getName().equals("is_struct", ignoreCase = true)) { + if (!firstBranchCondition.fn.name.equals("is_struct", ignoreCase = true)) { return null } val firstBranchResultType = firstBranch.rex.type diff --git a/partiql-planner/src/main/resources/partiql_plan_internal.ion b/partiql-planner/src/main/resources/partiql_plan_internal.ion index 3a6243ad7e..efc9b45604 100644 --- a/partiql-planner/src/main/resources/partiql_plan_internal.ion +++ b/partiql-planner/src/main/resources/partiql_plan_internal.ion @@ -6,6 +6,7 @@ imports::{ partiql_value_type::'org.partiql.planner.internal.typer.CompilerType', static_type::'org.partiql.planner.internal.typer.CompilerType', fn_signature::'org.partiql.spi.function.Function', + fn_instance::'org.partiql.spi.function.Function.Instance', agg_signature::'org.partiql.spi.function.Aggregation', table::'org.partiql.spi.catalog.Table', ], @@ -127,7 +128,7 @@ rex::{ }, static::{ - fn: '.ref.fn', + fn: fn_instance, args: list::[rex], }, diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt index 490a64ee73..d0117b5ad2 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt @@ -870,7 +870,7 @@ internal class PlanTyperTestsPorted { StructType( fields = mapOf( "a" to StaticType.INT4, - "b" to StaticType.DECIMAL, + "b" to DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(2, 1)), ), contentClosed = true, constraints = setOf( @@ -888,7 +888,7 @@ internal class PlanTyperTestsPorted { StructType( fields = mapOf( "a" to StaticType.INT4, - "b" to StaticType.DECIMAL, + "b" to DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(2, 1)), ), contentClosed = true, constraints = setOf( @@ -905,7 +905,7 @@ internal class PlanTyperTestsPorted { expected = BagType( StructType( fields = listOf( - StructType.Field("b", StaticType.DECIMAL), + StructType.Field("b", DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(2, 1))), StructType.Field("a", StaticType.INT4), ), contentClosed = true, @@ -924,7 +924,7 @@ internal class PlanTyperTestsPorted { StructType( fields = listOf( StructType.Field("a", StaticType.INT4), - StructType.Field("a", StaticType.DECIMAL), + StructType.Field("a", DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(2, 1))), ), contentClosed = true, constraints = setOf( @@ -942,7 +942,7 @@ internal class PlanTyperTestsPorted { StructType( fields = listOf( StructType.Field("a", StaticType.INT4), - StructType.Field("a", StaticType.DECIMAL), + StructType.Field("a", DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(2, 1))), ), contentClosed = true, constraints = setOf( @@ -970,7 +970,7 @@ internal class PlanTyperTestsPorted { StructType( fields = listOf( StructType.Field("a", StaticType.INT4), - StructType.Field("a", StaticType.DECIMAL), + StructType.Field("a", DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(2, 1))), StructType.Field("a", StaticType.STRING), ), contentClosed = true, @@ -989,7 +989,7 @@ internal class PlanTyperTestsPorted { StructType( fields = listOf( StructType.Field("a", StaticType.INT4), - StructType.Field("a", StaticType.DECIMAL), + StructType.Field("a", DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(2, 1))), ), contentClosed = true, constraints = setOf( @@ -2777,7 +2777,7 @@ internal class PlanTyperTestsPorted { SuccessTestCase( key = PartiQLTest.Key("basics", "case-when-43"), catalog = "pql", - expected = StaticType.DECIMAL + expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(38, 5)), ), SuccessTestCase( key = PartiQLTest.Key("basics", "case-when-44"), @@ -3147,7 +3147,7 @@ internal class PlanTyperTestsPorted { query = """ { 'aBc': 1, 'AbC': 2.0 }['AbC']; """, - expected = StaticType.DECIMAL + expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(2, 1)), ), // This should fail because the Spec says tuple indexing MUST use a literal string or explicit cast. ErrorTestCase( @@ -3315,7 +3315,7 @@ internal class PlanTyperTestsPorted { expected = BagType( StructType( fields = mapOf( - "a" to StaticType.DECIMAL, + "a" to DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(2, 1)), "c" to StaticType.INT8, "s" to StaticType.DECIMAL, "m" to StaticType.DECIMAL, @@ -3882,6 +3882,10 @@ internal class PlanTyperTestsPorted { when (val statement = plan.getOperation()) { is org.partiql.plan.Operation.Query -> { assert(collector.problems.isEmpty()) { + // Throw internal error for debugging + collector.problems.firstOrNull { it.code() == PError.INTERNAL_ERROR }?.let { pError -> + pError.getOrNull("CAUSE", Throwable::class.java)?.let { throw it } + } buildString { appendLine(collector.problems.toString()) appendLine() diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpArithmeticTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpArithmeticTest.kt index c9bbbea6e3..49cfec6a5f 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpArithmeticTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpArithmeticTest.kt @@ -37,6 +37,8 @@ class OpArithmeticTest : PartiQLTyperTestBase() { val arg1 = args[1] val output = when { arg0 == arg1 -> arg1 + arg0 == StaticType.DECIMAL && arg1 == StaticType.FLOAT -> arg1 // TODO: The cast table is wrong. Honestly, it should be deleted. + arg1 == StaticType.DECIMAL && arg0 == StaticType.FLOAT -> arg0 // TODO: The cast table is wrong castTable(arg1, arg0) == CastType.COERCION -> arg0 castTable(arg0, arg1) == CastType.COERCION -> arg1 else -> error("Arguments do not conform to parameters. Args: $args") diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpBitwiseAndTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpBitwiseAndTest.kt index 6973cf5de4..63860ce265 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpBitwiseAndTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpBitwiseAndTest.kt @@ -34,8 +34,8 @@ class OpBitwiseAndTest : PartiQLTyperTestBase() { val arg1 = args[1] val output = when { arg0 !in allIntType && arg1 !in allIntType -> StaticType.INT - arg0 in allIntType && arg1 !in allIntType -> arg0 - arg0 !in allIntType && arg1 in allIntType -> arg1 + arg0 in allIntType && arg1 !in allIntType -> StaticType.INT + arg0 !in allIntType && arg1 in allIntType -> StaticType.INT arg0 == arg1 -> arg1 castTable(arg1, arg0) == CastType.COERCION -> arg0 castTable(arg0, arg1) == CastType.COERCION -> arg1 diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpBetweenTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpBetweenTest.kt index 578153335f..4cd4d07db3 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpBetweenTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpBetweenTest.kt @@ -13,6 +13,7 @@ import java.util.stream.Stream // TODO: Finalize the semantics for Between operator when operands contain MISSING // For now, Between propagates MISSING. class OpBetweenTest : PartiQLTyperTestBase() { + @TestFactory fun between(): Stream { val tests = listOf( diff --git a/partiql-spi/api/partiql-spi.api b/partiql-spi/api/partiql-spi.api index 060a880fb8..c288970495 100644 --- a/partiql-spi/api/partiql-spi.api +++ b/partiql-spi/api/partiql-spi.api @@ -343,6 +343,7 @@ public final class org/partiql/spi/errors/PError : org/partiql/spi/Enum { public class org/partiql/spi/errors/PErrorException : org/partiql/spi/errors/PErrorListenerException { public field error Lorg/partiql/spi/errors/PError; public fun (Lorg/partiql/spi/errors/PError;)V + public fun (Lorg/partiql/spi/errors/PError;Ljava/lang/Throwable;)V public fun equals (Ljava/lang/Object;)Z public fun hashCode ()I public fun toString ()Ljava/lang/String; @@ -407,10 +408,13 @@ public final class org/partiql/spi/function/Aggregation$DefaultImpls { public abstract interface class org/partiql/spi/function/Function : org/partiql/spi/function/Routine { public static final field Companion Lorg/partiql/spi/function/Function$Companion; public abstract fun getInstance ([Lorg/partiql/types/PType;)Lorg/partiql/spi/function/Function$Instance; + public static fun instance (Ljava/lang/String;[Lorg/partiql/spi/function/Parameter;Lorg/partiql/types/PType;ZZLkotlin/jvm/functions/Function1;)Lorg/partiql/spi/function/Function$Instance; public static fun static (Ljava/lang/String;[Lorg/partiql/spi/function/Parameter;Lorg/partiql/types/PType;ZZLkotlin/jvm/functions/Function1;)Lorg/partiql/spi/function/Function; } public final class org/partiql/spi/function/Function$Companion { + public final fun instance (Ljava/lang/String;[Lorg/partiql/spi/function/Parameter;Lorg/partiql/types/PType;ZZLkotlin/jvm/functions/Function1;)Lorg/partiql/spi/function/Function$Instance; + public static synthetic fun instance$default (Lorg/partiql/spi/function/Function$Companion;Ljava/lang/String;[Lorg/partiql/spi/function/Parameter;Lorg/partiql/types/PType;ZZLkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/spi/function/Function$Instance; public final fun static (Ljava/lang/String;[Lorg/partiql/spi/function/Parameter;Lorg/partiql/types/PType;ZZLkotlin/jvm/functions/Function1;)Lorg/partiql/spi/function/Function; public static synthetic fun static$default (Lorg/partiql/spi/function/Function$Companion;Ljava/lang/String;[Lorg/partiql/spi/function/Parameter;Lorg/partiql/types/PType;ZZLkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/spi/function/Function; } @@ -423,10 +427,11 @@ public final class org/partiql/spi/function/Function$DefaultImpls { public abstract class org/partiql/spi/function/Function$Instance { public final field isMissingCall Z public final field isNullCall Z + public final field name Ljava/lang/String; public final field parameters [Lorg/partiql/types/PType; public final field returns Lorg/partiql/types/PType; - public fun ([Lorg/partiql/types/PType;Lorg/partiql/types/PType;ZZ)V - public synthetic fun ([Lorg/partiql/types/PType;Lorg/partiql/types/PType;ZZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Ljava/lang/String;[Lorg/partiql/types/PType;Lorg/partiql/types/PType;ZZ)V + public synthetic fun (Ljava/lang/String;[Lorg/partiql/types/PType;Lorg/partiql/types/PType;ZZILkotlin/jvm/internal/DefaultConstructorMarker;)V public abstract fun invoke ([Lorg/partiql/spi/value/Datum;)Lorg/partiql/spi/value/Datum; } @@ -438,14 +443,14 @@ public final class org/partiql/spi/function/Parameter { public final fun getMatch (Lorg/partiql/types/PType;)Lorg/partiql/types/PType; public final fun getName ()Ljava/lang/String; public final fun getType ()Lorg/partiql/types/PType; - public static final fun numeric (Ljava/lang/String;)Lorg/partiql/spi/function/Parameter; + public static final fun number (Ljava/lang/String;)Lorg/partiql/spi/function/Parameter; public static final fun text (Ljava/lang/String;)Lorg/partiql/spi/function/Parameter; } public final class org/partiql/spi/function/Parameter$Companion { public final fun collection (Ljava/lang/String;)Lorg/partiql/spi/function/Parameter; public final fun dynamic (Ljava/lang/String;)Lorg/partiql/spi/function/Parameter; - public final fun numeric (Ljava/lang/String;)Lorg/partiql/spi/function/Parameter; + public final fun number (Ljava/lang/String;)Lorg/partiql/spi/function/Parameter; public final fun text (Ljava/lang/String;)Lorg/partiql/spi/function/Parameter; } diff --git a/partiql-spi/src/main/java/org/partiql/spi/errors/PErrorException.java b/partiql-spi/src/main/java/org/partiql/spi/errors/PErrorException.java index 91f1e0de50..1761b26b33 100644 --- a/partiql-spi/src/main/java/org/partiql/spi/errors/PErrorException.java +++ b/partiql-spi/src/main/java/org/partiql/spi/errors/PErrorException.java @@ -22,6 +22,16 @@ public PErrorException(@NotNull PError error) { this.error = error; } + /** + * Creates an exception that holds an error. + * @param error the error that is wrapped + * @param cause the cause of the error + */ + public PErrorException(@NotNull PError error, @NotNull Throwable cause) { + super(cause); + this.error = error; + } + @Override public String toString() { return "ErrorException{" + diff --git a/partiql-spi/src/main/java/org/partiql/spi/errors/PErrorListener.java b/partiql-spi/src/main/java/org/partiql/spi/errors/PErrorListener.java index 75c44958ef..b25f359084 100644 --- a/partiql-spi/src/main/java/org/partiql/spi/errors/PErrorListener.java +++ b/partiql-spi/src/main/java/org/partiql/spi/errors/PErrorListener.java @@ -26,6 +26,12 @@ public interface PErrorListener { static PErrorListener abortOnError() { return error -> { if (error.severity.code() == Severity.ERROR) { + if (error.code() == PError.INTERNAL_ERROR) { + Throwable cause = error.getOrNull("CAUSE", Throwable.class); + if (cause != null) { + throw new PErrorException(error, cause); + } + } throw new PErrorException(error); } }; diff --git a/partiql-spi/src/main/java/org/partiql/spi/value/Datum.java b/partiql-spi/src/main/java/org/partiql/spi/value/Datum.java index 6eb39ca689..dd39e4d9ca 100644 --- a/partiql-spi/src/main/java/org/partiql/spi/value/Datum.java +++ b/partiql-spi/src/main/java/org/partiql/spi/value/Datum.java @@ -6,6 +6,7 @@ import org.jetbrains.annotations.Nullable; import org.partiql.errors.DataException; import org.partiql.types.PType; +import org.partiql.value.DecimalValue; import org.partiql.value.PartiQL; import org.partiql.value.PartiQLValue; import org.partiql.value.PartiQLValueType; @@ -513,7 +514,8 @@ static Datum of(PartiQLValue value) { return new DatumDouble(Objects.requireNonNull(FLOAT64Value.getValue())); case DECIMAL: org.partiql.value.DecimalValue DECIMALValue = (org.partiql.value.DecimalValue) value; - return new DatumDecimal(Objects.requireNonNull(DECIMALValue.getValue()), PType.decimal()); + BigDecimal bigDecimal = Objects.requireNonNull(DECIMALValue.getValue()); + return Datum.decimal(bigDecimal, bigDecimal.precision(), bigDecimal.scale()); case CHAR: org.partiql.value.CharValue CHARValue = (org.partiql.value.CharValue) value; String charString = Objects.requireNonNull(CHARValue.getValue()).toString(); @@ -531,7 +533,8 @@ static Datum of(PartiQLValue value) { throw new UnsupportedOperationException(); case DECIMAL_ARBITRARY: org.partiql.value.DecimalValue DECIMAL_ARBITRARYValue = (org.partiql.value.DecimalValue) value; - return new DatumDecimal(Objects.requireNonNull(DECIMAL_ARBITRARYValue.getValue()), PType.decimal()); + BigDecimal bigDecimal2 = Objects.requireNonNull(DECIMAL_ARBITRARYValue.getValue()); + return Datum.decimal(bigDecimal2, bigDecimal2.precision(), bigDecimal2.scale()); case ANY: default: throw new NotImplementedError(); diff --git a/partiql-spi/src/main/java/org/partiql/spi/value/DatumDecimal.java b/partiql-spi/src/main/java/org/partiql/spi/value/DatumDecimal.java index 1835207f4b..c5976c9ff9 100644 --- a/partiql-spi/src/main/java/org/partiql/spi/value/DatumDecimal.java +++ b/partiql-spi/src/main/java/org/partiql/spi/value/DatumDecimal.java @@ -4,6 +4,7 @@ import org.partiql.types.PType; import java.math.BigDecimal; +import java.util.Objects; /** * This shall always be package-private (internal). @@ -36,4 +37,25 @@ public BigDecimal getBigDecimal() { public PType getType() { return _type; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof Datum)) return false; + Datum data = (Datum) o; + return Objects.equals(_type, data.getType()) && Objects.equals(_value, data.getBigDecimal()); + } + + @Override + public int hashCode() { + return Objects.hash(_value, _type); + } + + @Override + public String toString() { + return "DatumDecimal{" + + "_value=" + _value + + ", _type=" + _type + + '}'; + } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/Builtins.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/Builtins.kt index df926ab019..d0d0ba3765 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/Builtins.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/Builtins.kt @@ -45,11 +45,7 @@ internal object Builtins { Fn_BIT_LENGTH__STRING__INT32, Fn_BIT_LENGTH__CLOB__INT32, Fn_BIT_LENGTH__SYMBOL__INT32, - Fn_BITWISE_AND__INT8_INT8__INT8, - Fn_BITWISE_AND__INT16_INT16__INT16, - Fn_BITWISE_AND__INT32_INT32__INT32, - Fn_BITWISE_AND__INT64_INT64__INT64, - Fn_BITWISE_AND__INT_INT__INT, + FnBitwiseAnd, Fn_CARDINALITY__BAG__INT32, Fn_CARDINALITY__LIST__INT32, Fn_CARDINALITY__SEXP__INT32, @@ -347,14 +343,7 @@ internal object Builtins { Fn_OCTET_LENGTH__STRING__INT32, Fn_OCTET_LENGTH__CLOB__INT32, Fn_OCTET_LENGTH__SYMBOL__INT32, - Fn_PLUS__INT8_INT8__INT8, - Fn_PLUS__INT16_INT16__INT16, - Fn_PLUS__INT32_INT32__INT32, - Fn_PLUS__INT64_INT64__INT64, - Fn_PLUS__INT_INT__INT, - Fn_PLUS__FLOAT32_FLOAT32__FLOAT32, - Fn_PLUS__FLOAT64_FLOAT64__FLOAT64, - Fn_PLUS__DECIMAL_ARBITRARY_DECIMAL_ARBITRARY__DECIMAL_ARBITRARY, + FnPlus, Fn_POS__INT8__INT8, Fn_POS__INT16__INT16, Fn_POS__INT32__INT32, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/Function.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/Function.kt index 2941429b37..64cc0a649c 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/Function.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/Function.kt @@ -11,7 +11,7 @@ public interface Function : Routine { /** * Returns an invocable implementation. Optional. */ - public fun getInstance(args: Array): Instance { + public fun getInstance(args: Array): Instance? { throw Error("Function ${getName()} has no implementations.") } @@ -23,6 +23,7 @@ public interface Function : Routine { * @see Function.getInstance */ public abstract class Instance( + @JvmField public val name: String, @JvmField public val parameters: Array, @JvmField public val returns: PType, @JvmField public val isNullCall: Boolean = true, @@ -42,6 +43,37 @@ public interface Function : Routine { */ public companion object { + /** + * TODO INTERNALIZE TO SPI AND REPLACE WITH A BUILDER (OR SOMETHING..) + * + * @param name + * @param parameters + * @param returns + * @param isNullCall + * @param isMissingCall + * @param invoke + * @return + */ + @JvmStatic + public fun instance( + name: String, + parameters: Array, + returns: PType, + isNullCall: Boolean = true, + isMissingCall: Boolean = true, + invoke: (Array) -> Datum, + ): Instance { + return object : Instance( + name, + Array(parameters.size) { parameters[it].getType() }, + returns, + isNullCall, + isMissingCall, + ) { + override fun invoke(args: Array): Datum = invoke(args) + } + } + /** * TODO INTERNALIZE TO SPI AND REPLACE WITH A BUILDER (OR SOMETHING..) * @@ -64,6 +96,7 @@ public interface Function : Routine { ): Function = _Function( name, parameters, returns, object : Instance( + name, Array(parameters.size) { parameters[it].getType() }, returns, isNullCall, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/Parameter.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/Parameter.kt index 36710acbb2..35cad050bf 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/Parameter.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/Parameter.kt @@ -94,10 +94,10 @@ public class Parameter private constructor( public fun text(name: String): Parameter = Parameter(name, SqlTypeFamily.TEXT, false) /** - * Create a numeric [Parameter]. + * Create a number [Parameter]. */ @JvmStatic - public fun numeric(name: String): Parameter = Parameter(name, SqlTypeFamily.NUMERIC, false) + public fun number(name: String): Parameter = Parameter(name, SqlTypeFamily.NUMBER, false) /** * Create a collection [Parameter]. diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/ArithmeticDiadicOperator.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/ArithmeticDiadicOperator.kt new file mode 100644 index 0000000000..7704261742 --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/ArithmeticDiadicOperator.kt @@ -0,0 +1,157 @@ +package org.partiql.spi.function.builtins + +import org.partiql.spi.function.Function +import org.partiql.spi.function.Parameter +import org.partiql.spi.internal.SqlTypeFamily +import org.partiql.spi.value.Datum +import org.partiql.types.PType + +/** + * This carries along with it a static table containing a mapping between the input types and the implementation. + */ +internal abstract class ArithmeticDiadicOperator : Function { + + companion object { + val allowed = SqlTypeFamily.NUMBER.members + setOf(PType.Kind.UNKNOWN) + } + + override fun getInstance(args: Array): Function.Instance? { + if (!allowed.contains(args[0].kind) || !allowed.contains(args[1].kind)) { + return null + } + val lhs = args[0].let { + when (it.kind) { + PType.Kind.DECIMAL_ARBITRARY -> PType.decimal(38, 0) // TODO: Remove decimal arbitrary + else -> it + } + } + val rhs = args[1].let { + when (it.kind) { + PType.Kind.DECIMAL_ARBITRARY -> PType.decimal(38, 0) // TODO: Remove decimal arbitrary + else -> it + } + } + val lhsPrecedence = TYPE_PRECEDENCE[lhs.kind] ?: throw IllegalArgumentException("Type not supported -- LHS = $lhs") + val rhsPrecedence = TYPE_PRECEDENCE[rhs.kind] ?: throw IllegalArgumentException("Type not supported -- RHS = $rhs") + val (newLhs, newRhs) = when (lhsPrecedence.compareTo(rhsPrecedence)) { + -1 -> (rhs to rhs) + 0 -> (lhs to rhs) + else -> (lhs to lhs) + } + val invocation = lookupTable[lhs.kind.ordinal][rhs.kind.ordinal] + return invocation.invoke(newLhs, newRhs) + } + + /** + * @param integerLhs TODO + * @param integerRhs TODO + * @return TODO + */ + abstract fun getIntegerInstance(integerLhs: PType, integerRhs: PType): Function.Instance + + /** + * @param tinyIntLhs TODO + * @param tinyIntRhs TODO + * @return TODO + */ + abstract fun getTinyIntInstance(tinyIntLhs: PType, tinyIntRhs: PType): Function.Instance + + /** + * @param smallIntLhs TODO + * @param smallIntRhs TODO + * @return TODO + */ + abstract fun getSmallIntInstance(smallIntLhs: PType, smallIntRhs: PType): Function.Instance + + /** + * @param bigIntLhs TODO + * @param bigIntRhs TODO + * @return TODO + */ + abstract fun getBigIntInstance(bigIntLhs: PType, bigIntRhs: PType): Function.Instance + + /** + * TODO: This will soon be removed. + * @param numericLhs TODO + * @param numericRhs TODO + * @return TODO + */ + abstract fun getNumericInstance(numericLhs: PType, numericRhs: PType): Function.Instance + + /** + * @param v1 TODO + * @param v2 TODO + * @return TODO + */ + abstract fun getDecimalInstance(v1: PType, v2: PType): Function.Instance + + /** + * @param realLhs TODO + * @param realRhs TODO + * @return TODO + */ + abstract fun getRealInstance(realLhs: PType, realRhs: PType): Function.Instance + + /** + * @param doubleLhs TODO + * @param doubleRhs TODO + * @return TODO + */ + abstract fun getDoubleInstance(doubleLhs: PType, doubleRhs: PType): Function.Instance + + override fun getParameters(): Array { + return arrayOf( + Parameter.number("lhs"), + Parameter.number("rhs"), + ) + } + + override fun getReturnType(args: Array): PType { + return getInstance(args)?.returns ?: PType.dynamic() // TODO: Do we need this method? + } + + private val lookupTable: Array Function.Instance?>> = Array(PType.Kind.entries.size) { + Array(PType.Kind.entries.size) { + { _, _ -> null } + } + } + + private fun fillTable(lhs: PType.Kind, rhs: PType.Kind, instance: (PType, PType) -> Function.Instance) { + lookupTable[lhs.ordinal][rhs.ordinal] = instance + } + + private fun fillTable(highPrecedence: PType.Kind, instance: (PType, PType) -> Function.Instance) { + val numbers = SqlTypeFamily.NUMBER.members + setOf(PType.Kind.UNKNOWN) + numbers.filter { + (TYPE_PRECEDENCE[highPrecedence]!! > TYPE_PRECEDENCE[it]!!) + }.forEach { + fillTable(highPrecedence, it) { lhs, _ -> instance(lhs, lhs) } + fillTable(it, highPrecedence) { _, rhs -> instance(rhs, rhs) } + } + fillTable(highPrecedence, highPrecedence) { lhs, rhs -> instance(lhs, rhs) } + } + + init { + fillTable(PType.Kind.TINYINT) { lhs, rhs -> getTinyIntInstance(lhs, rhs) } + fillTable(PType.Kind.SMALLINT) { lhs, rhs -> getSmallIntInstance(lhs, rhs) } + fillTable(PType.Kind.INTEGER) { lhs, rhs -> getIntegerInstance(lhs, rhs) } + fillTable(PType.Kind.BIGINT) { lhs, rhs -> getBigIntInstance(lhs, rhs) } + fillTable(PType.Kind.DECIMAL) { lhs, rhs -> getDecimalInstance(lhs, rhs) } + fillTable(PType.Kind.DECIMAL_ARBITRARY) { lhs, rhs -> getDecimalInstance(lhs, rhs) } // TODO: Remove this + fillTable(PType.Kind.NUMERIC) { lhs, rhs -> getNumericInstance(lhs, rhs) } // TODO: Remove this + fillTable(PType.Kind.REAL) { lhs, rhs -> getRealInstance(lhs, rhs) } + fillTable(PType.Kind.DOUBLE) { lhs, rhs -> getDoubleInstance(lhs, rhs) } + } + + protected fun basic(arg: PType, invocation: (Array) -> Datum): Function.Instance { + return Function.instance( + name = getName(), + returns = arg, + parameters = arrayOf( + Parameter("lhs", arg), + Parameter("rhs", arg), + ), + invoke = invocation + ) + } +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnBitwiseAnd.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnBitwiseAnd.kt index 1357a1de43..021a7bf084 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnBitwiseAnd.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnBitwiseAnd.kt @@ -4,82 +4,65 @@ package org.partiql.spi.function.builtins import org.partiql.spi.function.Function -import org.partiql.spi.function.Parameter import org.partiql.spi.value.Datum import org.partiql.types.PType import kotlin.experimental.and -internal val Fn_BITWISE_AND__INT8_INT8__INT8 = Function.static( - - name = "bitwise_and", - returns = PType.tinyint(), - parameters = arrayOf( - Parameter("lhs", PType.tinyint()), - Parameter("rhs", PType.tinyint()), - ), - -) { args -> - @Suppress("DEPRECATION") val arg0 = args[0].byte - @Suppress("DEPRECATION") val arg1 = args[1].byte - Datum.tinyint(arg0 and arg1) -} - -internal val Fn_BITWISE_AND__INT16_INT16__INT16 = Function.static( - - name = "bitwise_and", - returns = PType.smallint(), - parameters = arrayOf( - Parameter("lhs", PType.smallint()), - Parameter("rhs", PType.smallint()), - ), - -) { args -> - val arg0 = args[0].short - val arg1 = args[1].short - Datum.smallint(arg0 and arg1) -} - -internal val Fn_BITWISE_AND__INT32_INT32__INT32 = Function.static( - - name = "bitwise_and", - returns = PType.integer(), - parameters = arrayOf( - Parameter("lhs", PType.integer()), - Parameter("rhs", PType.integer()), - ), - -) { args -> - val arg0 = args[0].int - val arg1 = args[1].int - Datum.integer(arg0 and arg1) -} - -internal val Fn_BITWISE_AND__INT64_INT64__INT64 = Function.static( - - name = "bitwise_and", - returns = PType.bigint(), - parameters = arrayOf( - Parameter("lhs", PType.bigint()), - Parameter("rhs", PType.bigint()), - ), - -) { args -> - val arg0 = args[0].long - val arg1 = args[1].long - Datum.bigint(arg0 and arg1) -} - -internal val Fn_BITWISE_AND__INT_INT__INT = Function.static( - - name = "bitwise_and", - returns = PType.numeric(), - parameters = arrayOf( - @Suppress("DEPRECATION") Parameter("lhs", PType.numeric()), - @Suppress("DEPRECATION") Parameter("rhs", PType.numeric()), - ), - -) { args -> - val arg0 = args[0].bigInteger - val arg1 = args[1].bigInteger - Datum.numeric(arg0 and arg1) +internal object FnBitwiseAnd : ArithmeticDiadicOperator() { + override fun getName(): String { + return "bitwise_and" + } + + override fun getTinyIntInstance(tinyIntLhs: PType, tinyIntRhs: PType): Function.Instance { + return basic(PType.tinyint()) { args -> + @Suppress("DEPRECATION") val arg0 = args[0].byte + @Suppress("DEPRECATION") val arg1 = args[1].byte + Datum.tinyint(arg0 and arg1) + } + } + + override fun getSmallIntInstance(smallIntLhs: PType, smallIntRhs: PType): Function.Instance { + return basic(PType.smallint()) { args -> + val arg0 = args[0].short + val arg1 = args[1].short + Datum.smallint(arg0 and arg1) + } + } + + override fun getIntegerInstance(integerLhs: PType, integerRhs: PType): Function.Instance { + return basic(PType.integer()) { args -> + val arg0 = args[0].int + val arg1 = args[1].int + Datum.integer(arg0 and arg1) + } + } + + override fun getBigIntInstance(bigIntLhs: PType, bigIntRhs: PType): Function.Instance { + return basic(PType.bigint()) { args -> + val arg0 = args[0].long + val arg1 = args[1].long + Datum.bigint(arg0 and arg1) + } + } + + // TODO: Probably remove this if we don't expose NUMERIC + override fun getNumericInstance(numericLhs: PType, numericRhs: PType): Function.Instance { + return basic(PType.numeric()) { args -> + val arg0 = args[0].bigInteger + val arg1 = args[1].bigInteger + Datum.numeric(arg0 and arg1) + } + } + + override fun getDecimalInstance(v1: PType, v2: PType): Function.Instance { + return getNumericInstance(v1, v2) + } + + override fun getRealInstance(realLhs: PType, realRhs: PType): Function.Instance { + return getNumericInstance(realLhs, realRhs) + } + + override fun getDoubleInstance(doubleLhs: PType, doubleRhs: PType): Function.Instance { + return getNumericInstance(doubleLhs, doubleRhs) + } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnCollAgg.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnCollAgg.kt index e75f212cb2..09a09f782e 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnCollAgg.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnCollAgg.kt @@ -32,6 +32,7 @@ internal abstract class Fn_COLL_AGG__BAG__ANY( override fun getInstance(args: Array): Function.Instance = instance private val instance = object : Function.Instance( + name, parameters = arrayOf(PType.bag()), returns = PType.dynamic(), ) { diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnIsMissing.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnIsMissing.kt index 62bd9af453..7acb5188cb 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnIsMissing.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnIsMissing.kt @@ -23,6 +23,7 @@ internal val Fn_IS_MISSING__ANY__BOOL = object : Function { * IS MISSING implementation. */ private var instance = object : Function.Instance( + name, parameters = arrayOf(PType.dynamic()), returns = PType.bool(), isNullCall = false, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnIsNull.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnIsNull.kt index 1d5c35bf47..d3d258db86 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnIsNull.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnIsNull.kt @@ -23,6 +23,7 @@ internal val Fn_IS_NULL__ANY__BOOL = object : Function { * IS NULL implementation. */ private var instance = object : Function.Instance( + name, parameters = arrayOf(PType.dynamic()), returns = PType.bool(), isNullCall = false, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnNot.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnNot.kt index 0da34462f7..c605f78000 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnNot.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnNot.kt @@ -17,6 +17,7 @@ internal val Fn_NOT__BOOL__BOOL = object : Function { private var returns = PType.bool() private var instance = object : Function.Instance( + name, parameters = arrayOf(PType.dynamic()), returns = PType.bool(), isNullCall = true, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnPlus.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnPlus.kt index 0af98f16a6..e66a7eabbf 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnPlus.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnPlus.kt @@ -8,123 +8,87 @@ import org.partiql.spi.function.Parameter import org.partiql.spi.value.Datum import org.partiql.types.PType -// TODO: Handle Overflow -internal val Fn_PLUS__INT8_INT8__INT8 = Function.static( - - name = "plus", - returns = PType.tinyint(), - parameters = arrayOf( - Parameter("lhs", PType.tinyint()), - Parameter("rhs", PType.tinyint()), - ), - -) { args -> - @Suppress("DEPRECATION") val arg0 = args[0].byte - @Suppress("DEPRECATION") val arg1 = args[1].byte - Datum.tinyint((arg0 + arg1).toByte()) -} - -internal val Fn_PLUS__INT16_INT16__INT16 = Function.static( - - name = "plus", - returns = PType.smallint(), - parameters = arrayOf( - Parameter("lhs", PType.smallint()), - Parameter("rhs", PType.smallint()), - ), - -) { args -> - val arg0 = args[0].short - val arg1 = args[1].short - Datum.smallint((arg0 + arg1).toShort()) -} - -internal val Fn_PLUS__INT32_INT32__INT32 = Function.static( - - name = "plus", - returns = PType.integer(), - parameters = arrayOf( - Parameter("lhs", PType.integer()), - Parameter("rhs", PType.integer()), - ), - -) { args -> - val arg0 = args[0].int - val arg1 = args[1].int - Datum.integer(arg0 + arg1) -} - -internal val Fn_PLUS__INT64_INT64__INT64 = Function.static( - - name = "plus", - returns = PType.bigint(), - parameters = arrayOf( - Parameter("lhs", PType.bigint()), - Parameter("rhs", PType.bigint()), - ), - -) { args -> - val arg0 = args[0].long - val arg1 = args[1].long - Datum.bigint(arg0 + arg1) -} - -internal val Fn_PLUS__INT_INT__INT = Function.static( - - name = "plus", - returns = PType.numeric(), - parameters = arrayOf( - @Suppress("DEPRECATION") Parameter("lhs", PType.numeric()), - @Suppress("DEPRECATION") Parameter("rhs", PType.numeric()), - ), - -) { args -> - val arg0 = args[0].bigInteger - val arg1 = args[1].bigInteger - Datum.numeric(arg0 + arg1) -} - -internal val Fn_PLUS__DECIMAL_ARBITRARY_DECIMAL_ARBITRARY__DECIMAL_ARBITRARY = Function.static( - - name = "plus", - returns = PType.decimal(), - parameters = arrayOf( - @Suppress("DEPRECATION") Parameter("lhs", PType.decimal()), - @Suppress("DEPRECATION") Parameter("rhs", PType.decimal()), - ), - -) { args -> - val arg0 = args[0].bigDecimal - val arg1 = args[1].bigDecimal - Datum.decimal(arg0 + arg1) -} - -internal val Fn_PLUS__FLOAT32_FLOAT32__FLOAT32 = Function.static( - - name = "plus", - returns = PType.real(), - parameters = arrayOf( - Parameter("lhs", PType.real()), - Parameter("rhs", PType.real()), - ), - -) { args -> - val arg0 = args[0].float - val arg1 = args[1].float - Datum.real(arg0 + arg1) -} - -internal val Fn_PLUS__FLOAT64_FLOAT64__FLOAT64 = Function.static( - - name = "plus", - returns = PType.doublePrecision(), - parameters = arrayOf( - Parameter("lhs", PType.doublePrecision()), - Parameter("rhs", PType.doublePrecision()), - ), - -) { args -> - val arg0 = args[0].double - val arg1 = args[1].double - Datum.doublePrecision(arg0 + arg1) +internal object FnPlus : ArithmeticDiadicOperator() { + override fun getName(): String { + return "plus" + } + + override fun getTinyIntInstance(tinyIntLhs: PType, tinyIntRhs: PType): Function.Instance { + return basic(PType.tinyint()) { args -> + @Suppress("DEPRECATION") val arg0 = args[0].byte + @Suppress("DEPRECATION") val arg1 = args[1].byte + Datum.tinyint((arg0 + arg1).toByte()) + } + } + + override fun getSmallIntInstance(smallIntLhs: PType, smallIntRhs: PType): Function.Instance { + return basic(PType.smallint()) { args -> + val arg0 = args[0].short + val arg1 = args[1].short + Datum.smallint((arg0 + arg1).toShort()) + } + } + + override fun getIntegerInstance(integerLhs: PType, integerRhs: PType): Function.Instance { + return basic(PType.integer()) { args -> + val arg0 = args[0].int + val arg1 = args[1].int + Datum.integer(arg0 + arg1) + } + } + + override fun getBigIntInstance(bigIntLhs: PType, bigIntRhs: PType): Function.Instance { + return basic(PType.bigint()) { args -> + val arg0 = args[0].long + val arg1 = args[1].long + Datum.bigint(arg0 + arg1) + } + } + + // TODO: Probably remove this if we don't expose NUMERIC + override fun getNumericInstance(numericLhs: PType, numericRhs: PType): Function.Instance { + return basic(PType.numeric()) { args -> + val arg0 = args[0].bigInteger + val arg1 = args[1].bigInteger + Datum.numeric(arg0 + arg1) + } + } + + /** + * Precision and scale calculation: + * P = max(s1, s2) + max(p1 - s1, p2 - s2) + 1 + * S = max(s1, s2) + */ + override fun getDecimalInstance(v1: PType, v2: PType): Function.Instance { + val p = Math.min(38, Math.max(v1.scale, v2.scale) + Math.max(v1.precision - v1.scale, v2.precision - v2.scale) + 1) + val s = Math.min(38, Math.max(v1.scale, v2.scale)) + return Function.instance( + name = "plus", + returns = PType.decimal(p, s), + parameters = arrayOf( + Parameter("lhs", v1), + Parameter("rhs", v2), + ) + ) { args -> + val arg0 = args[0].bigDecimal + val arg1 = args[1].bigDecimal + Datum.decimal(arg0 + arg1, p, s) + } + } + + override fun getRealInstance(realLhs: PType, realRhs: PType): Function.Instance { + return basic(PType.real()) { args -> + val arg0 = args[0].float + val arg1 = args[1].float + Datum.real(arg0 + arg1) + } + } + + override fun getDoubleInstance(doubleLhs: PType, doubleRhs: PType): Function.Instance { + return basic(PType.doublePrecision()) { args -> + val arg0 = args[0].double + val arg1 = args[1].double + Datum.doublePrecision(arg0 + arg1) + } + } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/TypePrecedence.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/TypePrecedence.kt new file mode 100644 index 0000000000..6ce73b6c3e --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/TypePrecedence.kt @@ -0,0 +1,39 @@ +package org.partiql.spi.function.builtins + +import org.partiql.types.PType.Kind + +/** + * @return the precedence of the types for the PartiQL comparator. + * @see .TYPE_PRECEDENCE + */ +@Suppress("deprecation") +internal val TYPE_PRECEDENCE: Map = listOf( + Kind.UNKNOWN, + Kind.BOOL, + Kind.TINYINT, + Kind.SMALLINT, + Kind.INTEGER, + Kind.BIGINT, + Kind.NUMERIC, + Kind.DECIMAL, + Kind.REAL, + Kind.DOUBLE, + Kind.DECIMAL_ARBITRARY, // Arbitrary precision decimal has a higher precedence than FLOAT + Kind.CHAR, + Kind.VARCHAR, + Kind.SYMBOL, + Kind.STRING, + Kind.CLOB, + Kind.BLOB, + Kind.DATE, + Kind.TIME, + Kind.TIMEZ, + Kind.TIMESTAMP, + Kind.TIMESTAMPZ, + Kind.ARRAY, + Kind.SEXP, + Kind.BAG, + Kind.ROW, + Kind.STRUCT, + Kind.DYNAMIC +).mapIndexed { precedence, type -> type to precedence }.toMap() diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/internal/SqlTypeFamily.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/internal/SqlTypeFamily.kt index a421f02cba..7c624a2f06 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/internal/SqlTypeFamily.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/internal/SqlTypeFamily.kt @@ -61,7 +61,7 @@ internal class SqlTypeFamily private constructor( ) @JvmStatic - val NUMERIC = SqlTypeFamily( + val NUMBER = SqlTypeFamily( preferred = PType.decimal(), members = setOf( Kind.TINYINT, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/internal/SqlTypes.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/internal/SqlTypes.kt index 68818c85d9..550c751f60 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/internal/SqlTypes.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/internal/SqlTypes.kt @@ -119,7 +119,7 @@ internal object SqlTypes { * ``` */ private fun areAssignableNumberTypes(input: PType, target: PType): Boolean { - return input in SqlTypeFamily.NUMERIC && target in SqlTypeFamily.NUMERIC + return input in SqlTypeFamily.NUMBER && target in SqlTypeFamily.NUMBER } /** diff --git a/partiql-types/src/main/java/org/partiql/types/PType.java b/partiql-types/src/main/java/org/partiql/types/PType.java index 93f8869ce9..31b0f53a61 100644 --- a/partiql-types/src/main/java/org/partiql/types/PType.java +++ b/partiql-types/src/main/java/org/partiql/types/PType.java @@ -525,12 +525,12 @@ static PType numeric() { } /** - * @return a PartiQL decimal (arbitrary precision/scale) type + * @return a PartiQL decimal type * @deprecated this API is experimental and is subject to modification/deletion without prior notice. */ @NotNull static PType decimal() { - return new PTypePrimitive(Kind.DECIMAL_ARBITRARY); + return new PTypeDecimal(38, 0); } /**