Skip to content

Commit

Permalink
Add T_DIST scalar function
Browse files Browse the repository at this point in the history
  • Loading branch information
wendigo committed Jun 26, 2024
1 parent 88e4a5c commit b2e95ae
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
import io.trino.operator.scalar.ConcatWsFunction;
import io.trino.operator.scalar.DataSizeFunctions;
import io.trino.operator.scalar.DateTimeFunctions;
import io.trino.operator.scalar.DistributionFunctions;
import io.trino.operator.scalar.EmptyMapConstructor;
import io.trino.operator.scalar.FailureFunction;
import io.trino.operator.scalar.FormatNumberFunction;
Expand Down Expand Up @@ -441,6 +442,7 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
.scalars(BitwiseFunctions.class)
.scalars(DateTimeFunctions.class)
.scalar(DateTimeFunctions.FromUnixtimeNanosDecimal.class)
.scalars(DistributionFunctions.class)
.scalars(JsonFunctions.class)
.scalars(JsonInputFunctions.class)
.scalars(JsonOutputFunctions.class)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.scalar;

import io.trino.spi.TrinoException;
import io.trino.spi.function.ScalarFunction;
import io.trino.spi.function.SqlType;
import org.apache.commons.math3.distribution.TDistribution;

import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.type.StandardTypes.BOOLEAN;
import static io.trino.spi.type.StandardTypes.DOUBLE;
import static io.trino.spi.type.StandardTypes.INTEGER;

public class DistributionFunctions
{
private DistributionFunctions() {}

@ScalarFunction("t_dist")
@SqlType(DOUBLE)
public static double tDistribution(@SqlType(DOUBLE) double x, @SqlType(INTEGER) long degreesOfFreedom, @SqlType(BOOLEAN) boolean cumulative)
{
if (degreesOfFreedom < 1) {
throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "degrees_of_freedom must be greater than or equal to 1");
}

TDistribution tDistribution = new TDistribution(degreesOfFreedom);
return cumulative ? tDistribution.cumulativeProbability(x) : tDistribution.density(x);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.scalar;

import io.trino.sql.query.QueryAssertions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.parallel.Execution;

import static io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;
import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT;

@TestInstance(PER_CLASS)
@Execution(CONCURRENT)
public class TestDistributionFunctions
{
private QueryAssertions assertions;

@BeforeAll
public void init()
{
assertions = new QueryAssertions();
}

@AfterAll
public void teardown()
{
assertions.close();
assertions = null;
}

@Test
public void testTDist()
{
assertThat(assertions.function("t_dist", "60", "1", "true"))
.isEqualTo(0.9946953263673767);
assertThat(assertions.function("t_dist", "8", "3", "false"))
.isEqualTo(7.369065209469264E-4);
assertTrinoExceptionThrownBy(assertions.function("t_dist", "60", "0", "true")::evaluate)
.hasMessage("degrees_of_freedom must be greater than or equal to 1");
}
}

0 comments on commit b2e95ae

Please sign in to comment.