From 4c331b54c60d0bf88ed68b05b40db35126da8499 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Thu, 18 Jan 2024 15:39:07 -0500 Subject: [PATCH] Add experimental SIMD implementation of B-tree to round down dates (#11194) (#11924) * Add experimental SIMD implementation of B-tree to round down dates * Use system properties in favor of feature flags to remove dependency on the server module * Removed Java 20 test sources to simplify builds * Remove the use of forbidden APIs in unit-tests * Migrate to the recommended usage for custom test task classpath * Switch benchmarks module to multi-release one * Add JMH annotation processing for JDK-20+ sources * Make JMH annotations consistent across sources * Improve execution of Roundable unit-tests * Revert "Improve execution of Roundable unit-tests" This reverts commit 2e82d0aa7b114000bcf58f9485377eb5705ba3b1. * Add 'forced' as a possible feature flag value to simplify the execution of unit-tests --------- (cherry picked from commit b424aafdffb09cbc56cbc436bb1f4073b2f9f079) Signed-off-by: Ketan Verma Signed-off-by: Andriy Redko Signed-off-by: github-actions[bot] Co-authored-by: github-actions[bot] Co-authored-by: Andriy Redko Co-authored-by: Andrew Ross --- CHANGELOG.md | 1 + benchmarks/build.gradle | 42 ++++++++ .../common/round/RoundableBenchmark.java | 18 +--- .../common/round/RoundableSupplier.java | 35 ++++++ .../common/round/RoundableSupplier.java | 36 +++++++ libs/common/build.gradle | 61 +++++++++++ .../common/round/BtreeSearcher.java | 100 ++++++++++++++++++ .../common/round/RoundableFactory.java | 57 ++++++++++ .../common/round/RoundableTests.java | 54 ++++++---- 9 files changed, 370 insertions(+), 34 deletions(-) create mode 100644 benchmarks/src/main/java/org/opensearch/common/round/RoundableSupplier.java create mode 100644 benchmarks/src/main/java20/org/opensearch/common/round/RoundableSupplier.java create mode 100644 libs/common/src/main/java20/org/opensearch/common/round/BtreeSearcher.java create mode 100644 libs/common/src/main/java20/org/opensearch/common/round/RoundableFactory.java diff --git a/CHANGELOG.md b/CHANGELOG.md index b2a324fb6ca79..e82877ff8d6bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Request level coordinator slow logs ([#11246](https://github.com/opensearch-project/OpenSearch/pull/11246)) - Add template snippets support for field and target_field in KV ingest processor ([#10040](https://github.com/opensearch-project/OpenSearch/pull/10040)) - Allowing pipeline processors to access index mapping info by passing ingest service ref as part of the processor factory parameters ([#10307](https://github.com/opensearch-project/OpenSearch/pull/10307)) +- Add experimental SIMD implementation of B-tree to round down dates ([#11194](https://github.com/opensearch-project/OpenSearch/issues/11194)) - Make number of segment metadata files in remote segment store configurable ([#11329](https://github.com/opensearch-project/OpenSearch/pull/11329)) - Allow changing number of replicas of searchable snapshot index ([#11317](https://github.com/opensearch-project/OpenSearch/pull/11317)) - [BWC and API enforcement] Introduce checks for enforcing the API restrictions ([#11175](https://github.com/opensearch-project/OpenSearch/pull/11175)) diff --git a/benchmarks/build.gradle b/benchmarks/build.gradle index 6b4634c7e791c..be4579b4e5324 100644 --- a/benchmarks/build.gradle +++ b/benchmarks/build.gradle @@ -84,3 +84,45 @@ spotless { targetExclude 'src/main/generated/**/*.java' } } + +if (BuildParams.runtimeJavaVersion >= JavaVersion.VERSION_20) { + // Add support for incubator modules on supported Java versions. + run.jvmArgs += ['--add-modules=jdk.incubator.vector'] + run.classpath += files(jar.archiveFile) + run.classpath -= sourceSets.main.output + evaluationDependsOn(':libs:opensearch-common') + + sourceSets { + java20 { + java { + srcDirs = ['src/main/java20'] + } + } + } + + configurations { + java20Implementation.extendsFrom(implementation) + } + + dependencies { + java20Implementation sourceSets.main.output + java20Implementation project(':libs:opensearch-common').sourceSets.java20.output + java20AnnotationProcessor "org.openjdk.jmh:jmh-generator-annprocess:$versions.jmh" + } + + compileJava20Java { + targetCompatibility = JavaVersion.VERSION_20 + options.compilerArgs.addAll(["-processor", "org.openjdk.jmh.generators.BenchmarkProcessor"]) + } + + jar { + metaInf { + into 'versions/20' + from sourceSets.java20.output + } + manifest.attributes('Multi-Release': 'true') + } + + // classes generated by JMH can use all sorts of forbidden APIs but we have no influence at all and cannot exclude these classes + disableTasks('forbiddenApisJava20') +} diff --git a/benchmarks/src/main/java/org/opensearch/common/round/RoundableBenchmark.java b/benchmarks/src/main/java/org/opensearch/common/round/RoundableBenchmark.java index 4e07af452968b..3909a3f4eb8fc 100644 --- a/benchmarks/src/main/java/org/opensearch/common/round/RoundableBenchmark.java +++ b/benchmarks/src/main/java/org/opensearch/common/round/RoundableBenchmark.java @@ -21,7 +21,6 @@ import org.openjdk.jmh.infra.Blackhole; import java.util.Random; -import java.util.function.Supplier; @Fork(value = 3) @Warmup(iterations = 3, time = 1) @@ -83,17 +82,17 @@ public static class Options { "256" }) public Integer size; - @Param({ "binary", "linear" }) + @Param({ "binary", "linear", "btree" }) public String type; @Param({ "uniform", "skewed_edge", "skewed_center" }) public String distribution; public long[] queries; - public Supplier supplier; + public RoundableSupplier supplier; @Setup - public void setup() { + public void setup() throws ClassNotFoundException { Random random = new Random(size); long[] values = new long[size]; for (int i = 1; i < values.length; i++) { @@ -128,16 +127,7 @@ public void setup() { throw new IllegalArgumentException("invalid distribution: " + distribution); } - switch (type) { - case "binary": - supplier = () -> new BinarySearcher(values, size); - break; - case "linear": - supplier = () -> new BidirectionalLinearSearcher(values, size); - break; - default: - throw new IllegalArgumentException("invalid type: " + type); - } + supplier = new RoundableSupplier(type, values, size); } private static long nextPositiveLong(Random random) { diff --git a/benchmarks/src/main/java/org/opensearch/common/round/RoundableSupplier.java b/benchmarks/src/main/java/org/opensearch/common/round/RoundableSupplier.java new file mode 100644 index 0000000000000..44ac42810996f --- /dev/null +++ b/benchmarks/src/main/java/org/opensearch/common/round/RoundableSupplier.java @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.round; + +import java.util.function.Supplier; + +public class RoundableSupplier implements Supplier { + private final Supplier delegate; + + RoundableSupplier(String type, long[] values, int size) throws ClassNotFoundException { + switch (type) { + case "binary": + delegate = () -> new BinarySearcher(values, size); + break; + case "linear": + delegate = () -> new BidirectionalLinearSearcher(values, size); + break; + case "btree": + throw new ClassNotFoundException("BtreeSearcher is not supported below JDK 20"); + default: + throw new IllegalArgumentException("invalid type: " + type); + } + } + + @Override + public Roundable get() { + return delegate.get(); + } +} diff --git a/benchmarks/src/main/java20/org/opensearch/common/round/RoundableSupplier.java b/benchmarks/src/main/java20/org/opensearch/common/round/RoundableSupplier.java new file mode 100644 index 0000000000000..e81c1b137bd30 --- /dev/null +++ b/benchmarks/src/main/java20/org/opensearch/common/round/RoundableSupplier.java @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.round; + +import java.util.function.Supplier; + +public class RoundableSupplier implements Supplier { + private final Supplier delegate; + + RoundableSupplier(String type, long[] values, int size) { + switch (type) { + case "binary": + delegate = () -> new BinarySearcher(values, size); + break; + case "linear": + delegate = () -> new BidirectionalLinearSearcher(values, size); + break; + case "btree": + delegate = () -> new BtreeSearcher(values, size); + break; + default: + throw new IllegalArgumentException("invalid type: " + type); + } + } + + @Override + public Roundable get() { + return delegate.get(); + } +} diff --git a/libs/common/build.gradle b/libs/common/build.gradle index 4f89b81636420..60bf488833393 100644 --- a/libs/common/build.gradle +++ b/libs/common/build.gradle @@ -43,3 +43,64 @@ tasks.named('forbiddenApisMain').configure { // TODO: Need to decide how we want to handle for forbidden signatures with the changes to server replaceSignatureFiles 'jdk-signatures' } + +// Add support for incubator modules on supported Java versions. +if (BuildParams.runtimeJavaVersion >= JavaVersion.VERSION_20) { + sourceSets { + java20 { + java { + srcDirs = ['src/main/java20'] + } + } + } + + configurations { + java20Implementation.extendsFrom(implementation) + } + + dependencies { + java20Implementation sourceSets.main.output + } + + compileJava20Java { + targetCompatibility = JavaVersion.VERSION_20 + options.compilerArgs += ['--add-modules', 'jdk.incubator.vector'] + options.compilerArgs -= '-Werror' // use of incubator modules is reported as a warning + } + + jar { + metaInf { + into 'versions/20' + from sourceSets.java20.output + } + manifest.attributes('Multi-Release': 'true') + } + + tasks.withType(Test).configureEach { + // Relying on the convention for Test.classpath in custom Test tasks has been deprecated + // and scheduled to be removed in Gradle 9.0. Below lines are added from the migration guide: + // https://docs.gradle.org/8.5/userguide/upgrading_version_8.html#test_task_default_classpath + testClassesDirs = testing.suites.test.sources.output.classesDirs + classpath = testing.suites.test.sources.runtimeClasspath + + // Adds the multi-release JAR to the classpath when executing tests. + // This allows newer sources to be picked up at test runtime (if supported). + classpath += files(jar.archiveFile) + // Removes the "main" sources from the classpath to avoid JarHell problems as + // the multi-release JAR already contains those classes. + classpath -= sourceSets.main.output + } + + tasks.register('roundableSimdTest', Test) { + group 'verification' + include '**/RoundableTests.class' + systemProperty 'opensearch.experimental.feature.simd.rounding.enabled', 'forced' + } + + check.dependsOn(roundableSimdTest) + + forbiddenApisJava20 { + failOnMissingClasses = false + ignoreSignaturesOfMissingClasses = true + } +} diff --git a/libs/common/src/main/java20/org/opensearch/common/round/BtreeSearcher.java b/libs/common/src/main/java20/org/opensearch/common/round/BtreeSearcher.java new file mode 100644 index 0000000000000..626fb6e6b810e --- /dev/null +++ b/libs/common/src/main/java20/org/opensearch/common/round/BtreeSearcher.java @@ -0,0 +1,100 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.round; + +import org.opensearch.common.annotation.InternalApi; + +import jdk.incubator.vector.LongVector; +import jdk.incubator.vector.Vector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorSpecies; + +/** + * It uses vectorized B-tree search to find the round-down point. + * + * @opensearch.internal + */ +@InternalApi +class BtreeSearcher implements Roundable { + private static final VectorSpecies LONG_VECTOR_SPECIES = LongVector.SPECIES_PREFERRED; + private static final int LANES = LONG_VECTOR_SPECIES.length(); + private static final int SHIFT = log2(LANES); + + private final long[] values; + private final long minValue; + + BtreeSearcher(long[] values, int size) { + if (size <= 0) { + throw new IllegalArgumentException("at least one value must be present"); + } + + int blocks = (size + LANES - 1) / LANES; // number of blocks + int length = 1 + blocks * LANES; // size of the backing array (1-indexed) + + this.minValue = values[0]; + this.values = new long[length]; + build(values, 0, size, this.values, 1); + } + + /** + * Builds the B-tree memory layout. + * It builds the tree recursively, following an in-order traversal. + * + *

+ * Each block stores 'lanes' values at indices {@code i, i + 1, ..., i + lanes - 1} where {@code i} is the + * starting offset. The starting offset of the root block is 1. The branching factor is (1 + lanes) so each + * block can have these many children. Given the starting offset {@code i} of a block, the starting offset + * of its k-th child (ranging from {@code 0, 1, ..., k}) can be computed as {@code i + ((i + k) << shift)}. + * + * @param src is the sorted input array + * @param i is the index in the input array to read the value from + * @param size the number of values in the input array + * @param dst is the output array + * @param j is the index in the output array to write the value to + * @return the next index 'i' + */ + private static int build(long[] src, int i, int size, long[] dst, int j) { + if (j < dst.length) { + for (int k = 0; k < LANES; k++) { + i = build(src, i, size, dst, j + ((j + k) << SHIFT)); + + // Fills the B-tree as a complete tree, i.e., all levels are completely filled, + // except the last level which is filled from left to right. + // The trick is to fill the destination array between indices 1...size (inclusive / 1-indexed) + // and pad the remaining array with +infinity. + dst[j + k] = (j + k <= size) ? src[i++] : Long.MAX_VALUE; + } + i = build(src, i, size, dst, j + ((j + LANES) << SHIFT)); + } + return i; + } + + @Override + public long floor(long key) { + Vector keyVector = LongVector.broadcast(LONG_VECTOR_SPECIES, key); + int i = 1, result = 1; + + while (i < values.length) { + Vector valuesVector = LongVector.fromArray(LONG_VECTOR_SPECIES, values, i); + int j = i + valuesVector.compare(VectorOperators.GT, keyVector).firstTrue(); + result = (j > i) ? j : result; + i += (j << SHIFT); + } + + assert result > 1 : "key must be greater than or equal to " + minValue; + return values[result - 1]; + } + + private static int log2(int num) { + if ((num <= 0) || ((num & (num - 1)) != 0)) { + throw new IllegalArgumentException(num + " is not a positive power of 2"); + } + return 32 - Integer.numberOfLeadingZeros(num - 1); + } +} diff --git a/libs/common/src/main/java20/org/opensearch/common/round/RoundableFactory.java b/libs/common/src/main/java20/org/opensearch/common/round/RoundableFactory.java new file mode 100644 index 0000000000000..c5ec45bceb5bd --- /dev/null +++ b/libs/common/src/main/java20/org/opensearch/common/round/RoundableFactory.java @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.round; + +import org.opensearch.common.annotation.InternalApi; + +import jdk.incubator.vector.LongVector; + +/** + * Factory class to create and return the fastest implementation of {@link Roundable}. + * + * @opensearch.internal + */ +@InternalApi +public final class RoundableFactory { + /** + * The maximum limit up to which linear search is used, otherwise binary or B-tree search is used. + * This is because linear search is much faster on small arrays. + * Benchmark results: PR #9727 + */ + private static final int LINEAR_SEARCH_MAX_SIZE = 64; + + /** + * Indicates whether the vectorized (SIMD) B-tree search implementation is to be used. + * It is true when either: + * 1. The feature flag is set to "forced", or + * 2. The platform has a minimum of 4 long vector lanes and the feature flag is set to "true". + */ + private static final boolean USE_BTREE_SEARCHER; + + static { + String simdRoundingFeatureFlag = System.getProperty("opensearch.experimental.feature.simd.rounding.enabled"); + USE_BTREE_SEARCHER = "forced".equalsIgnoreCase(simdRoundingFeatureFlag) + || (LongVector.SPECIES_PREFERRED.length() >= 4 && "true".equalsIgnoreCase(simdRoundingFeatureFlag)); + } + + private RoundableFactory() {} + + /** + * Creates and returns the fastest implementation of {@link Roundable}. + */ + public static Roundable create(long[] values, int size) { + if (size <= LINEAR_SEARCH_MAX_SIZE) { + return new BidirectionalLinearSearcher(values, size); + } else if (USE_BTREE_SEARCHER) { + return new BtreeSearcher(values, size); + } else { + return new BinarySearcher(values, size); + } + } +} diff --git a/libs/common/src/test/java/org/opensearch/common/round/RoundableTests.java b/libs/common/src/test/java/org/opensearch/common/round/RoundableTests.java index ae9f629c59024..ad19f456b0df4 100644 --- a/libs/common/src/test/java/org/opensearch/common/round/RoundableTests.java +++ b/libs/common/src/test/java/org/opensearch/common/round/RoundableTests.java @@ -12,15 +12,31 @@ public class RoundableTests extends OpenSearchTestCase { - public void testFloor() { - int size = randomIntBetween(1, 256); - long[] values = new long[size]; - for (int i = 1; i < values.length; i++) { - values[i] = values[i - 1] + (randomNonNegativeLong() % 200) + 1; - } + public void testRoundingEmptyArray() { + Throwable throwable = assertThrows(IllegalArgumentException.class, () -> RoundableFactory.create(new long[0], 0)); + assertEquals("at least one value must be present", throwable.getMessage()); + } + + public void testRoundingSmallArray() { + int size = randomIntBetween(1, 64); + long[] values = randomArrayOfSortedValues(size); + Roundable roundable = RoundableFactory.create(values, size); + + assertEquals("BidirectionalLinearSearcher", roundable.getClass().getSimpleName()); + assertRounding(roundable, values, size); + } - Roundable[] impls = { new BinarySearcher(values, size), new BidirectionalLinearSearcher(values, size) }; + public void testRoundingLargeArray() { + int size = randomIntBetween(65, 256); + long[] values = randomArrayOfSortedValues(size); + Roundable roundable = RoundableFactory.create(values, size); + boolean useBtreeSearcher = "forced".equalsIgnoreCase(System.getProperty("opensearch.experimental.feature.simd.rounding.enabled")); + assertEquals(useBtreeSearcher ? "BtreeSearcher" : "BinarySearcher", roundable.getClass().getSimpleName()); + assertRounding(roundable, values, size); + } + + private void assertRounding(Roundable roundable, long[] values, int size) { for (int i = 0; i < 100000; i++) { // Index of the expected round-down point. int idx = randomIntBetween(0, size - 1); @@ -35,23 +51,21 @@ public void testFloor() { // round-down point, which will still floor to the same value. long key = expected + (randomNonNegativeLong() % delta); - for (Roundable roundable : impls) { - assertEquals(expected, roundable.floor(key)); - } + assertEquals(expected, roundable.floor(key)); } + + Throwable throwable = assertThrows(AssertionError.class, () -> roundable.floor(values[0] - 1)); + assertEquals("key must be greater than or equal to " + values[0], throwable.getMessage()); } - public void testFailureCases() { - Throwable throwable; + private static long[] randomArrayOfSortedValues(int size) { + int capacity = size + randomInt(20); // May be slightly more than the size. + long[] values = new long[capacity]; - throwable = assertThrows(IllegalArgumentException.class, () -> new BinarySearcher(new long[0], 0)); - assertEquals("at least one value must be present", throwable.getMessage()); - throwable = assertThrows(IllegalArgumentException.class, () -> new BidirectionalLinearSearcher(new long[0], 0)); - assertEquals("at least one value must be present", throwable.getMessage()); + for (int i = 1; i < size; i++) { + values[i] = values[i - 1] + (randomNonNegativeLong() % 200) + 1; + } - throwable = assertThrows(AssertionError.class, () -> new BinarySearcher(new long[] { 100 }, 1).floor(50)); - assertEquals("key must be greater than or equal to 100", throwable.getMessage()); - throwable = assertThrows(AssertionError.class, () -> new BidirectionalLinearSearcher(new long[] { 100 }, 1).floor(50)); - assertEquals("key must be greater than or equal to 100", throwable.getMessage()); + return values; } }