diff --git a/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java b/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java index af37a8db58ade3..839bd4f6f4bc1d 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java +++ b/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java @@ -125,6 +125,7 @@ import io.trino.operator.scalar.ConcatWsFunction; import io.trino.operator.scalar.DataSizeFunctions; import io.trino.operator.scalar.DateTimeFunctions; +import io.trino.operator.scalar.DistributionFunctions; import io.trino.operator.scalar.EmptyMapConstructor; import io.trino.operator.scalar.FailureFunction; import io.trino.operator.scalar.FormatNumberFunction; @@ -441,6 +442,7 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators .scalars(BitwiseFunctions.class) .scalars(DateTimeFunctions.class) .scalar(DateTimeFunctions.FromUnixtimeNanosDecimal.class) + .scalars(DistributionFunctions.class) .scalars(JsonFunctions.class) .scalars(JsonInputFunctions.class) .scalars(JsonOutputFunctions.class) diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/DistributionFunctions.java b/core/trino-main/src/main/java/io/trino/operator/scalar/DistributionFunctions.java new file mode 100644 index 00000000000000..76dfd3b93238dd --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/DistributionFunctions.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.scalar; + +import io.trino.spi.TrinoException; +import io.trino.spi.function.ScalarFunction; +import io.trino.spi.function.SqlType; +import org.apache.commons.math3.distribution.TDistribution; + +import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.trino.spi.type.StandardTypes.BOOLEAN; +import static io.trino.spi.type.StandardTypes.DOUBLE; +import static io.trino.spi.type.StandardTypes.INTEGER; + +public class DistributionFunctions +{ + private DistributionFunctions() {} + + @ScalarFunction("t_dist") + @SqlType(DOUBLE) + public static double tDistribution(@SqlType(DOUBLE) double x, @SqlType(INTEGER) long degreesOfFreedom, @SqlType(BOOLEAN) boolean cumulative) + { + if (degreesOfFreedom < 1) { + throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "degrees_of_freedom must be greater than or equal to 1"); + } + + TDistribution tDistribution = new TDistribution(degreesOfFreedom); + return cumulative ? tDistribution.cumulativeProbability(x) : tDistribution.density(x); + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestDistributionFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestDistributionFunctions.java new file mode 100644 index 00000000000000..d89a3ff07c84ad --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestDistributionFunctions.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.scalar; + +import io.trino.sql.query.QueryAssertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import static io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; + +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) +public class TestDistributionFunctions +{ + private QueryAssertions assertions; + + @BeforeAll + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } + + @Test + public void testTDist() + { + assertThat(assertions.function("t_dist", "60", "1", "true")) + .isEqualTo(0.9946953263673767); + assertThat(assertions.function("t_dist", "8", "3", "false")) + .isEqualTo(7.369065209469264E-4); + assertTrinoExceptionThrownBy(assertions.function("t_dist", "60", "0", "true")::evaluate) + .hasMessage("degrees_of_freedom must be greater than or equal to 1"); + } +}