Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add experimental SIMD implementation of B-tree to round down dates #11194

Merged
merged 15 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Request level coordinator slow logs ([#10650](https://github.com/opensearch-project/OpenSearch/pull/10650))
- Add template snippets support for field and target_field in KV ingest processor ([#10040](https://github.com/opensearch-project/OpenSearch/pull/10040))
- Allowing pipeline processors to access index mapping info by passing ingest service ref as part of the processor factory parameters ([#10307](https://github.com/opensearch-project/OpenSearch/pull/10307))
- Add experimental SIMD implementation of B-tree to round down dates ([#11194](https://github.com/opensearch-project/OpenSearch/issues/11194))
- Make number of segment metadata files in remote segment store configurable ([#11329](https://github.com/opensearch-project/OpenSearch/pull/11329))
- Allow changing number of replicas of searchable snapshot index ([#11317](https://github.com/opensearch-project/OpenSearch/pull/11317))
- Adding slf4j license header to LoggerMessageFormat.java ([#11069](https://github.com/opensearch-project/OpenSearch/pull/11069))
Expand Down
42 changes: 42 additions & 0 deletions benchmarks/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,45 @@ spotless {
targetExclude 'src/main/generated/**/*.java'
}
}

if (BuildParams.runtimeJavaVersion >= JavaVersion.VERSION_20) {
// Add support for incubator modules on supported Java versions.
run.jvmArgs += ['--add-modules=jdk.incubator.vector']
run.classpath += files(jar.archiveFile)
run.classpath -= sourceSets.main.output
evaluationDependsOn(':libs:opensearch-common')

sourceSets {
java20 {
java {
srcDirs = ['src/main/java20']
}
}
}

configurations {
java20Implementation.extendsFrom(implementation)
}

dependencies {
java20Implementation sourceSets.main.output
java20Implementation project(':libs:opensearch-common').sourceSets.java20.output
java20AnnotationProcessor "org.openjdk.jmh:jmh-generator-annprocess:$versions.jmh"
}

compileJava20Java {
targetCompatibility = JavaVersion.VERSION_20
options.compilerArgs.addAll(["-processor", "org.openjdk.jmh.generators.BenchmarkProcessor"])
}

jar {
metaInf {
into 'versions/20'
from sourceSets.java20.output
}
manifest.attributes('Multi-Release': 'true')
}

// classes generated by JMH can use all sorts of forbidden APIs but we have no influence at all and cannot exclude these classes
disableTasks('forbiddenApisJava20')
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.openjdk.jmh.infra.Blackhole;

import java.util.Random;
import java.util.function.Supplier;

@Fork(value = 3)
@Warmup(iterations = 3, time = 1)
Expand Down Expand Up @@ -83,17 +82,17 @@ public static class Options {
"256" })
public Integer size;

@Param({ "binary", "linear" })
@Param({ "binary", "linear", "btree" })
public String type;

@Param({ "uniform", "skewed_edge", "skewed_center" })
public String distribution;

public long[] queries;
public Supplier<Roundable> supplier;
public RoundableSupplier supplier;

@Setup
public void setup() {
public void setup() throws ClassNotFoundException {
Random random = new Random(size);
long[] values = new long[size];
for (int i = 1; i < values.length; i++) {
Expand Down Expand Up @@ -128,16 +127,7 @@ public void setup() {
throw new IllegalArgumentException("invalid distribution: " + distribution);
}

switch (type) {
case "binary":
supplier = () -> new BinarySearcher(values, size);
break;
case "linear":
supplier = () -> new BidirectionalLinearSearcher(values, size);
break;
default:
throw new IllegalArgumentException("invalid type: " + type);
}
supplier = new RoundableSupplier(type, values, size);
}

private static long nextPositiveLong(Random random) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.common.round;

import java.util.function.Supplier;

public class RoundableSupplier implements Supplier<Roundable> {
private final Supplier<Roundable> delegate;

RoundableSupplier(String type, long[] values, int size) throws ClassNotFoundException {
switch (type) {
case "binary":
delegate = () -> new BinarySearcher(values, size);
break;
case "linear":
delegate = () -> new BidirectionalLinearSearcher(values, size);
break;
case "btree":
throw new ClassNotFoundException("BtreeSearcher is not supported below JDK 20");
default:
throw new IllegalArgumentException("invalid type: " + type);
}
}

@Override
public Roundable get() {
return delegate.get();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.common.round;

import java.util.function.Supplier;

public class RoundableSupplier implements Supplier<Roundable> {
private final Supplier<Roundable> delegate;

RoundableSupplier(String type, long[] values, int size) {
switch (type) {
case "binary":
delegate = () -> new BinarySearcher(values, size);
break;
case "linear":
delegate = () -> new BidirectionalLinearSearcher(values, size);
break;
case "btree":
delegate = () -> new BtreeSearcher(values, size);
break;
default:
throw new IllegalArgumentException("invalid type: " + type);
}
}

@Override
public Roundable get() {
return delegate.get();
}
}
61 changes: 61 additions & 0 deletions libs/common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,64 @@ tasks.named('forbiddenApisMain').configure {
// TODO: Need to decide how we want to handle for forbidden signatures with the changes to server
replaceSignatureFiles 'jdk-signatures'
}

// Add support for incubator modules on supported Java versions.
reta marked this conversation as resolved.
Show resolved Hide resolved
if (BuildParams.runtimeJavaVersion >= JavaVersion.VERSION_20) {
sourceSets {
java20 {
java {
srcDirs = ['src/main/java20']
}
}
}

configurations {
java20Implementation.extendsFrom(implementation)
}

dependencies {
java20Implementation sourceSets.main.output
}

compileJava20Java {
targetCompatibility = JavaVersion.VERSION_20
options.compilerArgs += ['--add-modules', 'jdk.incubator.vector']
options.compilerArgs -= '-Werror' // use of incubator modules is reported as a warning
}

jar {
metaInf {
into 'versions/20'
from sourceSets.java20.output
}
manifest.attributes('Multi-Release': 'true')
}

tasks.withType(Test).configureEach {
// Relying on the convention for Test.classpath in custom Test tasks has been deprecated
// and scheduled to be removed in Gradle 9.0. Below lines are added from the migration guide:
// https://docs.gradle.org/8.5/userguide/upgrading_version_8.html#test_task_default_classpath
testClassesDirs = testing.suites.test.sources.output.classesDirs
classpath = testing.suites.test.sources.runtimeClasspath

// Adds the multi-release JAR to the classpath when executing tests.
// This allows newer sources to be picked up at test runtime (if supported).
classpath += files(jar.archiveFile)
// Removes the "main" sources from the classpath to avoid JarHell problems as
// the multi-release JAR already contains those classes.
classpath -= sourceSets.main.output
}

tasks.register('roundableSimdTest', Test) {
group 'verification'
include '**/RoundableTests.class'
systemProperty 'opensearch.experimental.feature.simd.rounding.enabled', 'forced'
}

check.dependsOn(roundableSimdTest)

forbiddenApisJava20 {
ketanv3 marked this conversation as resolved.
Show resolved Hide resolved
failOnMissingClasses = false
ignoreSignaturesOfMissingClasses = true
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.common.round;

import org.opensearch.common.annotation.InternalApi;

import jdk.incubator.vector.LongVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorSpecies;

/**
* It uses vectorized B-tree search to find the round-down point.
*
* @opensearch.internal
*/
@InternalApi
class BtreeSearcher implements Roundable {
reta marked this conversation as resolved.
Show resolved Hide resolved
private static final VectorSpecies<Long> LONG_VECTOR_SPECIES = LongVector.SPECIES_PREFERRED;
private static final int LANES = LONG_VECTOR_SPECIES.length();
private static final int SHIFT = log2(LANES);

private final long[] values;
private final long minValue;

BtreeSearcher(long[] values, int size) {
if (size <= 0) {
throw new IllegalArgumentException("at least one value must be present");
}

int blocks = (size + LANES - 1) / LANES; // number of blocks
int length = 1 + blocks * LANES; // size of the backing array (1-indexed)

this.minValue = values[0];
this.values = new long[length];
build(values, 0, size, this.values, 1);
}

/**
* Builds the B-tree memory layout.
* It builds the tree recursively, following an in-order traversal.
*
* <p>
* Each block stores 'lanes' values at indices {@code i, i + 1, ..., i + lanes - 1} where {@code i} is the
* starting offset. The starting offset of the root block is 1. The branching factor is (1 + lanes) so each
* block can have these many children. Given the starting offset {@code i} of a block, the starting offset
* of its k-th child (ranging from {@code 0, 1, ..., k}) can be computed as {@code i + ((i + k) << shift)}.
*
* @param src is the sorted input array
* @param i is the index in the input array to read the value from
* @param size the number of values in the input array
* @param dst is the output array
* @param j is the index in the output array to write the value to
* @return the next index 'i'
*/
private static int build(long[] src, int i, int size, long[] dst, int j) {
if (j < dst.length) {
for (int k = 0; k < LANES; k++) {
i = build(src, i, size, dst, j + ((j + k) << SHIFT));

// Fills the B-tree as a complete tree, i.e., all levels are completely filled,
// except the last level which is filled from left to right.
// The trick is to fill the destination array between indices 1...size (inclusive / 1-indexed)
// and pad the remaining array with +infinity.
dst[j + k] = (j + k <= size) ? src[i++] : Long.MAX_VALUE;
}
i = build(src, i, size, dst, j + ((j + LANES) << SHIFT));
}
return i;
}

@Override
public long floor(long key) {
Vector<Long> keyVector = LongVector.broadcast(LONG_VECTOR_SPECIES, key);
int i = 1, result = 1;

while (i < values.length) {
Vector<Long> valuesVector = LongVector.fromArray(LONG_VECTOR_SPECIES, values, i);
int j = i + valuesVector.compare(VectorOperators.GT, keyVector).firstTrue();
result = (j > i) ? j : result;
i += (j << SHIFT);
}

assert result > 1 : "key must be greater than or equal to " + minValue;
return values[result - 1];
}

private static int log2(int num) {
if ((num <= 0) || ((num & (num - 1)) != 0)) {
throw new IllegalArgumentException(num + " is not a positive power of 2");
}
return 32 - Integer.numberOfLeadingZeros(num - 1);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.common.round;

import org.opensearch.common.annotation.InternalApi;

import jdk.incubator.vector.LongVector;

/**
* Factory class to create and return the fastest implementation of {@link Roundable}.
*
* @opensearch.internal
*/
@InternalApi
public final class RoundableFactory {
ketanv3 marked this conversation as resolved.
Show resolved Hide resolved
/**
* The maximum limit up to which linear search is used, otherwise binary or B-tree search is used.
* This is because linear search is much faster on small arrays.
* Benchmark results: <a href="https://github.com/opensearch-project/OpenSearch/pull/9727">PR #9727</a>
*/
private static final int LINEAR_SEARCH_MAX_SIZE = 64;

/**
* Indicates whether the vectorized (SIMD) B-tree search implementation is to be used.
* It is true when either:
* 1. The feature flag is set to "forced", or
* 2. The platform has a minimum of 4 long vector lanes and the feature flag is set to "true".
*/
private static final boolean USE_BTREE_SEARCHER;

static {
String simdRoundingFeatureFlag = System.getProperty("opensearch.experimental.feature.simd.rounding.enabled");
reta marked this conversation as resolved.
Show resolved Hide resolved
USE_BTREE_SEARCHER = "forced".equalsIgnoreCase(simdRoundingFeatureFlag)
|| (LongVector.SPECIES_PREFERRED.length() >= 4 && "true".equalsIgnoreCase(simdRoundingFeatureFlag));
}

private RoundableFactory() {}

/**
* Creates and returns the fastest implementation of {@link Roundable}.
*/
public static Roundable create(long[] values, int size) {
if (size <= LINEAR_SEARCH_MAX_SIZE) {
return new BidirectionalLinearSearcher(values, size);
} else if (USE_BTREE_SEARCHER) {
return new BtreeSearcher(values, size);
} else {
return new BinarySearcher(values, size);
}
}
}
Loading
Loading