diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/tdigest/SortBench.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/tdigest/SortBench.java index 862479571113d..423db48337586 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/tdigest/SortBench.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/tdigest/SortBench.java @@ -22,6 +22,9 @@ package org.elasticsearch.benchmark.tdigest; import org.elasticsearch.tdigest.Sort; +import org.elasticsearch.tdigest.arrays.TDigestDoubleArray; +import org.elasticsearch.tdigest.arrays.TDigestIntArray; +import org.elasticsearch.tdigest.arrays.WrapperTDigestArrays; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -35,7 +38,6 @@ import org.openjdk.jmh.annotations.Threads; import org.openjdk.jmh.annotations.Warmup; -import java.util.Arrays; import java.util.Random; import java.util.concurrent.TimeUnit; @@ -49,7 +51,7 @@ @State(Scope.Thread) public class SortBench { private final int size = 100000; - private final double[] values = new double[size]; + private final TDigestDoubleArray values = WrapperTDigestArrays.INSTANCE.newDoubleArray(size); @Param({ "0", "1", "-1" }) public int sortDirection; @@ -58,22 +60,22 @@ public class SortBench { public void setup() { Random prng = new Random(999983); for (int i = 0; i < size; i++) { - values[i] = prng.nextDouble(); + values.set(i, prng.nextDouble()); } if (sortDirection > 0) { - Arrays.sort(values); + values.sort(); } else if (sortDirection < 0) { - Arrays.sort(values); - Sort.reverse(values, 0, values.length); + values.sort(); + Sort.reverse(values, 0, values.size()); } } @Benchmark - public void quicksort() { - int[] order = new int[size]; + public void stableSort() { + TDigestIntArray order = WrapperTDigestArrays.INSTANCE.newIntArray(size); for (int i = 0; i < size; i++) { - order[i] = i; + order.set(i, i); } - Sort.sort(order, values, null, values.length); + Sort.stableSort(order, values, values.size()); } } diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/tdigest/TDigestBench.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/tdigest/TDigestBench.java index abf08395f90a6..58bb5b07d22cd 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/tdigest/TDigestBench.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/tdigest/TDigestBench.java @@ -21,9 +21,9 @@ package org.elasticsearch.benchmark.tdigest; -import org.elasticsearch.tdigest.AVLTreeDigest; import org.elasticsearch.tdigest.MergingDigest; import org.elasticsearch.tdigest.TDigest; +import org.elasticsearch.tdigest.arrays.WrapperTDigestArrays; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -61,13 +61,19 @@ public enum TDigestFactory { MERGE { @Override TDigest create(double compression) { - return new MergingDigest(compression, (int) (10 * compression)); + return new MergingDigest(WrapperTDigestArrays.INSTANCE, compression, (int) (10 * compression)); } }, AVL_TREE { @Override TDigest create(double compression) { - return new AVLTreeDigest(compression); + return TDigest.createAvlTreeDigest(WrapperTDigestArrays.INSTANCE, compression); + } + }, + HYBRID { + @Override + TDigest create(double compression) { + return TDigest.createHybridDigest(WrapperTDigestArrays.INSTANCE, compression); } }; @@ -77,7 +83,7 @@ TDigest create(double compression) { @Param({ "100", "300" }) double compression; - @Param({ "MERGE", "AVL_TREE" }) + @Param({ "MERGE", "AVL_TREE", "HYBRID" }) TDigestFactory tdigestFactory; @Param({ "NORMAL", "GAUSSIAN" }) diff --git a/distribution/tools/geoip-cli/src/test/java/org/elasticsearch/geoip/GeoIpCliTests.java b/distribution/tools/geoip-cli/src/test/java/org/elasticsearch/geoip/GeoIpCliTests.java index 5b3fee1d9e49f..7daec3365c379 100644 --- a/distribution/tools/geoip-cli/src/test/java/org/elasticsearch/geoip/GeoIpCliTests.java +++ b/distribution/tools/geoip-cli/src/test/java/org/elasticsearch/geoip/GeoIpCliTests.java @@ -20,7 +20,6 @@ import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import java.io.BufferedInputStream; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; @@ -106,9 +105,7 @@ private void verifyOverview() throws Exception { private void verifyTarball(Map data) throws Exception { for (String tgz : List.of("a.tgz", "b.tgz")) { try ( - TarArchiveInputStream tis = new TarArchiveInputStream( - new GZIPInputStream(new BufferedInputStream(Files.newInputStream(target.resolve(tgz)))) - ) + TarArchiveInputStream tis = new TarArchiveInputStream(new GZIPInputStream(Files.newInputStream(target.resolve(tgz)), 8192)) ) { TarArchiveEntry entry = tis.getNextTarEntry(); assertNotNull(entry); diff --git a/docs/changelog/112677.yaml b/docs/changelog/112677.yaml new file mode 100644 index 0000000000000..89662236c6ca5 --- /dev/null +++ b/docs/changelog/112677.yaml @@ -0,0 +1,5 @@ +pr: 112677 +summary: Stream OpenAI Completion +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/112973.yaml b/docs/changelog/112973.yaml new file mode 100644 index 0000000000000..3ba86a31334ff --- /dev/null +++ b/docs/changelog/112973.yaml @@ -0,0 +1,5 @@ +pr: 112973 +summary: Fix verbose get data stream API not requiring extra privileges +area: Data streams +type: bug +issues: [] diff --git a/docs/reference/rest-api/security/get-service-accounts.asciidoc b/docs/reference/rest-api/security/get-service-accounts.asciidoc index 526c6e65ccf33..3a14278fb4cfb 100644 --- a/docs/reference/rest-api/security/get-service-accounts.asciidoc +++ b/docs/reference/rest-api/security/get-service-accounts.asciidoc @@ -66,7 +66,8 @@ GET /_security/service/elastic/fleet-server "cluster": [ "monitor", "manage_own_api_key", - "read_fleet_secrets" + "read_fleet_secrets", + "cluster:admin/xpack/connector/*" ], "indices": [ { @@ -238,6 +239,35 @@ GET /_security/service/elastic/fleet-server "auto_configure" ], "allow_restricted_indices": false + }, + { + "names": [ + ".elastic-connectors*" + ], + "privileges": [ + "read", + "write", + "monitor", + "create_index", + "auto_configure", + "maintenance" + ], + "allow_restricted_indices": false + }, + { + "names": [ + "content-*", + ".search-acl-filter-*" + ], + "privileges": [ + "read", + "write", + "monitor", + "create_index", + "auto_configure", + "maintenance" + ], + "allow_restricted_indices": false } ], "applications": [ diff --git a/docs/reference/snapshot-restore/repository-source-only.asciidoc b/docs/reference/snapshot-restore/repository-source-only.asciidoc index 07ddedd197931..04e53c42aff9d 100644 --- a/docs/reference/snapshot-restore/repository-source-only.asciidoc +++ b/docs/reference/snapshot-restore/repository-source-only.asciidoc @@ -18,7 +18,7 @@ stream or index. ================================================== Source-only snapshots are only supported if the `_source` field is enabled and no source-filtering is applied. -When you restore a source-only snapshot: +As a result, indices adopting synthetic source cannot be restored. When you restore a source-only snapshot: * The restored index is read-only and can only serve `match_all` search or scroll requests to enable reindexing. diff --git a/libs/tdigest/src/main/java/module-info.java b/libs/tdigest/src/main/java/module-info.java index 994ff41187221..8edaff3f31d8c 100644 --- a/libs/tdigest/src/main/java/module-info.java +++ b/libs/tdigest/src/main/java/module-info.java @@ -19,4 +19,5 @@ module org.elasticsearch.tdigest { exports org.elasticsearch.tdigest; + exports org.elasticsearch.tdigest.arrays; } diff --git a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/AVLGroupTree.java b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/AVLGroupTree.java index 12b2a29d3e034..8528db2128729 100644 --- a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/AVLGroupTree.java +++ b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/AVLGroupTree.java @@ -21,8 +21,11 @@ package org.elasticsearch.tdigest; +import org.elasticsearch.tdigest.arrays.TDigestArrays; +import org.elasticsearch.tdigest.arrays.TDigestDoubleArray; +import org.elasticsearch.tdigest.arrays.TDigestLongArray; + import java.util.AbstractCollection; -import java.util.Arrays; import java.util.Iterator; /** @@ -32,20 +35,20 @@ final class AVLGroupTree extends AbstractCollection { /* For insertions into the tree */ private double centroid; private long count; - private double[] centroids; - private long[] counts; - private long[] aggregatedCounts; + private final TDigestDoubleArray centroids; + private final TDigestLongArray counts; + private final TDigestLongArray aggregatedCounts; private final IntAVLTree tree; - AVLGroupTree() { - tree = new IntAVLTree() { + AVLGroupTree(TDigestArrays arrays) { + tree = new IntAVLTree(arrays) { @Override protected void resize(int newCapacity) { super.resize(newCapacity); - centroids = Arrays.copyOf(centroids, newCapacity); - counts = Arrays.copyOf(counts, newCapacity); - aggregatedCounts = Arrays.copyOf(aggregatedCounts, newCapacity); + centroids.resize(newCapacity); + counts.resize(newCapacity); + aggregatedCounts.resize(newCapacity); } @Override @@ -56,13 +59,13 @@ protected void merge(int node) { @Override protected void copy(int node) { - centroids[node] = centroid; - counts[node] = count; + centroids.set(node, centroid); + counts.set(node, count); } @Override protected int compare(int node) { - if (centroid < centroids[node]) { + if (centroid < centroids.get(node)) { return -1; } else { // upon equality, the newly added node is considered greater @@ -73,13 +76,13 @@ protected int compare(int node) { @Override protected void fixAggregates(int node) { super.fixAggregates(node); - aggregatedCounts[node] = counts[node] + aggregatedCounts[left(node)] + aggregatedCounts[right(node)]; + aggregatedCounts.set(node, counts.get(node) + aggregatedCounts.get(left(node)) + aggregatedCounts.get(right(node))); } }; - centroids = new double[tree.capacity()]; - counts = new long[tree.capacity()]; - aggregatedCounts = new long[tree.capacity()]; + centroids = arrays.newDoubleArray(tree.capacity()); + counts = arrays.newLongArray(tree.capacity()); + aggregatedCounts = arrays.newLongArray(tree.capacity()); } /** @@ -107,14 +110,14 @@ public int next(int node) { * Return the mean for the provided node. */ public double mean(int node) { - return centroids[node]; + return centroids.get(node); } /** * Return the count for the provided node. */ public long count(int node) { - return counts[node]; + return counts.get(node); } /** @@ -167,7 +170,7 @@ public int floorSum(long sum) { int floor = IntAVLTree.NIL; for (int node = tree.root(); node != IntAVLTree.NIL;) { final int left = tree.left(node); - final long leftCount = aggregatedCounts[left]; + final long leftCount = aggregatedCounts.get(left); if (leftCount <= sum) { floor = node; sum -= leftCount + count(node); @@ -199,11 +202,11 @@ public int last() { */ public long headSum(int node) { final int left = tree.left(node); - long sum = aggregatedCounts[left]; + long sum = aggregatedCounts.get(left); for (int n = node, p = tree.parent(node); p != IntAVLTree.NIL; n = p, p = tree.parent(n)) { if (n == tree.right(p)) { final int leftP = tree.left(p); - sum += counts[p] + aggregatedCounts[leftP]; + sum += counts.get(p) + aggregatedCounts.get(leftP); } } return sum; @@ -243,7 +246,7 @@ public void remove() { * Return the total count of points that have been added to the tree. */ public long sum() { - return aggregatedCounts[tree.root()]; + return aggregatedCounts.get(tree.root()); } void checkBalance() { @@ -255,7 +258,9 @@ void checkAggregates() { } private void checkAggregates(int node) { - assert aggregatedCounts[node] == counts[node] + aggregatedCounts[tree.left(node)] + aggregatedCounts[tree.right(node)]; + assert aggregatedCounts.get(node) == counts.get(node) + aggregatedCounts.get(tree.left(node)) + aggregatedCounts.get( + tree.right(node) + ); if (node != IntAVLTree.NIL) { checkAggregates(tree.left(node)); checkAggregates(tree.right(node)); diff --git a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/AVLTreeDigest.java b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/AVLTreeDigest.java index deb3407565f36..c28f86b9b8edc 100644 --- a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/AVLTreeDigest.java +++ b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/AVLTreeDigest.java @@ -21,6 +21,8 @@ package org.elasticsearch.tdigest; +import org.elasticsearch.tdigest.arrays.TDigestArrays; + import java.util.Collection; import java.util.Collections; import java.util.Iterator; @@ -29,6 +31,8 @@ import static org.elasticsearch.tdigest.IntAVLTree.NIL; public class AVLTreeDigest extends AbstractTDigest { + private final TDigestArrays arrays; + final Random gen = new Random(); private final double compression; private AVLGroupTree summary; @@ -46,9 +50,10 @@ public class AVLTreeDigest extends AbstractTDigest { * quantiles. Conversely, you should expect to track about 5 N centroids for this * accuracy. */ - public AVLTreeDigest(double compression) { + AVLTreeDigest(TDigestArrays arrays, double compression) { + this.arrays = arrays; this.compression = compression; - summary = new AVLGroupTree(); + summary = new AVLGroupTree(arrays); } /** @@ -149,7 +154,7 @@ public void compress() { needsCompression = false; AVLGroupTree centroids = summary; - this.summary = new AVLGroupTree(); + this.summary = new AVLGroupTree(arrays); final int[] nodes = new int[centroids.size()]; nodes[0] = centroids.first(); diff --git a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/Dist.java b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/Dist.java index 087deaedc7d75..02fb7a8376aa4 100644 --- a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/Dist.java +++ b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/Dist.java @@ -21,6 +21,8 @@ package org.elasticsearch.tdigest; +import org.elasticsearch.tdigest.arrays.TDigestDoubleArray; + import java.util.List; import java.util.function.Function; @@ -102,6 +104,10 @@ public static double cdf(final double x, List data) { return cdf(x, data.size(), data::get); } + public static double cdf(final double x, TDigestDoubleArray data) { + return cdf(x, data.size(), data::get); + } + private static double quantile(final double q, final int length, Function elementGetter) { if (length == 0) { return Double.NaN; @@ -133,4 +139,8 @@ public static double quantile(final double q, double[] data) { public static double quantile(final double q, List data) { return quantile(q, data.size(), data::get); } + + public static double quantile(final double q, TDigestDoubleArray data) { + return quantile(q, data.size(), data::get); + } } diff --git a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/HybridDigest.java b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/HybridDigest.java index 07a12381e2a71..c28a99fbd6d44 100644 --- a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/HybridDigest.java +++ b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/HybridDigest.java @@ -19,6 +19,8 @@ package org.elasticsearch.tdigest; +import org.elasticsearch.tdigest.arrays.TDigestArrays; + import java.util.Collection; /** @@ -32,6 +34,8 @@ */ public class HybridDigest extends AbstractTDigest { + private final TDigestArrays arrays; + // See MergingDigest's compression param. private final double compression; @@ -39,7 +43,7 @@ public class HybridDigest extends AbstractTDigest { private final long maxSortingSize; // This is set to null when the implementation switches to MergingDigest. - private SortingDigest sortingDigest = new SortingDigest(); + private SortingDigest sortingDigest; // This gets initialized when the implementation switches to MergingDigest. private MergingDigest mergingDigest; @@ -51,9 +55,11 @@ public class HybridDigest extends AbstractTDigest { * @param compression The compression factor for the MergingDigest * @param maxSortingSize The sample size limit for switching from a {@link SortingDigest} to a {@link MergingDigest} implementation */ - HybridDigest(double compression, long maxSortingSize) { + HybridDigest(TDigestArrays arrays, double compression, long maxSortingSize) { + this.arrays = arrays; this.compression = compression; this.maxSortingSize = maxSortingSize; + this.sortingDigest = new SortingDigest(arrays); } /** @@ -62,11 +68,11 @@ public class HybridDigest extends AbstractTDigest { * * @param compression The compression factor for the MergingDigest */ - HybridDigest(double compression) { + HybridDigest(TDigestArrays arrays, double compression) { // The default maxSortingSize is calculated so that the SortingDigest will have comparable size with the MergingDigest // at the point where implementations switch, e.g. for default compression 100 SortingDigest allocates ~16kB and MergingDigest // allocates ~15kB. - this(compression, Math.round(compression) * 20); + this(arrays, compression, Math.round(compression) * 20); } @Override @@ -98,9 +104,9 @@ public void reserve(long size) { // Check if we need to switch implementations. assert sortingDigest != null; if (sortingDigest.size() + size >= maxSortingSize) { - mergingDigest = new MergingDigest(compression); - for (double value : sortingDigest.values) { - mergingDigest.add(value); + mergingDigest = new MergingDigest(arrays, compression); + for (int i = 0; i < sortingDigest.values.size(); i++) { + mergingDigest.add(sortingDigest.values.get(i)); } mergingDigest.reserve(size); // Release the allocated SortingDigest. diff --git a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/IntAVLTree.java b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/IntAVLTree.java index f2f71b13a2eef..cda8aecdb2ccc 100644 --- a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/IntAVLTree.java +++ b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/IntAVLTree.java @@ -21,6 +21,10 @@ package org.elasticsearch.tdigest; +import org.elasticsearch.tdigest.arrays.TDigestArrays; +import org.elasticsearch.tdigest.arrays.TDigestByteArray; +import org.elasticsearch.tdigest.arrays.TDigestIntArray; + import java.util.Arrays; /** @@ -30,7 +34,6 @@ * identifiers as indices. */ abstract class IntAVLTree { - /** * We use 0 instead of -1 so that left(NIL) works without * condition. @@ -44,22 +47,22 @@ static int oversize(int size) { private final NodeAllocator nodeAllocator; private int root; - private int[] parent; - private int[] left; - private int[] right; - private byte[] depth; + private final TDigestIntArray parent; + private final TDigestIntArray left; + private final TDigestIntArray right; + private final TDigestByteArray depth; - IntAVLTree(int initialCapacity) { + IntAVLTree(TDigestArrays arrays, int initialCapacity) { nodeAllocator = new NodeAllocator(); root = NIL; - parent = new int[initialCapacity]; - left = new int[initialCapacity]; - right = new int[initialCapacity]; - depth = new byte[initialCapacity]; + parent = arrays.newIntArray(initialCapacity); + left = arrays.newIntArray(initialCapacity); + right = arrays.newIntArray(initialCapacity); + depth = arrays.newByteArray(initialCapacity); } - IntAVLTree() { - this(16); + IntAVLTree(TDigestArrays arrays) { + this(arrays, 16); } /** @@ -74,7 +77,7 @@ public int root() { * can hold. */ public int capacity() { - return parent.length; + return parent.size(); } /** @@ -82,10 +85,10 @@ public int capacity() { * newCapacity (excluded). */ protected void resize(int newCapacity) { - parent = Arrays.copyOf(parent, newCapacity); - left = Arrays.copyOf(left, newCapacity); - right = Arrays.copyOf(right, newCapacity); - depth = Arrays.copyOf(depth, newCapacity); + parent.resize(newCapacity); + left.resize(newCapacity); + right.resize(newCapacity); + depth.resize(newCapacity); } /** @@ -99,28 +102,28 @@ public int size() { * Return the parent of the provided node. */ public int parent(int node) { - return parent[node]; + return parent.get(node); } /** * Return the left child of the provided node. */ public int left(int node) { - return left[node]; + return left.get(node); } /** * Return the right child of the provided node. */ public int right(int node) { - return right[node]; + return right.get(node); } /** * Return the depth nodes that are stored below node including itself. */ public int depth(int node) { - return depth[node]; + return depth.get(node); } /** @@ -493,23 +496,23 @@ private void rotateRight(int n) { private void parent(int node, int parent) { assert node != NIL; - this.parent[node] = parent; + this.parent.set(node, parent); } private void left(int node, int left) { assert node != NIL; - this.left[node] = left; + this.left.set(node, left); } private void right(int node, int right) { assert node != NIL; - this.right[node] = right; + this.right.set(node, right); } private void depth(int node, int depth) { assert node != NIL; assert depth >= 0 && depth <= Byte.MAX_VALUE; - this.depth[node] = (byte) depth; + this.depth.set(node, (byte) depth); } void checkBalance(int node) { diff --git a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/MergingDigest.java b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/MergingDigest.java index fc22bda52e104..1649af041ee19 100644 --- a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/MergingDigest.java +++ b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/MergingDigest.java @@ -21,6 +21,10 @@ package org.elasticsearch.tdigest; +import org.elasticsearch.tdigest.arrays.TDigestArrays; +import org.elasticsearch.tdigest.arrays.TDigestDoubleArray; +import org.elasticsearch.tdigest.arrays.TDigestIntArray; + import java.util.AbstractCollection; import java.util.Collection; import java.util.Iterator; @@ -62,6 +66,8 @@ * what the AVLTreeDigest uses and no dynamic allocation is required at all. */ public class MergingDigest extends AbstractTDigest { + private final TDigestArrays arrays; + private int mergeCount = 0; private final double publicCompression; @@ -70,26 +76,26 @@ public class MergingDigest extends AbstractTDigest { // points to the first unused centroid private int lastUsedCell; - // sum_i weight[i] See also unmergedWeight + // sum_i weight.get(i) See also unmergedWeight private double totalWeight = 0; // number of points that have been added to each merged centroid - private final double[] weight; + private final TDigestDoubleArray weight; // mean of points added to each merged centroid - private final double[] mean; + private final TDigestDoubleArray mean; - // sum_i tempWeight[i] + // sum_i tempWeight.get(i) private double unmergedWeight = 0; // this is the index of the next temporary centroid // this is a more Java-like convention than lastUsedCell uses private int tempUsed = 0; - private final double[] tempWeight; - private final double[] tempMean; + private final TDigestDoubleArray tempWeight; + private final TDigestDoubleArray tempMean; // array used for sorting the temp centroids. This is a field // to avoid allocations during operation - private final int[] order; + private final TDigestIntArray order; // if true, alternate upward and downward merge passes public boolean useAlternatingSort = true; @@ -109,8 +115,8 @@ public class MergingDigest extends AbstractTDigest { * * @param compression The compression factor */ - public MergingDigest(double compression) { - this(compression, -1); + public MergingDigest(TDigestArrays arrays, double compression) { + this(arrays, compression, -1); } /** @@ -119,9 +125,9 @@ public MergingDigest(double compression) { * @param compression Compression factor for t-digest. Same as 1/\delta in the paper. * @param bufferSize How many samples to retain before merging. */ - public MergingDigest(double compression, int bufferSize) { + public MergingDigest(TDigestArrays arrays, double compression, int bufferSize) { // we can guarantee that we only need ceiling(compression). - this(compression, bufferSize, -1); + this(arrays, compression, bufferSize, -1); } /** @@ -131,7 +137,9 @@ public MergingDigest(double compression, int bufferSize) { * @param bufferSize Number of temporary centroids * @param size Size of main buffer */ - public MergingDigest(double compression, int bufferSize, int size) { + public MergingDigest(TDigestArrays arrays, double compression, int bufferSize, int size) { + this.arrays = arrays; + // ensure compression >= 10 // default size = 2 * ceil(compression) // default bufferSize = 5 * size @@ -205,12 +213,12 @@ public MergingDigest(double compression, int bufferSize, int size) { bufferSize = 2 * size; } - weight = new double[size]; - mean = new double[size]; + weight = arrays.newDoubleArray(size); + mean = arrays.newDoubleArray(size); - tempWeight = new double[bufferSize]; - tempMean = new double[bufferSize]; - order = new int[bufferSize]; + tempWeight = arrays.newDoubleArray(bufferSize); + tempMean = arrays.newDoubleArray(bufferSize); + order = arrays.newIntArray(bufferSize); lastUsedCell = 0; } @@ -218,12 +226,12 @@ public MergingDigest(double compression, int bufferSize, int size) { @Override public void add(double x, long w) { checkValue(x); - if (tempUsed >= tempWeight.length - lastUsedCell - 1) { + if (tempUsed >= tempWeight.size() - lastUsedCell - 1) { mergeNewValues(); } int where = tempUsed++; - tempWeight[where] = w; - tempMean[where] = x; + tempWeight.set(where, w); + tempMean.set(where, x); unmergedWeight += w; if (x < min) { min = x; @@ -252,22 +260,22 @@ private void mergeNewValues(double compression) { } private void merge( - double[] incomingMean, - double[] incomingWeight, + TDigestDoubleArray incomingMean, + TDigestDoubleArray incomingWeight, int incomingCount, - int[] incomingOrder, + TDigestIntArray incomingOrder, double unmergedWeight, boolean runBackwards, double compression ) { // when our incoming buffer fills up, we combine our existing centroids with the incoming data, // and then reduce the centroids by merging if possible - System.arraycopy(mean, 0, incomingMean, incomingCount, lastUsedCell); - System.arraycopy(weight, 0, incomingWeight, incomingCount, lastUsedCell); + incomingMean.set(incomingCount, mean, 0, lastUsedCell); + incomingWeight.set(incomingCount, weight, 0, lastUsedCell); incomingCount += lastUsedCell; if (incomingOrder == null) { - incomingOrder = new int[incomingCount]; + incomingOrder = arrays.newIntArray(incomingCount); } Sort.stableSort(incomingOrder, incomingMean, incomingCount); @@ -280,8 +288,8 @@ private void merge( // start by copying the least incoming value to the normal buffer lastUsedCell = 0; - mean[lastUsedCell] = incomingMean[incomingOrder[0]]; - weight[lastUsedCell] = incomingWeight[incomingOrder[0]]; + mean.set(lastUsedCell, incomingMean.get(incomingOrder.get(0))); + weight.set(lastUsedCell, incomingWeight.get(incomingOrder.get(0))); double wSoFar = 0; // weight will contain all zeros after this loop @@ -290,8 +298,8 @@ private void merge( double k1 = scale.k(0, normalizer); double wLimit = totalWeight * scale.q(k1 + 1, normalizer); for (int i = 1; i < incomingCount; i++) { - int ix = incomingOrder[i]; - double proposedWeight = weight[lastUsedCell] + incomingWeight[ix]; + int ix = incomingOrder.get(i); + double proposedWeight = weight.get(lastUsedCell) + incomingWeight.get(ix); double projectedW = wSoFar + proposedWeight; boolean addThis; if (useWeightLimit) { @@ -305,7 +313,7 @@ private void merge( // force first and last centroid to never merge addThis = false; } - if (lastUsedCell == mean.length - 1) { + if (lastUsedCell == mean.size() - 1) { // use the last centroid, there's no more addThis = true; } @@ -313,22 +321,26 @@ private void merge( if (addThis) { // next point will fit // so merge into existing centroid - weight[lastUsedCell] += incomingWeight[ix]; - mean[lastUsedCell] = mean[lastUsedCell] + (incomingMean[ix] - mean[lastUsedCell]) * incomingWeight[ix] - / weight[lastUsedCell]; - incomingWeight[ix] = 0; + weight.set(lastUsedCell, weight.get(lastUsedCell) + incomingWeight.get(ix)); + mean.set( + lastUsedCell, + mean.get(lastUsedCell) + (incomingMean.get(ix) - mean.get(lastUsedCell)) * incomingWeight.get(ix) / weight.get( + lastUsedCell + ) + ); + incomingWeight.set(ix, 0); } else { // didn't fit ... move to next output, copy out first centroid - wSoFar += weight[lastUsedCell]; + wSoFar += weight.get(lastUsedCell); if (useWeightLimit == false) { k1 = scale.k(wSoFar / totalWeight, normalizer); wLimit = totalWeight * scale.q(k1 + 1, normalizer); } lastUsedCell++; - mean[lastUsedCell] = incomingMean[ix]; - weight[lastUsedCell] = incomingWeight[ix]; - incomingWeight[ix] = 0; + mean.set(lastUsedCell, incomingMean.get(ix)); + weight.set(lastUsedCell, incomingWeight.get(ix)); + incomingWeight.set(ix, 0); } } // points to next empty cell @@ -337,7 +349,7 @@ private void merge( // sanity check double sum = 0; for (int i = 0; i < lastUsedCell; i++) { - sum += weight[i]; + sum += weight.get(i); } assert sum == totalWeight; if (runBackwards) { @@ -345,8 +357,8 @@ private void merge( Sort.reverse(weight, 0, lastUsedCell); } if (totalWeight > 0) { - min = Math.min(min, mean[0]); - max = Math.max(max, mean[lastUsedCell - 1]); + min = Math.min(min, mean.get(0)); + max = Math.max(max, mean.get(lastUsedCell - 1)); } } @@ -387,8 +399,8 @@ public double cdf(double x) { // we have one or more centroids == x, treat them as one // dw will accumulate the weight of all of the centroids at x double dw = 0; - for (int i = 0; i < lastUsedCell && Double.compare(mean[i], x) == 0; i++) { - dw += weight[i]; + for (int i = 0; i < lastUsedCell && Double.compare(mean.get(i), x) == 0; i++) { + dw += weight.get(i); } return dw / 2.0 / size(); } @@ -398,31 +410,32 @@ public double cdf(double x) { } if (x == max) { double dw = 0; - for (int i = lastUsedCell - 1; i >= 0 && Double.compare(mean[i], x) == 0; i--) { - dw += weight[i]; + for (int i = lastUsedCell - 1; i >= 0 && Double.compare(mean.get(i), x) == 0; i--) { + dw += weight.get(i); } return (size() - dw / 2.0) / size(); } // initially, we set left width equal to right width - double left = (mean[1] - mean[0]) / 2; + double left = (mean.get(1) - mean.get(0)) / 2; double weightSoFar = 0; for (int i = 0; i < lastUsedCell - 1; i++) { - double right = (mean[i + 1] - mean[i]) / 2; - if (x < mean[i] + right) { - double value = (weightSoFar + weight[i] * interpolate(x, mean[i] - left, mean[i] + right)) / size(); + double right = (mean.get(i + 1) - mean.get(i)) / 2; + if (x < mean.get(i) + right) { + double value = (weightSoFar + weight.get(i) * interpolate(x, mean.get(i) - left, mean.get(i) + right)) / size(); return Math.max(value, 0.0); } - weightSoFar += weight[i]; + weightSoFar += weight.get(i); left = right; } // for the last element, assume right width is same as left int lastOffset = lastUsedCell - 1; - double right = (mean[lastOffset] - mean[lastOffset - 1]) / 2; - if (x < mean[lastOffset] + right) { - return (weightSoFar + weight[lastOffset] * interpolate(x, mean[lastOffset] - right, mean[lastOffset] + right)) / size(); + double right = (mean.get(lastOffset) - mean.get(lastOffset - 1)) / 2; + if (x < mean.get(lastOffset) + right) { + return (weightSoFar + weight.get(lastOffset) * interpolate(x, mean.get(lastOffset) - right, mean.get(lastOffset) + right)) + / size(); } return 1; } @@ -440,7 +453,7 @@ public double quantile(double q) { return Double.NaN; } else if (lastUsedCell == 1) { // with one data point, all quantiles lead to Rome - return mean[0]; + return mean.get(0); } // we know that there are at least two centroids now @@ -458,40 +471,40 @@ public double quantile(double q) { return max; } - double weightSoFar = weight[0] / 2; + double weightSoFar = weight.get(0) / 2; // if the left centroid has more than one sample, we still know // that one sample occurred at min so we can do some interpolation - if (weight[0] > 1 && index < weightSoFar) { + if (weight.get(0) > 1 && index < weightSoFar) { // there is a single sample at min so we interpolate with less weight - return weightedAverage(min, weightSoFar - index, mean[0], index); + return weightedAverage(min, weightSoFar - index, mean.get(0), index); } // if the right-most centroid has more than one sample, we still know // that one sample occurred at max so we can do some interpolation - if (weight[n - 1] > 1 && totalWeight - index <= weight[n - 1] / 2) { - return max - (totalWeight - index - 1) / (weight[n - 1] / 2 - 1) * (max - mean[n - 1]); + if (weight.get(n - 1) > 1 && totalWeight - index <= weight.get(n - 1) / 2) { + return max - (totalWeight - index - 1) / (weight.get(n - 1) / 2 - 1) * (max - mean.get(n - 1)); } // in between extremes we interpolate between centroids for (int i = 0; i < n - 1; i++) { - double dw = (weight[i] + weight[i + 1]) / 2; + double dw = (weight.get(i) + weight.get(i + 1)) / 2; if (weightSoFar + dw > index) { // centroids i and i+1 bracket our current point double z1 = index - weightSoFar; double z2 = weightSoFar + dw - index; - return weightedAverage(mean[i], z2, mean[i + 1], z1); + return weightedAverage(mean.get(i), z2, mean.get(i + 1), z1); } weightSoFar += dw; } - assert weight[n - 1] >= 1; - assert index >= totalWeight - weight[n - 1]; + assert weight.get(n - 1) >= 1; + assert index >= totalWeight - weight.get(n - 1); // Interpolate between the last mean and the max. double z1 = index - weightSoFar; - double z2 = weight[n - 1] / 2.0 - z1; - return weightedAverage(mean[n - 1], z1, max, z2); + double z2 = weight.get(n - 1) / 2.0 - z1; + return weightedAverage(mean.get(n - 1), z1, max, z2); } @Override @@ -518,7 +531,7 @@ public boolean hasNext() { @Override public Centroid next() { - Centroid rc = new Centroid(mean[i], (long) weight[i]); + Centroid rc = new Centroid(mean.get(i), (long) weight.get(i)); i++; return rc; } @@ -553,7 +566,7 @@ public void setScaleFunction(ScaleFunction scaleFunction) { @Override public int byteSize() { - return 48 + 8 * (mean.length + weight.length + tempMean.length + tempWeight.length) + 4 * order.length; + return 48 + 8 * (mean.size() + weight.size() + tempMean.size() + tempWeight.size()) + 4 * order.size(); } @Override diff --git a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/Sort.java b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/Sort.java index c62ae54f93c2c..1c54b0327690c 100644 --- a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/Sort.java +++ b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/Sort.java @@ -21,7 +21,9 @@ package org.elasticsearch.tdigest; -import java.util.Arrays; +import org.elasticsearch.tdigest.arrays.TDigestDoubleArray; +import org.elasticsearch.tdigest.arrays.TDigestIntArray; + import java.util.Random; /** @@ -37,206 +39,14 @@ public class Sort { * @param values The values to sort. * @param n The number of values to sort */ - public static void stableSort(int[] order, double[] values, int n) { + public static void stableSort(TDigestIntArray order, TDigestDoubleArray values, int n) { for (int i = 0; i < n; i++) { - order[i] = i; + order.set(i, i); } stableQuickSort(order, values, 0, n, 64); stableInsertionSort(order, values, 0, n, 64); } - /** - * Two-key quick sort on (values, weights) using an index array - * - * @param order Indexes into values - * @param values The values to sort. - * @param weights The secondary sort key - * @param n The number of values to sort - * @return true if the values were already sorted - */ - public static boolean sort(int[] order, double[] values, double[] weights, int n) { - if (weights == null) { - weights = Arrays.copyOf(values, values.length); - } - boolean r = sort(order, values, weights, 0, n); - // now adjust all runs with equal value so that bigger weights are nearer - // the median - double medianWeight = 0; - for (int i = 0; i < n; i++) { - medianWeight += weights[i]; - } - medianWeight = medianWeight / 2; - int i = 0; - double soFar = 0; - double nextGroup = 0; - while (i < n) { - int j = i; - while (j < n && values[order[j]] == values[order[i]]) { - double w = weights[order[j]]; - nextGroup += w; - j++; - } - if (j > i + 1) { - if (soFar >= medianWeight) { - // entire group is in last half, reverse the order - reverse(order, i, j - i); - } else if (nextGroup > medianWeight) { - // group straddles the median, but not necessarily evenly - // most elements are probably unit weight if there are many - double[] scratch = new double[j - i]; - - double netAfter = nextGroup + soFar - 2 * medianWeight; - // heuristically adjust weights to roughly balance around median - double max = weights[order[j - 1]]; - for (int k = j - i - 1; k >= 0; k--) { - double weight = weights[order[i + k]]; - if (netAfter < 0) { - // sort in normal order - scratch[k] = weight; - netAfter += weight; - } else { - // sort reversed, but after normal items - scratch[k] = 2 * max + 1 - weight; - netAfter -= weight; - } - } - // sort these balanced weights - int[] sub = new int[j - i]; - sort(sub, scratch, scratch, 0, j - i); - int[] tmp = Arrays.copyOfRange(order, i, j); - for (int k = 0; k < j - i; k++) { - order[i + k] = tmp[sub[k]]; - } - } - } - soFar = nextGroup; - i = j; - } - return r; - } - - /** - * Two-key quick sort on (values, weights) using an index array - * - * @param order Indexes into values - * @param values The values to sort - * @param weights The weights that define the secondary ordering - * @param start The first element to sort - * @param n The number of values to sort - * @return True if the values were in order without sorting - */ - private static boolean sort(int[] order, double[] values, double[] weights, int start, int n) { - boolean inOrder = true; - for (int i = start; i < start + n; i++) { - if (inOrder && i < start + n - 1) { - inOrder = values[i] < values[i + 1] || (values[i] == values[i + 1] && weights[i] <= weights[i + 1]); - } - order[i] = i; - } - if (inOrder) { - return true; - } - quickSort(order, values, weights, start, start + n, 64); - insertionSort(order, values, weights, start, start + n, 64); - return false; - } - - /** - * Standard two-key quick sort on (values, weights) except that sorting is done on an index array - * rather than the values themselves - * - * @param order The pre-allocated index array - * @param values The values to sort - * @param weights The weights (secondary key) - * @param start The beginning of the values to sort - * @param end The value after the last value to sort - * @param limit The minimum size to recurse down to. - */ - private static void quickSort(int[] order, double[] values, double[] weights, int start, int end, int limit) { - // the while loop implements tail-recursion to avoid excessive stack calls on nasty cases - while (end - start > limit) { - - // pivot by a random element - int pivotIndex = start + prng.nextInt(end - start); - double pivotValue = values[order[pivotIndex]]; - double pivotWeight = weights[order[pivotIndex]]; - - // move pivot to beginning of array - swap(order, start, pivotIndex); - - // we use a three way partition because many duplicate values is an important case - - int low = start + 1; // low points to first value not known to be equal to pivotValue - int high = end; // high points to first value > pivotValue - int i = low; // i scans the array - while (i < high) { - // invariant: (values,weights)[order[k]] == (pivotValue, pivotWeight) for k in [0..low) - // invariant: (values,weights)[order[k]] < (pivotValue, pivotWeight) for k in [low..i) - // invariant: (values,weights)[order[k]] > (pivotValue, pivotWeight) for k in [high..end) - // in-loop: i < high - // in-loop: low < high - // in-loop: i >= low - double vi = values[order[i]]; - double wi = weights[order[i]]; - if (vi == pivotValue && wi == pivotWeight) { - if (low != i) { - swap(order, low, i); - } else { - i++; - } - low++; - } else if (vi > pivotValue || (vi == pivotValue && wi > pivotWeight)) { - high--; - swap(order, i, high); - } else { - // vi < pivotValue || (vi == pivotValue && wi < pivotWeight) - i++; - } - } - // invariant: (values,weights)[order[k]] == (pivotValue, pivotWeight) for k in [0..low) - // invariant: (values,weights)[order[k]] < (pivotValue, pivotWeight) for k in [low..i) - // invariant: (values,weights)[order[k]] > (pivotValue, pivotWeight) for k in [high..end) - // assert i == high || low == high therefore, we are done with partition - - // at this point, i==high, from [start,low) are == pivot, [low,high) are < and [high,end) are > - // we have to move the values equal to the pivot into the middle. To do this, we swap pivot - // values into the top end of the [low,high) range stopping when we run out of destinations - // or when we run out of values to copy - int from = start; - int to = high - 1; - for (i = 0; from < low && to >= low; i++) { - swap(order, from++, to--); - } - if (from == low) { - // ran out of things to copy. This means that the last destination is the boundary - low = to + 1; - } else { - // ran out of places to copy to. This means that there are uncopied pivots and the - // boundary is at the beginning of those - low = from; - } - - // checkPartition(order, values, pivotValue, start, low, high, end); - - // now recurse, but arrange it so we handle the longer limit by tail recursion - // we have to sort the pivot values because they may have different weights - // we can't do that, however until we know how much weight is in the left and right - if (low - start < end - high) { - // left side is smaller - quickSort(order, values, weights, start, low, limit); - - // this is really a way to do - // quickSort(order, values, high, end, limit); - start = high; - } else { - quickSort(order, values, weights, high, end, limit); - // this is really a way to do - // quickSort(order, values, start, low, limit); - end = low; - } - } - } - /** * Stabilized quick sort on an index array. This is a normal quick sort that uses the * original index as a secondary key. Since we are really just sorting an index array @@ -248,14 +58,14 @@ private static void quickSort(int[] order, double[] values, double[] weights, in * @param end The value after the last value to sort * @param limit The minimum size to recurse down to. */ - private static void stableQuickSort(int[] order, double[] values, int start, int end, int limit) { + private static void stableQuickSort(TDigestIntArray order, TDigestDoubleArray values, int start, int end, int limit) { // the while loop implements tail-recursion to avoid excessive stack calls on nasty cases while (end - start > limit) { // pivot by a random element int pivotIndex = start + prng.nextInt(end - start); - double pivotValue = values[order[pivotIndex]]; - int pv = order[pivotIndex]; + double pivotValue = values.get(order.get(pivotIndex)); + int pv = order.get(pivotIndex); // move pivot to beginning of array swap(order, start, pivotIndex); @@ -272,8 +82,8 @@ private static void stableQuickSort(int[] order, double[] values, int start, int // in-loop: i < high // in-loop: low < high // in-loop: i >= low - double vi = values[order[i]]; - int pi = order[i]; + double vi = values.get(order.get(i)); + int pi = order.get(i); if (vi == pivotValue && pi == pv) { if (low != i) { swap(order, low, i); @@ -333,247 +143,10 @@ private static void stableQuickSort(int[] order, double[] values, int start, int } } - /** - * Quick sort in place of several paired arrays. On return, - * keys[...] is in order and the values[] arrays will be - * reordered as well in the same way. - * - * @param key Values to sort on - * @param values The auxiliary values to sort. - */ - public static void sort(double[] key, double[]... values) { - sort(key, 0, key.length, values); - } - - /** - * Quick sort using an index array. On return, - * values[order[i]] is in order as i goes start..n - * @param key Values to sort on - * @param start The first element to sort - * @param n The number of values to sort - * @param values The auxiliary values to sort. - */ - public static void sort(double[] key, int start, int n, double[]... values) { - quickSort(key, values, start, start + n, 8); - insertionSort(key, values, start, start + n, 8); - } - - /** - * Standard quick sort except that sorting rearranges parallel arrays - * - * @param key Values to sort on - * @param values The auxiliary values to sort. - * @param start The beginning of the values to sort - * @param end The value after the last value to sort - * @param limit The minimum size to recurse down to. - */ - private static void quickSort(double[] key, double[][] values, int start, int end, int limit) { - // the while loop implements tail-recursion to avoid excessive stack calls on nasty cases - while (end - start > limit) { - - // median of three values for the pivot - int a = start; - int b = (start + end) / 2; - int c = end - 1; - - int pivotIndex; - double pivotValue; - double va = key[a]; - double vb = key[b]; - double vc = key[c]; - - if (va > vb) { - if (vc > va) { - // vc > va > vb - pivotIndex = a; - pivotValue = va; - } else { - // va > vb, va >= vc - if (vc < vb) { - // va > vb > vc - pivotIndex = b; - pivotValue = vb; - } else { - // va >= vc >= vb - pivotIndex = c; - pivotValue = vc; - } - } - } else { - // vb >= va - if (vc > vb) { - // vc > vb >= va - pivotIndex = b; - pivotValue = vb; - } else { - // vb >= va, vb >= vc - if (vc < va) { - // vb >= va > vc - pivotIndex = a; - pivotValue = va; - } else { - // vb >= vc >= va - pivotIndex = c; - pivotValue = vc; - } - } - } - - // move pivot to beginning of array - swap(start, pivotIndex, key, values); - - // we use a three way partition because many duplicate values is an important case - - int low = start + 1; // low points to first value not known to be equal to pivotValue - int high = end; // high points to first value > pivotValue - int i = low; // i scans the array - while (i < high) { - // invariant: values[order[k]] == pivotValue for k in [0..low) - // invariant: values[order[k]] < pivotValue for k in [low..i) - // invariant: values[order[k]] > pivotValue for k in [high..end) - // in-loop: i < high - // in-loop: low < high - // in-loop: i >= low - double vi = key[i]; - if (vi == pivotValue) { - if (low != i) { - swap(low, i, key, values); - } else { - i++; - } - low++; - } else if (vi > pivotValue) { - high--; - swap(i, high, key, values); - } else { - // vi < pivotValue - i++; - } - } - // invariant: values[order[k]] == pivotValue for k in [0..low) - // invariant: values[order[k]] < pivotValue for k in [low..i) - // invariant: values[order[k]] > pivotValue for k in [high..end) - // assert i == high || low == high therefore, we are done with partition - - // at this point, i==high, from [start,low) are == pivot, [low,high) are < and [high,end) are > - // we have to move the values equal to the pivot into the middle. To do this, we swap pivot - // values into the top end of the [low,high) range stopping when we run out of destinations - // or when we run out of values to copy - int from = start; - int to = high - 1; - for (i = 0; from < low && to >= low; i++) { - swap(from++, to--, key, values); - } - if (from == low) { - // ran out of things to copy. This means that the last destination is the boundary - low = to + 1; - } else { - // ran out of places to copy to. This means that there are uncopied pivots and the - // boundary is at the beginning of those - low = from; - } - - // checkPartition(order, values, pivotValue, start, low, high, end); - - // now recurse, but arrange it so we handle the longer limit by tail recursion - if (low - start < end - high) { - quickSort(key, values, start, low, limit); - - // this is really a way to do - // quickSort(order, values, high, end, limit); - start = high; - } else { - quickSort(key, values, high, end, limit); - // this is really a way to do - // quickSort(order, values, start, low, limit); - end = low; - } - } - } - - /** - * Limited range insertion sort. We assume that no element has to move more than limit steps - * because quick sort has done its thing. This version works on parallel arrays of keys and values. - * - * @param key The array of keys - * @param values The values we are sorting - * @param start The starting point of the sort - * @param end The ending point of the sort - * @param limit The largest amount of disorder - */ - private static void insertionSort(double[] key, double[][] values, int start, int end, int limit) { - // loop invariant: all values start ... i-1 are ordered - for (int i = start + 1; i < end; i++) { - double v = key[i]; - int m = Math.max(i - limit, start); - for (int j = i; j >= m; j--) { - if (j == m || key[j - 1] <= v) { - if (j < i) { - System.arraycopy(key, j, key, j + 1, i - j); - key[j] = v; - for (double[] value : values) { - double tmp = value[i]; - System.arraycopy(value, j, value, j + 1, i - j); - value[j] = tmp; - } - } - break; - } - } - } - } - - private static void swap(int[] order, int i, int j) { - int t = order[i]; - order[i] = order[j]; - order[j] = t; - } - - private static void swap(int i, int j, double[] key, double[]... values) { - double t = key[i]; - key[i] = key[j]; - key[j] = t; - - for (int k = 0; k < values.length; k++) { - t = values[k][i]; - values[k][i] = values[k][j]; - values[k][j] = t; - } - } - - /** - * Limited range insertion sort with primary and secondary key. We assume that no - * element has to move more than limit steps because quick sort has done its thing. - * - * If weights (the secondary key) is null, then only the primary key is used. - * - * This sort is inherently stable. - * - * @param order The permutation index - * @param values The values we are sorting - * @param weights The secondary key for sorting - * @param start Where to start the sort - * @param n How many elements to sort - * @param limit The largest amount of disorder - */ - private static void insertionSort(int[] order, double[] values, double[] weights, int start, int n, int limit) { - for (int i = start + 1; i < n; i++) { - int t = order[i]; - double v = values[order[i]]; - double w = weights == null ? 0 : weights[order[i]]; - int m = Math.max(i - limit, start); - // values in [start, i) are ordered - // scan backwards to find where to stick t - for (int j = i; j >= m; j--) { - if (j == 0 || values[order[j - 1]] < v || (values[order[j - 1]] == v && (weights == null || weights[order[j - 1]] <= w))) { - if (j < i) { - System.arraycopy(order, j, order, j + 1, i - j); - order[j] = t; - } - break; - } - } - } + private static void swap(TDigestIntArray order, int i, int j) { + int t = order.get(i); + order.set(i, order.get(j)); + order.set(j, t); } /** @@ -587,19 +160,19 @@ private static void insertionSort(int[] order, double[] values, double[] weights * @param n How many elements to sort * @param limit The largest amount of disorder */ - private static void stableInsertionSort(int[] order, double[] values, int start, int n, int limit) { + private static void stableInsertionSort(TDigestIntArray order, TDigestDoubleArray values, int start, int n, int limit) { for (int i = start + 1; i < n; i++) { - int t = order[i]; - double v = values[order[i]]; - int vi = order[i]; + int t = order.get(i); + double v = values.get(order.get(i)); + int vi = order.get(i); int m = Math.max(i - limit, start); // values in [start, i) are ordered // scan backwards to find where to stick t for (int j = i; j >= m; j--) { - if (j == 0 || values[order[j - 1]] < v || (values[order[j - 1]] == v && (order[j - 1] <= vi))) { + if (j == 0 || values.get(order.get(j - 1)) < v || (values.get(order.get(j - 1)) == v && (order.get(j - 1) <= vi))) { if (j < i) { - System.arraycopy(order, j, order, j + 1, i - j); - order[j] = t; + order.set(j + 1, order, j, i - j); + order.set(j, t); } break; } @@ -608,41 +181,32 @@ private static void stableInsertionSort(int[] order, double[] values, int start, } /** - * Reverses an array in-place. - * - * @param order The array to reverse - */ - public static void reverse(int[] order) { - reverse(order, 0, order.length); - } - - /** - * Reverses part of an array. See {@link #reverse(int[])} + * Reverses part of an array. * * @param order The array containing the data to reverse. * @param offset Where to start reversing. * @param length How many elements to reverse */ - public static void reverse(int[] order, int offset, int length) { + public static void reverse(TDigestIntArray order, int offset, int length) { for (int i = 0; i < length / 2; i++) { - int t = order[offset + i]; - order[offset + i] = order[offset + length - i - 1]; - order[offset + length - i - 1] = t; + int t = order.get(offset + i); + order.set(offset + i, order.get(offset + length - i - 1)); + order.set(offset + length - i - 1, t); } } /** - * Reverses part of an array. See {@link #reverse(int[])} + * Reverses part of an array. * * @param order The array containing the data to reverse. * @param offset Where to start reversing. * @param length How many elements to reverse */ - public static void reverse(double[] order, int offset, int length) { + public static void reverse(TDigestDoubleArray order, int offset, int length) { for (int i = 0; i < length / 2; i++) { - double t = order[offset + i]; - order[offset + i] = order[offset + length - i - 1]; - order[offset + length - i - 1] = t; + double t = order.get(offset + i); + order.set(offset + i, order.get(offset + length - i - 1)); + order.set(offset + length - i - 1, t); } } } diff --git a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/SortingDigest.java b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/SortingDigest.java index 92f770cbb7569..94b5c667e0672 100644 --- a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/SortingDigest.java +++ b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/SortingDigest.java @@ -19,10 +19,11 @@ package org.elasticsearch.tdigest; +import org.elasticsearch.tdigest.arrays.TDigestArrays; +import org.elasticsearch.tdigest.arrays.TDigestDoubleArray; + import java.util.AbstractCollection; -import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; import java.util.Iterator; /** @@ -31,17 +32,20 @@ * samples, at the expense of allocating much more memory. */ public class SortingDigest extends AbstractTDigest { - // Tracks all samples. Gets sorted on quantile and cdf calls. - final ArrayList values = new ArrayList<>(); + final TDigestDoubleArray values; // Indicates if all values have been sorted. private boolean isSorted = true; + public SortingDigest(TDigestArrays arrays) { + values = arrays.newDoubleArray(0); + } + @Override public void add(double x, long w) { checkValue(x); - isSorted = isSorted && (values.isEmpty() || values.get(values.size() - 1) <= x); + isSorted = isSorted && (values.size() == 0 || values.get(values.size() - 1) <= x); for (int i = 0; i < w; i++) { values.add(x); } @@ -52,7 +56,7 @@ public void add(double x, long w) { @Override public void compress() { if (isSorted == false) { - Collections.sort(values); + values.sort(); isSorted = true; } } diff --git a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/TDigest.java b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/TDigest.java index 296ed57a4d960..4e79f9e68cd02 100644 --- a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/TDigest.java +++ b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/TDigest.java @@ -21,6 +21,8 @@ package org.elasticsearch.tdigest; +import org.elasticsearch.tdigest.arrays.TDigestArrays; + import java.util.Collection; import java.util.Locale; @@ -48,8 +50,8 @@ public abstract class TDigest { * The number of centroids retained will be a smallish (usually less than 10) multiple of this number. * @return the MergingDigest */ - public static TDigest createMergingDigest(double compression) { - return new MergingDigest(compression); + public static TDigest createMergingDigest(TDigestArrays arrays, double compression) { + return new MergingDigest(arrays, compression); } /** @@ -61,8 +63,8 @@ public static TDigest createMergingDigest(double compression) { * The number of centroids retained will be a smallish (usually less than 10) multiple of this number. * @return the AvlTreeDigest */ - public static TDigest createAvlTreeDigest(double compression) { - return new AVLTreeDigest(compression); + public static TDigest createAvlTreeDigest(TDigestArrays arrays, double compression) { + return new AVLTreeDigest(arrays, compression); } /** @@ -71,8 +73,8 @@ public static TDigest createAvlTreeDigest(double compression) { * * @return the SortingDigest */ - public static TDigest createSortingDigest() { - return new SortingDigest(); + public static TDigest createSortingDigest(TDigestArrays arrays) { + return new SortingDigest(arrays); } /** @@ -84,8 +86,8 @@ public static TDigest createSortingDigest() { * The number of centroids retained will be a smallish (usually less than 10) multiple of this number. * @return the HybridDigest */ - public static TDigest createHybridDigest(double compression) { - return new HybridDigest(compression); + public static TDigest createHybridDigest(TDigestArrays arrays, double compression) { + return new HybridDigest(arrays, compression); } /** diff --git a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/arrays/TDigestArrays.java b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/arrays/TDigestArrays.java new file mode 100644 index 0000000000000..5e15c4c82f796 --- /dev/null +++ b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/arrays/TDigestArrays.java @@ -0,0 +1,35 @@ +/* + * Licensed to Elasticsearch B.V. under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch B.V. licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * This project is based on a modification of https://github.com/tdunning/t-digest which is licensed under the Apache 2.0 License. + */ + +package org.elasticsearch.tdigest.arrays; + +/** + * Minimal interface for BigArrays-like classes used within TDigest. + */ +public interface TDigestArrays { + TDigestDoubleArray newDoubleArray(int initialSize); + + TDigestIntArray newIntArray(int initialSize); + + TDigestLongArray newLongArray(int initialSize); + + TDigestByteArray newByteArray(int initialSize); +} diff --git a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/arrays/TDigestByteArray.java b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/arrays/TDigestByteArray.java new file mode 100644 index 0000000000000..481dde9784008 --- /dev/null +++ b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/arrays/TDigestByteArray.java @@ -0,0 +1,38 @@ +/* + * Licensed to Elasticsearch B.V. under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch B.V. licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * This project is based on a modification of https://github.com/tdunning/t-digest which is licensed under the Apache 2.0 License. + */ + +package org.elasticsearch.tdigest.arrays; + +/** + * Minimal interface for ByteArray-like classes used within TDigest. + */ +public interface TDigestByteArray { + int size(); + + byte get(int index); + + void set(int index, byte value); + + /** + * Resizes the array. If the new size is bigger than the current size, the new elements are set to 0. + */ + void resize(int newSize); +} diff --git a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/arrays/TDigestDoubleArray.java b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/arrays/TDigestDoubleArray.java new file mode 100644 index 0000000000000..92530db5e7dc4 --- /dev/null +++ b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/arrays/TDigestDoubleArray.java @@ -0,0 +1,58 @@ +/* + * Licensed to Elasticsearch B.V. under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch B.V. licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * This project is based on a modification of https://github.com/tdunning/t-digest which is licensed under the Apache 2.0 License. + */ + +package org.elasticsearch.tdigest.arrays; + +/** + * Minimal interface for DoubleArray-like classes used within TDigest. + */ +public interface TDigestDoubleArray { + int size(); + + double get(int index); + + void set(int index, double value); + + void add(double value); + + void ensureCapacity(int requiredCapacity); + + /** + * Resizes the array. If the new size is bigger than the current size, the new elements are set to 0. + */ + void resize(int newSize); + + /** + * Copies {@code len} elements from {@code buf} to this array. + */ + default void set(int index, TDigestDoubleArray buf, int offset, int len) { + assert index >= 0 && index + len <= this.size(); + assert buf != this : "This method doesn't ensure that the copy from itself will be correct"; + for (int i = len - 1; i >= 0; i--) { + this.set(index + i, buf.get(offset + i)); + } + } + + /** + * Sorts the array in place in ascending order. + */ + void sort(); +} diff --git a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/arrays/TDigestIntArray.java b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/arrays/TDigestIntArray.java new file mode 100644 index 0000000000000..c944a4f8faf07 --- /dev/null +++ b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/arrays/TDigestIntArray.java @@ -0,0 +1,53 @@ +/* + * Licensed to Elasticsearch B.V. under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch B.V. licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * This project is based on a modification of https://github.com/tdunning/t-digest which is licensed under the Apache 2.0 License. + */ + +package org.elasticsearch.tdigest.arrays; + +/** + * Minimal interface for IntArray-like classes used within TDigest. + */ +public interface TDigestIntArray { + int size(); + + int get(int index); + + void set(int index, int value); + + /** + * Resizes the array. If the new size is bigger than the current size, the new elements are set to 0. + */ + void resize(int newSize); + + /** + * Copies {@code len} elements from {@code buf} to this array. + *

+ * As this method will be used to insert elements from itself in an insertion sort, + * the copy must be made in reverse order, from offset+len-1 to offset. + *

+ */ + default void set(int index, TDigestIntArray buf, int offset, int len) { + assert index >= 0 && index + len <= this.size(); + assert buf != this || index >= offset : "To set to itself, the destination index must be greater than the source offset"; + for (int i = len - 1; i >= 0; i--) { + this.set(index + i, buf.get(offset + i)); + } + } +} diff --git a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/arrays/TDigestLongArray.java b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/arrays/TDigestLongArray.java new file mode 100644 index 0000000000000..7e75dd512e86d --- /dev/null +++ b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/arrays/TDigestLongArray.java @@ -0,0 +1,38 @@ +/* + * Licensed to Elasticsearch B.V. under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch B.V. licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * This project is based on a modification of https://github.com/tdunning/t-digest which is licensed under the Apache 2.0 License. + */ + +package org.elasticsearch.tdigest.arrays; + +/** + * Minimal interface for LongArray-like classes used within TDigest. + */ +public interface TDigestLongArray { + int size(); + + long get(int index); + + void set(int index, long value); + + /** + * Resizes the array. If the new size is bigger than the current size, the new elements are set to 0. + */ + void resize(int newSize); +} diff --git a/libs/tdigest/src/main/java/org/elasticsearch/tdigest/arrays/WrapperTDigestArrays.java b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/arrays/WrapperTDigestArrays.java new file mode 100644 index 0000000000000..ce2dd4f8d8e1d --- /dev/null +++ b/libs/tdigest/src/main/java/org/elasticsearch/tdigest/arrays/WrapperTDigestArrays.java @@ -0,0 +1,258 @@ +/* + * Licensed to Elasticsearch B.V. under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch B.V. licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * This project is based on a modification of https://github.com/tdunning/t-digest which is licensed under the Apache 2.0 License. + */ + +package org.elasticsearch.tdigest.arrays; + +import java.util.Arrays; + +/** + * Temporal TDigestArrays with raw arrays. + * + *

+ * Delete after the right implementation for BigArrays is made. + *

+ */ +public class WrapperTDigestArrays implements TDigestArrays { + + public static final WrapperTDigestArrays INSTANCE = new WrapperTDigestArrays(); + + private WrapperTDigestArrays() {} + + @Override + public WrapperTDigestDoubleArray newDoubleArray(int initialCapacity) { + return new WrapperTDigestDoubleArray(initialCapacity); + } + + @Override + public WrapperTDigestIntArray newIntArray(int initialSize) { + return new WrapperTDigestIntArray(initialSize); + } + + @Override + public TDigestLongArray newLongArray(int initialSize) { + return new WrapperTDigestLongArray(initialSize); + } + + @Override + public TDigestByteArray newByteArray(int initialSize) { + return new WrapperTDigestByteArray(initialSize); + } + + public WrapperTDigestDoubleArray newDoubleArray(double[] array) { + return new WrapperTDigestDoubleArray(array); + } + + public WrapperTDigestIntArray newIntArray(int[] array) { + return new WrapperTDigestIntArray(array); + } + + public static class WrapperTDigestDoubleArray implements TDigestDoubleArray { + private double[] array; + private int size; + + public WrapperTDigestDoubleArray(int initialSize) { + this(new double[initialSize]); + } + + public WrapperTDigestDoubleArray(double[] array) { + this.array = array; + this.size = array.length; + } + + @Override + public int size() { + return size; + } + + @Override + public double get(int index) { + assert index >= 0 && index < size; + return array[index]; + } + + @Override + public void set(int index, double value) { + assert index >= 0 && index < size; + array[index] = value; + } + + @Override + public void add(double value) { + ensureCapacity(size + 1); + array[size++] = value; + } + + @Override + public void sort() { + Arrays.sort(array, 0, size); + } + + @Override + public void ensureCapacity(int requiredCapacity) { + if (requiredCapacity > array.length) { + int newSize = array.length + (array.length >> 1); + if (newSize < requiredCapacity) { + newSize = requiredCapacity; + } + double[] newArray = new double[newSize]; + System.arraycopy(array, 0, newArray, 0, size); + array = newArray; + } + } + + @Override + public void resize(int newSize) { + if (newSize > array.length) { + array = Arrays.copyOf(array, newSize); + } + if (newSize > size) { + Arrays.fill(array, size, newSize, 0); + } + size = newSize; + } + } + + public static class WrapperTDigestIntArray implements TDigestIntArray { + private int[] array; + private int size; + + public WrapperTDigestIntArray(int initialSize) { + this(new int[initialSize]); + } + + public WrapperTDigestIntArray(int[] array) { + this.array = array; + this.size = array.length; + } + + @Override + public int size() { + return size; + } + + @Override + public int get(int index) { + assert index >= 0 && index < size; + return array[index]; + } + + @Override + public void set(int index, int value) { + assert index >= 0 && index < size; + array[index] = value; + } + + @Override + public void resize(int newSize) { + if (newSize > array.length) { + array = Arrays.copyOf(array, newSize); + } + if (newSize > size) { + Arrays.fill(array, size, newSize, 0); + } + size = newSize; + } + } + + public static class WrapperTDigestLongArray implements TDigestLongArray { + private long[] array; + private int size; + + public WrapperTDigestLongArray(int initialSize) { + this(new long[initialSize]); + } + + public WrapperTDigestLongArray(long[] array) { + this.array = array; + this.size = array.length; + } + + @Override + public int size() { + return size; + } + + @Override + public long get(int index) { + assert index >= 0 && index < size; + return array[index]; + } + + @Override + public void set(int index, long value) { + assert index >= 0 && index < size; + array[index] = value; + } + + @Override + public void resize(int newSize) { + if (newSize > array.length) { + array = Arrays.copyOf(array, newSize); + } + if (newSize > size) { + Arrays.fill(array, size, newSize, 0); + } + size = newSize; + } + } + + public static class WrapperTDigestByteArray implements TDigestByteArray { + private byte[] array; + private int size; + + public WrapperTDigestByteArray(int initialSize) { + this(new byte[initialSize]); + } + + public WrapperTDigestByteArray(byte[] array) { + this.array = array; + this.size = array.length; + } + + @Override + public int size() { + return size; + } + + @Override + public byte get(int index) { + assert index >= 0 && index < size; + return array[index]; + } + + @Override + public void set(int index, byte value) { + assert index >= 0 && index < size; + array[index] = value; + } + + @Override + public void resize(int newSize) { + if (newSize > array.length) { + array = Arrays.copyOf(array, newSize); + } + if (newSize > size) { + Arrays.fill(array, size, newSize, (byte) 0); + } + size = newSize; + } + } +} diff --git a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/AVLGroupTreeTests.java b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/AVLGroupTreeTests.java index 972a9e3b36878..71be849f401f4 100644 --- a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/AVLGroupTreeTests.java +++ b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/AVLGroupTreeTests.java @@ -21,12 +21,13 @@ package org.elasticsearch.tdigest; +import org.elasticsearch.tdigest.arrays.WrapperTDigestArrays; import org.elasticsearch.test.ESTestCase; public class AVLGroupTreeTests extends ESTestCase { public void testSimpleAdds() { - AVLGroupTree x = new AVLGroupTree(); + AVLGroupTree x = new AVLGroupTree(WrapperTDigestArrays.INSTANCE); assertEquals(IntAVLTree.NIL, x.floor(34)); assertEquals(IntAVLTree.NIL, x.first()); assertEquals(IntAVLTree.NIL, x.last()); @@ -45,7 +46,7 @@ public void testSimpleAdds() { } public void testBalancing() { - AVLGroupTree x = new AVLGroupTree(); + AVLGroupTree x = new AVLGroupTree(WrapperTDigestArrays.INSTANCE); for (int i = 0; i < 101; i++) { x.add(new Centroid(i)); } @@ -59,7 +60,7 @@ public void testBalancing() { public void testFloor() { // mostly tested in other tests - AVLGroupTree x = new AVLGroupTree(); + AVLGroupTree x = new AVLGroupTree(WrapperTDigestArrays.INSTANCE); for (int i = 0; i < 101; i++) { x.add(new Centroid(i / 2)); } @@ -72,7 +73,7 @@ public void testFloor() { } public void testHeadSum() { - AVLGroupTree x = new AVLGroupTree(); + AVLGroupTree x = new AVLGroupTree(WrapperTDigestArrays.INSTANCE); for (int i = 0; i < 1000; ++i) { x.add(randomDouble(), randomIntBetween(1, 10)); } @@ -87,7 +88,7 @@ public void testHeadSum() { } public void testFloorSum() { - AVLGroupTree x = new AVLGroupTree(); + AVLGroupTree x = new AVLGroupTree(WrapperTDigestArrays.INSTANCE); int total = 0; for (int i = 0; i < 1000; ++i) { int count = randomIntBetween(1, 10); diff --git a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/AVLTreeDigestTests.java b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/AVLTreeDigestTests.java index 7fd3e58da04f9..3cd89de4746f1 100644 --- a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/AVLTreeDigestTests.java +++ b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/AVLTreeDigestTests.java @@ -21,11 +21,13 @@ package org.elasticsearch.tdigest; +import org.elasticsearch.tdigest.arrays.WrapperTDigestArrays; + public class AVLTreeDigestTests extends TDigestTests { protected DigestFactory factory(final double compression) { return () -> { - AVLTreeDigest digest = new AVLTreeDigest(compression); + AVLTreeDigest digest = new AVLTreeDigest(WrapperTDigestArrays.INSTANCE, compression); digest.setRandomSeed(randomLong()); return digest; }; diff --git a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/AlternativeMergeTests.java b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/AlternativeMergeTests.java index b9d36f4c945fe..4b95e9c0ee695 100644 --- a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/AlternativeMergeTests.java +++ b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/AlternativeMergeTests.java @@ -21,6 +21,7 @@ package org.elasticsearch.tdigest; +import org.elasticsearch.tdigest.arrays.WrapperTDigestArrays; import org.elasticsearch.test.ESTestCase; import java.util.ArrayList; @@ -36,8 +37,8 @@ public class AlternativeMergeTests extends ESTestCase { public void testMerges() { for (int n : new int[] { 100, 1000, 10000, 100000 }) { for (double compression : new double[] { 50, 100, 200, 400 }) { - MergingDigest mergingDigest = new MergingDigest(compression); - AVLTreeDigest treeDigest = new AVLTreeDigest(compression); + MergingDigest mergingDigest = new MergingDigest(WrapperTDigestArrays.INSTANCE, compression); + AVLTreeDigest treeDigest = new AVLTreeDigest(WrapperTDigestArrays.INSTANCE, compression); List data = new ArrayList<>(); Random gen = random(); for (int i = 0; i < n; i++) { diff --git a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/BigCountTestsMergingDigestTests.java b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/BigCountTestsMergingDigestTests.java index d11cc76820823..25cd1af05a0ba 100644 --- a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/BigCountTestsMergingDigestTests.java +++ b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/BigCountTestsMergingDigestTests.java @@ -21,9 +21,11 @@ package org.elasticsearch.tdigest; +import org.elasticsearch.tdigest.arrays.WrapperTDigestArrays; + public class BigCountTestsMergingDigestTests extends BigCountTests { @Override public TDigest createDigest() { - return new MergingDigest(100); + return new MergingDigest(WrapperTDigestArrays.INSTANCE, 100); } } diff --git a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/BigCountTestsTreeDigestTests.java b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/BigCountTestsTreeDigestTests.java index 765b7c98d7df4..a2cdf49d8f8ad 100644 --- a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/BigCountTestsTreeDigestTests.java +++ b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/BigCountTestsTreeDigestTests.java @@ -21,9 +21,11 @@ package org.elasticsearch.tdigest; +import org.elasticsearch.tdigest.arrays.WrapperTDigestArrays; + public class BigCountTestsTreeDigestTests extends BigCountTests { @Override public TDigest createDigest() { - return new AVLTreeDigest(100); + return new AVLTreeDigest(WrapperTDigestArrays.INSTANCE, 100); } } diff --git a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/ComparisonTests.java b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/ComparisonTests.java index 61a546fe3dd3d..f5df0c2f86ea1 100644 --- a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/ComparisonTests.java +++ b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/ComparisonTests.java @@ -21,6 +21,7 @@ package org.elasticsearch.tdigest; +import org.elasticsearch.tdigest.arrays.WrapperTDigestArrays; import org.elasticsearch.test.ESTestCase; import java.util.Arrays; @@ -39,10 +40,10 @@ public class ComparisonTests extends ESTestCase { private void loadData(Supplier sampleGenerator) { final int COMPRESSION = 100; - avlTreeDigest = TDigest.createAvlTreeDigest(COMPRESSION); - mergingDigest = TDigest.createMergingDigest(COMPRESSION); - sortingDigest = TDigest.createSortingDigest(); - hybridDigest = TDigest.createHybridDigest(COMPRESSION); + avlTreeDigest = TDigest.createAvlTreeDigest(WrapperTDigestArrays.INSTANCE, COMPRESSION); + mergingDigest = TDigest.createMergingDigest(WrapperTDigestArrays.INSTANCE, COMPRESSION); + sortingDigest = TDigest.createSortingDigest(WrapperTDigestArrays.INSTANCE); + hybridDigest = TDigest.createHybridDigest(WrapperTDigestArrays.INSTANCE, COMPRESSION); samples = new double[SAMPLE_COUNT]; for (int i = 0; i < SAMPLE_COUNT; i++) { diff --git a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/HybridDigestTests.java b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/HybridDigestTests.java index 019dbdd830182..01b3dc8f5da2a 100644 --- a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/HybridDigestTests.java +++ b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/HybridDigestTests.java @@ -24,6 +24,6 @@ public class HybridDigestTests extends TDigestTests { protected DigestFactory factory(final double compression) { - return () -> new HybridDigest(compression); + return () -> new HybridDigest(arrays(), compression); } } diff --git a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/IntAVLTreeTests.java b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/IntAVLTreeTests.java index 733639978593f..58c91ae6e03e6 100644 --- a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/IntAVLTreeTests.java +++ b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/IntAVLTreeTests.java @@ -21,6 +21,7 @@ package org.elasticsearch.tdigest; +import org.elasticsearch.tdigest.arrays.WrapperTDigestArrays; import org.elasticsearch.test.ESTestCase; import java.util.Arrays; @@ -38,6 +39,7 @@ static class IntegerBag extends IntAVLTree { int[] counts; IntegerBag() { + super(WrapperTDigestArrays.INSTANCE); values = new int[capacity()]; counts = new int[capacity()]; } diff --git a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/MedianTests.java b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/MedianTests.java index b95c81b3c6144..dd455b307344e 100644 --- a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/MedianTests.java +++ b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/MedianTests.java @@ -21,13 +21,14 @@ package org.elasticsearch.tdigest; +import org.elasticsearch.tdigest.arrays.WrapperTDigestArrays; import org.elasticsearch.test.ESTestCase; public class MedianTests extends ESTestCase { public void testAVL() { double[] data = new double[] { 7, 15, 36, 39, 40, 41 }; - TDigest digest = new AVLTreeDigest(100); + TDigest digest = new AVLTreeDigest(WrapperTDigestArrays.INSTANCE, 100); for (double value : data) { digest.add(value); } @@ -38,7 +39,7 @@ public void testAVL() { public void testMergingDigest() { double[] data = new double[] { 7, 15, 36, 39, 40, 41 }; - TDigest digest = new MergingDigest(100); + TDigest digest = new MergingDigest(WrapperTDigestArrays.INSTANCE, 100); for (double value : data) { digest.add(value); } @@ -49,7 +50,7 @@ public void testMergingDigest() { public void testSortingDigest() { double[] data = new double[] { 7, 15, 36, 39, 40, 41 }; - TDigest digest = new SortingDigest(); + TDigest digest = new SortingDigest(WrapperTDigestArrays.INSTANCE); for (double value : data) { digest.add(value); } @@ -60,7 +61,7 @@ public void testSortingDigest() { public void testHybridDigest() { double[] data = new double[] { 7, 15, 36, 39, 40, 41 }; - TDigest digest = new HybridDigest(100); + TDigest digest = new HybridDigest(WrapperTDigestArrays.INSTANCE, 100); for (double value : data) { digest.add(value); } diff --git a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/MergingDigestTests.java b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/MergingDigestTests.java index 9fadf2218f203..263d0fe920208 100644 --- a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/MergingDigestTests.java +++ b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/MergingDigestTests.java @@ -33,18 +33,19 @@ public class MergingDigestTests extends TDigestTests { protected DigestFactory factory(final double compression) { - return () -> new MergingDigest(compression); + + return () -> new MergingDigest(arrays(), compression); } public void testNanDueToBadInitialization() { int compression = 100; int factor = 5; - MergingDigest md = new MergingDigest(compression, (factor + 1) * compression, compression); + MergingDigest md = new MergingDigest(arrays(), compression, (factor + 1) * compression, compression); final int M = 10; List mds = new ArrayList<>(); for (int i = 0; i < M; ++i) { - mds.add(new MergingDigest(compression, (factor + 1) * compression, compression)); + mds.add(new MergingDigest(arrays(), compression, (factor + 1) * compression, compression)); } // Fill all digests with values (0,10,20,...,80). @@ -107,7 +108,7 @@ public void testSingleMultiRange() { * Make sure that the first and last centroids have unit weight */ public void testSingletonsAtEnds() { - TDigest d = new MergingDigest(50); + TDigest d = new MergingDigest(arrays(), 50); Random gen = random(); double[] data = new double[100]; for (int i = 0; i < data.length; i++) { @@ -132,7 +133,7 @@ public void testSingletonsAtEnds() { * Verify centroid sizes. */ public void testFill() { - MergingDigest x = new MergingDigest(300); + MergingDigest x = new MergingDigest(arrays(), 300); Random gen = random(); ScaleFunction scale = x.getScaleFunction(); double compression = x.compression(); @@ -153,7 +154,7 @@ public void testFill() { } public void testLargeInputSmallCompression() { - MergingDigest td = new MergingDigest(10); + MergingDigest td = new MergingDigest(arrays(), 10); for (int i = 0; i < 10_000_000; i++) { td.add(between(0, 3_600_000)); } diff --git a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/SortTests.java b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/SortTests.java index 7b0d867d21205..7327dfb5aac3c 100644 --- a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/SortTests.java +++ b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/SortTests.java @@ -21,58 +21,58 @@ package org.elasticsearch.tdigest; +import org.elasticsearch.tdigest.arrays.TDigestIntArray; +import org.elasticsearch.tdigest.arrays.WrapperTDigestArrays; import org.elasticsearch.test.ESTestCase; import java.util.HashMap; -import java.util.Locale; import java.util.Map; import java.util.Random; public class SortTests extends ESTestCase { - public void testReverse() { - int[] x = new int[0]; + TDigestIntArray x = WrapperTDigestArrays.INSTANCE.newIntArray(0); // don't crash with no input - Sort.reverse(x); + Sort.reverse(x, 0, x.size()); // reverse stuff! - x = new int[] { 1, 2, 3, 4, 5 }; - Sort.reverse(x); + x = WrapperTDigestArrays.INSTANCE.newIntArray(new int[] { 1, 2, 3, 4, 5 }); + Sort.reverse(x, 0, x.size()); for (int i = 0; i < 5; i++) { - assertEquals(5 - i, x[i]); + assertEquals(5 - i, x.get(i)); } // reverse some stuff back Sort.reverse(x, 1, 3); - assertEquals(5, x[0]); - assertEquals(2, x[1]); - assertEquals(3, x[2]); - assertEquals(4, x[3]); - assertEquals(1, x[4]); + assertEquals(5, x.get(0)); + assertEquals(2, x.get(1)); + assertEquals(3, x.get(2)); + assertEquals(4, x.get(3)); + assertEquals(1, x.get(4)); // another no-op Sort.reverse(x, 3, 0); - assertEquals(5, x[0]); - assertEquals(2, x[1]); - assertEquals(3, x[2]); - assertEquals(4, x[3]); - assertEquals(1, x[4]); - - x = new int[] { 1, 2, 3, 4, 5, 6 }; - Sort.reverse(x); + assertEquals(5, x.get(0)); + assertEquals(2, x.get(1)); + assertEquals(3, x.get(2)); + assertEquals(4, x.get(3)); + assertEquals(1, x.get(4)); + + x = WrapperTDigestArrays.INSTANCE.newIntArray(new int[] { 1, 2, 3, 4, 5, 6 }); + Sort.reverse(x, 0, x.size()); for (int i = 0; i < 6; i++) { - assertEquals(6 - i, x[i]); + assertEquals(6 - i, x.get(i)); } } public void testEmpty() { - Sort.sort(new int[] {}, new double[] {}, null, 0); + sort(new int[0], new double[0], 0); } public void testOne() { int[] order = new int[1]; - Sort.sort(order, new double[] { 1 }, new double[] { 1 }, 1); + sort(order, new double[] { 1 }, 1); assertEquals(0, order[0]); } @@ -80,7 +80,7 @@ public void testIdentical() { int[] order = new int[6]; double[] values = new double[6]; - Sort.sort(order, values, null, values.length); + sort(order, values, values.length); checkOrder(order, values); } @@ -92,58 +92,10 @@ public void testRepeated() { values[i] = Math.rint(10 * ((double) i / n)) / 10.0; } - Sort.sort(order, values, null, values.length); + sort(order, values, values.length); checkOrder(order, values); } - public void testRepeatedSortByWeight() { - // this needs to be long enough to force coverage of both quicksort and insertion sort - // (i.e. >64) - int n = 125; - int[] order = new int[n]; - double[] values = new double[n]; - double[] weights = new double[n]; - double totalWeight = 0; - - // generate evenly distributed values and weights - for (int i = 0; i < n; i++) { - int k = ((i + 5) * 37) % n; - values[i] = Math.floor(k / 25.0); - weights[i] = (k % 25) + 1; - totalWeight += weights[i]; - } - - // verify: test weights should be evenly distributed - double[] tmp = new double[5]; - for (int i = 0; i < n; i++) { - tmp[(int) values[i]] += weights[i]; - } - for (double v : tmp) { - assertEquals(totalWeight / tmp.length, v, 0); - } - - // now sort ... - Sort.sort(order, values, weights, n); - - // and verify our somewhat unusual ordering of the result - // within the first two quintiles, value is constant, weights increase within each quintile - int delta = order.length / 5; - double sum = checkSubOrder(0.0, order, values, weights, 0, delta, 1); - assertEquals(totalWeight * 0.2, sum, 0); - sum = checkSubOrder(sum, order, values, weights, delta, 2 * delta, 1); - assertEquals(totalWeight * 0.4, sum, 0); - - // in the middle quintile, weights go up and then down after the median - sum = checkMidOrder(totalWeight / 2, sum, order, values, weights, 2 * delta, 3 * delta); - assertEquals(totalWeight * 0.6, sum, 0); - - // in the last two quintiles, weights decrease - sum = checkSubOrder(sum, order, values, weights, 3 * delta, 4 * delta, -1); - assertEquals(totalWeight * 0.8, sum, 0); - sum = checkSubOrder(sum, order, values, weights, 4 * delta, 5 * delta, -1); - assertEquals(totalWeight, sum, 0); - } - public void testStableSort() { // this needs to be long enough to force coverage of both quicksort and insertion sort // (i.e. >64) @@ -172,7 +124,7 @@ public void testStableSort() { } // now sort ... - Sort.stableSort(order, values, n); + sort(order, values, n); // and verify stability of the ordering // values must be in order and they must appear in their original ordering @@ -184,37 +136,6 @@ public void testStableSort() { } } - private double checkMidOrder(double medianWeight, double sofar, int[] order, double[] values, double[] weights, int start, int end) { - double value = values[order[start]]; - double last = 0; - assertTrue(sofar < medianWeight); - for (int i = start; i < end; i++) { - assertEquals(value, values[order[i]], 0); - double w = weights[order[i]]; - assertTrue(w > 0); - if (sofar > medianWeight) { - w = 2 * medianWeight - w; - } - assertTrue(w >= last); - sofar += weights[order[i]]; - } - assertTrue(sofar > medianWeight); - return sofar; - } - - private double checkSubOrder(double sofar, int[] order, double[] values, double[] weights, int start, int end, int ordering) { - double lastWeight = weights[order[start]] * ordering; - double value = values[order[start]]; - for (int i = start; i < end; i++) { - assertEquals(value, values[order[i]], 0); - double newOrderedWeight = weights[order[i]] * ordering; - assertTrue(newOrderedWeight >= lastWeight); - lastWeight = newOrderedWeight; - sofar += weights[order[i]]; - } - return sofar; - } - public void testShort() { int[] order = new int[6]; double[] values = new double[6]; @@ -224,19 +145,19 @@ public void testShort() { values[i] = 1; } - Sort.sort(order, values, null, values.length); + sort(order, values, values.length); checkOrder(order, values); values[0] = 0.8; values[1] = 0.3; - Sort.sort(order, values, null, values.length); + sort(order, values, values.length); checkOrder(order, values); values[5] = 1.5; values[4] = 1.2; - Sort.sort(order, values, null, values.length); + sort(order, values, values.length); checkOrder(order, values); } @@ -246,7 +167,7 @@ public void testLonger() { for (int i = 0; i < 20; i++) { values[i] = (i * 13) % 20; } - Sort.sort(order, values, null, values.length); + sort(order, values, values.length); checkOrder(order, values); } @@ -270,35 +191,10 @@ public void testMultiPivots() { values[24] = 25; values[26] = 25; - Sort.sort(order, values, null, values.length); + sort(order, values, values.length); checkOrder(order, values); } - public void testMultiPivotsInPlace() { - // more pivots than low split on first pass - // multiple pivots, but more low data on second part of recursion - double[] keys = new double[30]; - for (int i = 0; i < 9; i++) { - keys[i] = i + 20 * (i % 2); - } - - for (int i = 9; i < 20; i++) { - keys[i] = 10; - } - - for (int i = 20; i < 30; i++) { - keys[i] = i - 20 * (i % 2); - } - keys[29] = 29; - keys[24] = 25; - keys[26] = 25; - - double[] v = valuesFromKeys(keys, 0); - - Sort.sort(keys, v); - checkOrder(keys, 0, keys.length, v); - } - public void testRandomized() { Random rand = random(); @@ -309,98 +205,11 @@ public void testRandomized() { values[i] = rand.nextDouble(); } - Sort.sort(order, values, null, values.length); + sort(order, values, values.length); checkOrder(order, values); } } - public void testRandomizedShortSort() { - Random rand = random(); - - for (int k = 0; k < 100; k++) { - double[] keys = new double[30]; - for (int i = 0; i < 10; i++) { - keys[i] = i; - } - for (int i = 10; i < 20; i++) { - keys[i] = rand.nextDouble(); - } - for (int i = 20; i < 30; i++) { - keys[i] = i; - } - double[] v0 = valuesFromKeys(keys, 0); - double[] v1 = valuesFromKeys(keys, 1); - - Sort.sort(keys, 10, 10, v0, v1); - checkOrder(keys, 10, 10, v0, v1); - checkValues(keys, 0, keys.length, v0, v1); - for (int i = 0; i < 10; i++) { - assertEquals(i, keys[i], 0); - } - for (int i = 20; i < 30; i++) { - assertEquals(i, keys[i], 0); - } - } - } - - /** - * Generates a vector of values corresponding to a vector of keys. - * - * @param keys A vector of keys - * @param k Which value vector to generate - * @return The new vector containing frac(key_i * 3 * 5^k) - */ - private double[] valuesFromKeys(double[] keys, int k) { - double[] r = new double[keys.length]; - double scale = 3; - for (int i = 0; i < k; i++) { - scale = scale * 5; - } - for (int i = 0; i < keys.length; i++) { - r[i] = fractionalPart(keys[i] * scale); - } - return r; - } - - /** - * Verifies that keys are in order and that each value corresponds to the keys - * - * @param key Array of keys - * @param start The starting offset of keys and values to check - * @param length The number of keys and values to check - * @param values Arrays of associated values. Value_{ki} = frac(key_i * 3 * 5^k) - */ - private void checkOrder(double[] key, int start, int length, double[]... values) { - assert start + length <= key.length; - - for (int i = start; i < start + length - 1; i++) { - assertTrue(String.format(Locale.ROOT, "bad ordering at %d, %f > %f", i, key[i], key[i + 1]), key[i] <= key[i + 1]); - } - - checkValues(key, start, length, values); - } - - private void checkValues(double[] key, int start, int length, double[]... values) { - double scale = 3; - for (int k = 0; k < values.length; k++) { - double[] v = values[k]; - assertEquals(key.length, v.length); - for (int i = start; i < length; i++) { - assertEquals( - String.format(Locale.ROOT, "value %d not correlated, key=%.5f, k=%d, v=%.5f", i, key[i], k, values[k][i]), - fractionalPart(key[i] * scale), - values[k][i], - 0 - ); - } - scale = scale * 5; - } - } - - private double fractionalPart(double v) { - return v - Math.floor(v); - } - private void checkOrder(int[] order, double[] values) { double previous = -Double.MAX_VALUE; Map counts = new HashMap(); @@ -418,4 +227,11 @@ private void checkOrder(int[] order, double[] values) { assertEquals(1, entry.getValue().intValue()); } } + + private void sort(int[] order, double[] values, int n) { + var wrappedOrder = WrapperTDigestArrays.INSTANCE.newIntArray(order); + var wrappedValues = WrapperTDigestArrays.INSTANCE.newDoubleArray(values); + + Sort.stableSort(wrappedOrder, wrappedValues, n); + } } diff --git a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/SortingDigestTests.java b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/SortingDigestTests.java index 1c1dbbfa28ae9..2478e85421f07 100644 --- a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/SortingDigestTests.java +++ b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/SortingDigestTests.java @@ -24,7 +24,7 @@ public class SortingDigestTests extends TDigestTests { protected DigestFactory factory(final double compression) { - return SortingDigest::new; + return () -> new SortingDigest(arrays()); } // Make this test a noop to avoid OOMs. diff --git a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/TDigestTests.java b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/TDigestTests.java index 815346100532c..43f1e36afb314 100644 --- a/libs/tdigest/src/test/java/org/elasticsearch/tdigest/TDigestTests.java +++ b/libs/tdigest/src/test/java/org/elasticsearch/tdigest/TDigestTests.java @@ -21,6 +21,8 @@ package org.elasticsearch.tdigest; +import org.elasticsearch.tdigest.arrays.TDigestArrays; +import org.elasticsearch.tdigest.arrays.WrapperTDigestArrays; import org.elasticsearch.test.ESTestCase; import java.util.ArrayList; @@ -542,4 +544,8 @@ public void testMonotonicity() { lastQuantile = q; } } + + protected static TDigestArrays arrays() { + return WrapperTDigestArrays.INSTANCE; + } } diff --git a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/DataStreamWithSecurityIT.java b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/DataStreamWithSecurityIT.java new file mode 100644 index 0000000000000..2ba373945ad50 --- /dev/null +++ b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/DataStreamWithSecurityIT.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.datastreams; + +import org.apache.http.HttpHost; +import org.elasticsearch.client.Request; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.FeatureFlag; +import org.elasticsearch.test.cluster.local.distribution.DistributionType; +import org.elasticsearch.test.cluster.util.resource.Resource; +import org.elasticsearch.test.rest.ESRestTestCase; +import org.junit.ClassRule; + +public class DataStreamWithSecurityIT extends ESRestTestCase { + + private static final String PASSWORD = "secret-test-password"; + private static final String DATA_STREAM_NAME = "my-ds"; + + @ClassRule + public static ElasticsearchCluster cluster = ElasticsearchCluster.local() + .distribution(DistributionType.DEFAULT) + .feature(FeatureFlag.FAILURE_STORE_ENABLED) + .setting("xpack.watcher.enabled", "false") + .setting("xpack.ml.enabled", "false") + .setting("xpack.security.enabled", "true") + .setting("xpack.security.transport.ssl.enabled", "false") + .setting("xpack.security.http.ssl.enabled", "false") + .user("test_admin", PASSWORD, "superuser", false) + .user("limited_user", PASSWORD, "only_get", false) + .rolesFile(Resource.fromClasspath("roles.yml")) + .build(); + + @Override + protected String getTestRestCluster() { + return cluster.getHttpAddresses(); + } + + @Override + protected Settings restClientSettings() { + // If this test is running in a test framework that handles its own authorization, we don't want to overwrite it. + if (super.restClientSettings().keySet().contains(ThreadContext.PREFIX + ".Authorization")) { + return super.restClientSettings(); + } else { + // Note: We use the admin user because the other one is too unprivileged, so it breaks the initialization of the test + String token = basicAuthHeaderValue("test_admin", new SecureString(PASSWORD.toCharArray())); + return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build(); + } + } + + private Settings simpleUserRestClientSettings() { + // Note: This user is assigned the role "only_get". That role is defined in roles.yml. + String token = basicAuthHeaderValue("limited_user", new SecureString(PASSWORD.toCharArray())); + return Settings.builder().put(super.restClientSettings()).put(ThreadContext.PREFIX + ".Authorization", token).build(); + } + + public void testGetDataStreamWithoutPermission() throws Exception { + Request putComposableIndexTemplateRequest = new Request("POST", "/_index_template/my-ds-template"); + putComposableIndexTemplateRequest.setJsonEntity(""" + { + "index_patterns": ["my-ds*"], + "data_stream": {} + } + """); + assertOK(adminClient().performRequest(putComposableIndexTemplateRequest)); + assertOK(adminClient().performRequest(new Request("PUT", "/_data_stream/" + DATA_STREAM_NAME))); + Request createDocRequest = new Request("POST", "/" + DATA_STREAM_NAME + "/_doc"); + createDocRequest.setJsonEntity("{ \"@timestamp\": \"2022-01-01\", \"message\": \"foo\" }"); + assertOK(adminClient().performRequest(createDocRequest)); + + // Both the verbose and non-verbose versions should work with the "simple" user + try (var simpleUserClient = buildClient(simpleUserRestClientSettings(), getClusterHosts().toArray(new HttpHost[0]))) { + Request getDs = new Request("GET", "/_data_stream"); + assertOK(simpleUserClient.performRequest(getDs)); + + Request getDsVerbose = new Request("GET", "/_data_stream?verbose=true"); + assertOK(simpleUserClient.performRequest(getDsVerbose)); + } + } + +} diff --git a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/LogsDataStreamRestIT.java b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/LogsDataStreamRestIT.java index f62fa83b4e111..f95815d1daff9 100644 --- a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/LogsDataStreamRestIT.java +++ b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/LogsDataStreamRestIT.java @@ -9,16 +9,23 @@ package org.elasticsearch.datastreams; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpPut; import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseException; import org.elasticsearch.client.RestClient; import org.elasticsearch.common.network.InetAddresses; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.time.DateFormatter; import org.elasticsearch.common.time.FormatNames; +import org.elasticsearch.index.IndexMode; +import org.elasticsearch.repositories.fs.FsRepository; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.ESRestTestCase; +import org.hamcrest.Matchers; import org.junit.Before; import org.junit.ClassRule; @@ -41,6 +48,7 @@ public class LogsDataStreamRestIT extends ESRestTestCase { public static ElasticsearchCluster cluster = ElasticsearchCluster.local() .distribution(DistributionType.DEFAULT) .setting("xpack.security.enabled", "false") + .setting("xpack.license.self_generated.type", "trial") .build(); @Override @@ -102,7 +110,7 @@ private static void waitForLogs(RestClient client) throws Exception { } }"""; - private static final String STANDARD_TEMPLATE = """ + private static final String LOGS_STANDARD_INDEX_MODE = """ { "index_patterns": [ "logs-*-*" ], "data_stream": {}, @@ -135,6 +143,39 @@ private static void waitForLogs(RestClient client) throws Exception { } }"""; + private static final String STANDARD_TEMPLATE = """ + { + "index_patterns": [ "standard-*-*" ], + "data_stream": {}, + "priority": 201, + "template": { + "settings": { + "index": { + "mode": "standard" + } + }, + "mappings": { + "properties": { + "@timestamp" : { + "type": "date" + }, + "host.name": { + "type": "keyword" + }, + "pid": { + "type": "long" + }, + "method": { + "type": "keyword" + }, + "ip_address": { + "type": "ip" + } + } + } + } + }"""; + private static final String TIME_SERIES_TEMPLATE = """ { "index_patterns": [ "logs-*-*" ], @@ -203,7 +244,7 @@ public void testLogsIndexing() throws IOException { randomLongBetween(1_000_000L, 2_000_000L) ) ); - assertDataStreamBackingIndexMode("logsdb", 0); + assertDataStreamBackingIndexMode("logsdb", 0, DATA_STREAM_NAME); rolloverDataStream(client, DATA_STREAM_NAME); indexDocument( client, @@ -218,7 +259,7 @@ public void testLogsIndexing() throws IOException { randomLongBetween(1_000_000L, 2_000_000L) ) ); - assertDataStreamBackingIndexMode("logsdb", 1); + assertDataStreamBackingIndexMode("logsdb", 1, DATA_STREAM_NAME); } public void testLogsStandardIndexModeSwitch() throws IOException { @@ -237,9 +278,9 @@ public void testLogsStandardIndexModeSwitch() throws IOException { randomLongBetween(1_000_000L, 2_000_000L) ) ); - assertDataStreamBackingIndexMode("logsdb", 0); + assertDataStreamBackingIndexMode("logsdb", 0, DATA_STREAM_NAME); - putTemplate(client, "custom-template", STANDARD_TEMPLATE); + putTemplate(client, "custom-template", LOGS_STANDARD_INDEX_MODE); rolloverDataStream(client, DATA_STREAM_NAME); indexDocument( client, @@ -254,7 +295,7 @@ public void testLogsStandardIndexModeSwitch() throws IOException { randomLongBetween(1_000_000L, 2_000_000L) ) ); - assertDataStreamBackingIndexMode("standard", 1); + assertDataStreamBackingIndexMode("standard", 1, DATA_STREAM_NAME); putTemplate(client, "custom-template", LOGS_TEMPLATE); rolloverDataStream(client, DATA_STREAM_NAME); @@ -271,7 +312,7 @@ public void testLogsStandardIndexModeSwitch() throws IOException { randomLongBetween(1_000_000L, 2_000_000L) ) ); - assertDataStreamBackingIndexMode("logsdb", 2); + assertDataStreamBackingIndexMode("logsdb", 2, DATA_STREAM_NAME); } public void testLogsTimeSeriesIndexModeSwitch() throws IOException { @@ -290,7 +331,7 @@ public void testLogsTimeSeriesIndexModeSwitch() throws IOException { randomLongBetween(1_000_000L, 2_000_000L) ) ); - assertDataStreamBackingIndexMode("logsdb", 0); + assertDataStreamBackingIndexMode("logsdb", 0, DATA_STREAM_NAME); putTemplate(client, "custom-template", TIME_SERIES_TEMPLATE); rolloverDataStream(client, DATA_STREAM_NAME); @@ -307,7 +348,7 @@ public void testLogsTimeSeriesIndexModeSwitch() throws IOException { randomLongBetween(1_000_000L, 2_000_000L) ) ); - assertDataStreamBackingIndexMode("time_series", 1); + assertDataStreamBackingIndexMode("time_series", 1, DATA_STREAM_NAME); putTemplate(client, "custom-template", LOGS_TEMPLATE); rolloverDataStream(client, DATA_STREAM_NAME); @@ -324,11 +365,193 @@ public void testLogsTimeSeriesIndexModeSwitch() throws IOException { randomLongBetween(1_000_000L, 2_000_000L) ) ); - assertDataStreamBackingIndexMode("logsdb", 2); + assertDataStreamBackingIndexMode("logsdb", 2, DATA_STREAM_NAME); + } + + public void testLogsDBToStandardReindex() throws IOException { + // LogsDB data stream + putTemplate(client, "logs-template", LOGS_TEMPLATE); + createDataStream(client, "logs-apache-kafka"); + + // Standard data stream + putTemplate(client, "standard-template", STANDARD_TEMPLATE); + createDataStream(client, "standard-apache-kafka"); + + // Index some documents in the LogsDB index + for (int i = 0; i < 10; i++) { + indexDocument( + client, + "logs-apache-kafka", + document( + Instant.now().plusSeconds(10), + randomAlphaOfLength(10), + randomNonNegativeLong(), + randomFrom("PUT", "POST", "GET"), + randomAlphaOfLength(64), + randomIp(randomBoolean()), + randomLongBetween(1_000_000L, 2_000_000L) + ) + ); + } + assertDataStreamBackingIndexMode("logsdb", 0, "logs-apache-kafka"); + assertDocCount(client, "logs-apache-kafka", 10); + + // Reindex a LogsDB data stream into a standard data stream + final Request reindexRequest = new Request("POST", "/_reindex?refresh=true"); + reindexRequest.setJsonEntity(""" + { + "source": { + "index": "logs-apache-kafka" + }, + "dest": { + "index": "standard-apache-kafka", + "op_type": "create" + } + } + """); + assertOK(client.performRequest(reindexRequest)); + assertDataStreamBackingIndexMode("standard", 0, "standard-apache-kafka"); + assertDocCount(client, "standard-apache-kafka", 10); + } + + public void testStandardToLogsDBReindex() throws IOException { + // LogsDB data stream + putTemplate(client, "logs-template", LOGS_TEMPLATE); + createDataStream(client, "logs-apache-kafka"); + + // Standard data stream + putTemplate(client, "standard-template", STANDARD_TEMPLATE); + createDataStream(client, "standard-apache-kafka"); + + // Index some documents in a standard index + for (int i = 0; i < 10; i++) { + indexDocument( + client, + "standard-apache-kafka", + document( + Instant.now().plusSeconds(10), + randomAlphaOfLength(10), + randomNonNegativeLong(), + randomFrom("PUT", "POST", "GET"), + randomAlphaOfLength(64), + randomIp(randomBoolean()), + randomLongBetween(1_000_000L, 2_000_000L) + ) + ); + } + assertDataStreamBackingIndexMode("standard", 0, "standard-apache-kafka"); + assertDocCount(client, "standard-apache-kafka", 10); + + // Reindex a standard data stream into a LogsDB data stream + final Request reindexRequest = new Request("POST", "/_reindex?refresh=true"); + reindexRequest.setJsonEntity(""" + { + "source": { + "index": "standard-apache-kafka" + }, + "dest": { + "index": "logs-apache-kafka", + "op_type": "create" + } + } + """); + assertOK(client.performRequest(reindexRequest)); + assertDataStreamBackingIndexMode("logsdb", 0, "logs-apache-kafka"); + assertDocCount(client, "logs-apache-kafka", 10); + } + + public void testLogsDBSnapshotCreateRestoreMount() throws IOException { + final String repository = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); + registerRepository(repository, FsRepository.TYPE, Settings.builder().put("location", randomAlphaOfLength(6))); + + final String index = randomAlphaOfLength(12).toLowerCase(Locale.ROOT); + createIndex(client, index, Settings.builder().put("index.mode", IndexMode.LOGSDB.getName()).build()); + + for (int i = 0; i < 10; i++) { + indexDocument( + client, + index, + document( + Instant.now().plusSeconds(10), + randomAlphaOfLength(10), + randomNonNegativeLong(), + randomFrom("PUT", "POST", "GET"), + randomAlphaOfLength(64), + randomIp(randomBoolean()), + randomLongBetween(1_000_000L, 2_000_000L) + ) + ); + } + + final String snapshot = randomAlphaOfLength(8).toLowerCase(Locale.ROOT); + deleteSnapshot(repository, snapshot, true); + createSnapshot(client, repository, snapshot, true, index); + wipeDataStreams(); + wipeAllIndices(); + restoreSnapshot(client, repository, snapshot, true, index); + + final String restoreIndex = randomAlphaOfLength(7).toLowerCase(Locale.ROOT); + final Request mountRequest = new Request("POST", "/_snapshot/" + repository + '/' + snapshot + "/_mount"); + mountRequest.addParameter("wait_for_completion", "true"); + mountRequest.setJsonEntity("{\"index\": \"" + index + "\",\"renamed_index\": \"" + restoreIndex + "\"}"); + + assertOK(client.performRequest(mountRequest)); + assertDocCount(client, restoreIndex, 10); + assertThat(getSettings(client, restoreIndex).get("index.mode"), Matchers.equalTo(IndexMode.LOGSDB.getName())); + } + + // NOTE: this test will fail on snapshot creation after fixing + // https://github.com/elastic/elasticsearch/issues/112735 + public void testLogsDBSourceOnlySnapshotCreation() throws IOException { + final String repository = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); + registerRepository(repository, FsRepository.TYPE, Settings.builder().put("location", randomAlphaOfLength(6))); + // A source-only repository delegates storage to another repository + final String sourceOnlyRepository = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); + registerRepository( + sourceOnlyRepository, + "source", + Settings.builder().put("delegate_type", FsRepository.TYPE).put("location", repository) + ); + + final String index = randomAlphaOfLength(12).toLowerCase(Locale.ROOT); + createIndex(client, index, Settings.builder().put("index.mode", IndexMode.LOGSDB.getName()).build()); + + for (int i = 0; i < 10; i++) { + indexDocument( + client, + index, + document( + Instant.now().plusSeconds(10), + randomAlphaOfLength(10), + randomNonNegativeLong(), + randomFrom("PUT", "POST", "GET"), + randomAlphaOfLength(64), + randomIp(randomBoolean()), + randomLongBetween(1_000_000L, 2_000_000L) + ) + ); + } + + final String snapshot = randomAlphaOfLength(8).toLowerCase(Locale.ROOT); + deleteSnapshot(sourceOnlyRepository, snapshot, true); + createSnapshot(client, sourceOnlyRepository, snapshot, true, index); + wipeDataStreams(); + wipeAllIndices(); + // Can't snapshot _source only on an index that has incomplete source ie. has _source disabled or filters the source + final ResponseException responseException = expectThrows( + ResponseException.class, + () -> restoreSnapshot(client, sourceOnlyRepository, snapshot, true, index) + ); + assertThat(responseException.getMessage(), Matchers.containsString("wasn't fully snapshotted")); + } + + private static void registerRepository(final String repository, final String type, final Settings.Builder settings) throws IOException { + registerRepository(repository, type, false, settings.build()); } - private void assertDataStreamBackingIndexMode(final String indexMode, int backingIndex) throws IOException { - assertThat(getSettings(client, getWriteBackingIndex(client, DATA_STREAM_NAME, backingIndex)).get("index.mode"), is(indexMode)); + private void assertDataStreamBackingIndexMode(final String indexMode, int backingIndex, final String dataStreamName) + throws IOException { + assertThat(getSettings(client, getWriteBackingIndex(client, dataStreamName, backingIndex)).get("index.mode"), is(indexMode)); } private String document( @@ -364,8 +587,8 @@ private static void putTemplate(final RestClient client, final String templateNa assertOK(client.performRequest(request)); } - private static void indexDocument(final RestClient client, String dataStreamName, String doc) throws IOException { - final Request request = new Request("POST", "/" + dataStreamName + "/_doc?refresh=true"); + private static void indexDocument(final RestClient client, String indexOrtDataStream, String doc) throws IOException { + final Request request = new Request("POST", "/" + indexOrtDataStream + "/_doc?refresh=true"); request.setJsonEntity(doc); final Response response = client.performRequest(request); assertOK(response); @@ -393,4 +616,46 @@ private static Map getSettings(final RestClient client, final St final Request request = new Request("GET", "/" + indexName + "/_settings?flat_settings"); return ((Map>) entityAsMap(client.performRequest(request)).get(indexName)).get("settings"); } + + private static void createSnapshot( + RestClient restClient, + String repository, + String snapshot, + boolean waitForCompletion, + final String... indices + ) throws IOException { + final Request request = new Request(HttpPut.METHOD_NAME, "_snapshot/" + repository + '/' + snapshot); + request.addParameter("wait_for_completion", Boolean.toString(waitForCompletion)); + request.setJsonEntity(""" + "indices": $indices + """.replace("$indices", String.join(", ", indices))); + + final Response response = restClient.performRequest(request); + assertThat( + "Failed to create snapshot [" + snapshot + "] in repository [" + repository + "]: " + response, + response.getStatusLine().getStatusCode(), + equalTo(RestStatus.OK.getStatus()) + ); + } + + private static void restoreSnapshot( + final RestClient client, + final String repository, + String snapshot, + boolean waitForCompletion, + final String... indices + ) throws IOException { + final Request request = new Request(HttpPost.METHOD_NAME, "_snapshot/" + repository + '/' + snapshot + "/_restore"); + request.addParameter("wait_for_completion", Boolean.toString(waitForCompletion)); + request.setJsonEntity(""" + "indices": $indices + """.replace("$indices", String.join(", ", indices))); + + final Response response = client.performRequest(request); + assertThat( + "Failed to restore snapshot [" + snapshot + "] from repository [" + repository + "]: " + response, + response.getStatusLine().getStatusCode(), + equalTo(RestStatus.OK.getStatus()) + ); + } } diff --git a/modules/data-streams/src/javaRestTest/resources/roles.yml b/modules/data-streams/src/javaRestTest/resources/roles.yml index 74c238fdae4f2..bc38f60849e4e 100644 --- a/modules/data-streams/src/javaRestTest/resources/roles.yml +++ b/modules/data-streams/src/javaRestTest/resources/roles.yml @@ -16,4 +16,10 @@ under_privilged: - read - write - view_index_metadata +only_get: + indices: + - names: [ 'my-ds*' ] + privileges: + - read + - view_index_metadata no_privilege: diff --git a/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/TransportGetDataStreamsAction.java b/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/TransportGetDataStreamsAction.java index a22bfd61fa3ca..ffa2447f5f5aa 100644 --- a/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/TransportGetDataStreamsAction.java +++ b/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/TransportGetDataStreamsAction.java @@ -19,6 +19,7 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.master.TransportMasterNodeReadAction; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; @@ -90,7 +91,7 @@ public TransportGetDataStreamsAction( this.systemIndices = systemIndices; this.globalRetentionSettings = globalRetentionSettings; clusterSettings = clusterService.getClusterSettings(); - this.client = client; + this.client = new OriginSettingClient(client, "stack"); } @Override diff --git a/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderIT.java b/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderIT.java index a8dae37075959..6942cc3733d1e 100644 --- a/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderIT.java +++ b/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderIT.java @@ -9,7 +9,7 @@ package org.elasticsearch.ingest.geoip; -import com.maxmind.geoip2.DatabaseReader; +import com.maxmind.db.Reader; import org.apache.lucene.search.TotalHits; import org.elasticsearch.action.ingest.SimulateDocumentBaseResult; @@ -738,8 +738,8 @@ private void deleteDatabasesInConfigDirectory() throws Exception { @SuppressForbidden(reason = "Maxmind API requires java.io.File") private void parseDatabase(Path tempFile) throws IOException { - try (DatabaseReader databaseReader = new DatabaseReader.Builder(tempFile.toFile()).build()) { - assertNotNull(databaseReader.getMetadata()); + try (Reader reader = new Reader(tempFile.toFile())) { + assertNotNull(reader.getMetadata()); } } diff --git a/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/ReloadingDatabasesWhilePerformingGeoLookupsIT.java b/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/ReloadingDatabasesWhilePerformingGeoLookupsIT.java index ff6ea3de1ab7b..2c7d5fbcc56b7 100644 --- a/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/ReloadingDatabasesWhilePerformingGeoLookupsIT.java +++ b/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/ReloadingDatabasesWhilePerformingGeoLookupsIT.java @@ -13,7 +13,6 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.network.InetAddresses; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.core.IOUtils; import org.elasticsearch.index.VersionType; @@ -206,10 +205,10 @@ private static DatabaseNodeService createRegistry(Path geoIpConfigDir, Path geoI private static void lazyLoadReaders(DatabaseNodeService databaseNodeService) throws IOException { if (databaseNodeService.get("GeoLite2-City.mmdb") != null) { databaseNodeService.get("GeoLite2-City.mmdb").getDatabaseType(); - databaseNodeService.get("GeoLite2-City.mmdb").getCity(InetAddresses.forString("2.125.160.216")); + databaseNodeService.get("GeoLite2-City.mmdb").getCity("2.125.160.216"); } databaseNodeService.get("GeoLite2-City-Test.mmdb").getDatabaseType(); - databaseNodeService.get("GeoLite2-City-Test.mmdb").getCity(InetAddresses.forString("2.125.160.216")); + databaseNodeService.get("GeoLite2-City-Test.mmdb").getCity("2.125.160.216"); } } diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/DatabaseNodeService.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/DatabaseNodeService.java index 0579aeb1b5353..2114ebf9f2f05 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/DatabaseNodeService.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/DatabaseNodeService.java @@ -36,7 +36,6 @@ import org.elasticsearch.search.SearchHit; import org.elasticsearch.watcher.ResourceWatcherService; -import java.io.BufferedInputStream; import java.io.Closeable; import java.io.IOException; import java.io.UncheckedIOException; @@ -87,7 +86,7 @@ * if there is an old instance of this database then that is closed. * 4) Cleanup locally loaded databases that are no longer mentioned in {@link GeoIpTaskState}. */ -public final class DatabaseNodeService implements GeoIpDatabaseProvider, Closeable { +public final class DatabaseNodeService implements IpDatabaseProvider, Closeable { private static final Logger logger = LogManager.getLogger(DatabaseNodeService.class); @@ -221,7 +220,7 @@ DatabaseReaderLazyLoader getDatabaseReaderLazyLoader(String name) { } @Override - public GeoIpDatabase getDatabase(String name) { + public IpDatabase getDatabase(String name) { return getDatabaseReaderLazyLoader(name); } @@ -380,11 +379,7 @@ void retrieveAndUpdateDatabase(String databaseName, GeoIpTaskState.Metadata meta Path databaseFile = geoipTmpDirectory.resolve(databaseName); // tarball contains .mmdb, LICENSE.txt, COPYRIGHTS.txt and optional README.txt files. // we store mmdb file as is and prepend database name to all other entries to avoid conflicts - try ( - TarInputStream is = new TarInputStream( - new GZIPInputStream(new BufferedInputStream(Files.newInputStream(databaseTmpGzFile)), 8192) - ) - ) { + try (TarInputStream is = new TarInputStream(new GZIPInputStream(Files.newInputStream(databaseTmpGzFile), 8192))) { TarInputStream.TarEntry entry; while ((entry = is.getNextEntry()) != null) { // there might be ./ entry in tar, we should skip it diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/DatabaseReaderLazyLoader.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/DatabaseReaderLazyLoader.java index f1594ddaf5144..60e78138f5c74 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/DatabaseReaderLazyLoader.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/DatabaseReaderLazyLoader.java @@ -9,9 +9,10 @@ package org.elasticsearch.ingest.geoip; +import com.maxmind.db.DatabaseRecord; +import com.maxmind.db.Network; import com.maxmind.db.NoCache; import com.maxmind.db.Reader; -import com.maxmind.geoip2.DatabaseReader; import com.maxmind.geoip2.model.AbstractResponse; import com.maxmind.geoip2.model.AnonymousIpResponse; import com.maxmind.geoip2.model.AsnResponse; @@ -25,18 +26,23 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.util.SetOnce; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.common.CheckedBiFunction; import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.common.network.InetAddresses; +import org.elasticsearch.common.network.NetworkAddress; import org.elasticsearch.core.Booleans; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.SuppressForbidden; import java.io.Closeable; +import java.io.File; import java.io.IOException; import java.net.InetAddress; import java.nio.file.Files; import java.nio.file.Path; +import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; @@ -45,7 +51,7 @@ * Facilitates lazy loading of the database reader, so that when the geoip plugin is installed, but not used, * no memory is being wasted on the database reader. */ -class DatabaseReaderLazyLoader implements GeoIpDatabase, Closeable { +class DatabaseReaderLazyLoader implements IpDatabase, Closeable { private static final boolean LOAD_DATABASE_ON_HEAP = Booleans.parseBoolean(System.getProperty("es.geoip.load_db_on_heap", "false")); @@ -54,8 +60,8 @@ class DatabaseReaderLazyLoader implements GeoIpDatabase, Closeable { private final String md5; private final GeoIpCache cache; private final Path databasePath; - private final CheckedSupplier loader; - final SetOnce databaseReader; + private final CheckedSupplier loader; + final SetOnce databaseReader; // cache the database type so that we do not re-read it on every pipeline execution final SetOnce databaseType; @@ -92,50 +98,90 @@ public final String getDatabaseType() throws IOException { @Nullable @Override - public CityResponse getCity(InetAddress ipAddress) { - return getResponse(ipAddress, DatabaseReader::tryCity); + public CityResponse getCity(String ipAddress) { + return getResponse(ipAddress, (reader, ip) -> lookup(reader, ip, CityResponse.class, CityResponse::new)); } @Nullable @Override - public CountryResponse getCountry(InetAddress ipAddress) { - return getResponse(ipAddress, DatabaseReader::tryCountry); + public CountryResponse getCountry(String ipAddress) { + return getResponse(ipAddress, (reader, ip) -> lookup(reader, ip, CountryResponse.class, CountryResponse::new)); } @Nullable @Override - public AsnResponse getAsn(InetAddress ipAddress) { - return getResponse(ipAddress, DatabaseReader::tryAsn); + public AsnResponse getAsn(String ipAddress) { + return getResponse( + ipAddress, + (reader, ip) -> lookup( + reader, + ip, + AsnResponse.class, + (response, responseIp, network, locales) -> new AsnResponse(response, responseIp, network) + ) + ); } @Nullable @Override - public AnonymousIpResponse getAnonymousIp(InetAddress ipAddress) { - return getResponse(ipAddress, DatabaseReader::tryAnonymousIp); + public AnonymousIpResponse getAnonymousIp(String ipAddress) { + return getResponse( + ipAddress, + (reader, ip) -> lookup( + reader, + ip, + AnonymousIpResponse.class, + (response, responseIp, network, locales) -> new AnonymousIpResponse(response, responseIp, network) + ) + ); } @Nullable @Override - public ConnectionTypeResponse getConnectionType(InetAddress ipAddress) { - return getResponse(ipAddress, DatabaseReader::tryConnectionType); + public ConnectionTypeResponse getConnectionType(String ipAddress) { + return getResponse( + ipAddress, + (reader, ip) -> lookup( + reader, + ip, + ConnectionTypeResponse.class, + (response, responseIp, network, locales) -> new ConnectionTypeResponse(response, responseIp, network) + ) + ); } @Nullable @Override - public DomainResponse getDomain(InetAddress ipAddress) { - return getResponse(ipAddress, DatabaseReader::tryDomain); + public DomainResponse getDomain(String ipAddress) { + return getResponse( + ipAddress, + (reader, ip) -> lookup( + reader, + ip, + DomainResponse.class, + (response, responseIp, network, locales) -> new DomainResponse(response, responseIp, network) + ) + ); } @Nullable @Override - public EnterpriseResponse getEnterprise(InetAddress ipAddress) { - return getResponse(ipAddress, DatabaseReader::tryEnterprise); + public EnterpriseResponse getEnterprise(String ipAddress) { + return getResponse(ipAddress, (reader, ip) -> lookup(reader, ip, EnterpriseResponse.class, EnterpriseResponse::new)); } @Nullable @Override - public IspResponse getIsp(InetAddress ipAddress) { - return getResponse(ipAddress, DatabaseReader::tryIsp); + public IspResponse getIsp(String ipAddress) { + return getResponse( + ipAddress, + (reader, ip) -> lookup( + reader, + ip, + IspResponse.class, + (response, responseIp, network, locales) -> new IspResponse(response, responseIp, network) + ) + ); } boolean preLookup() { @@ -155,19 +201,19 @@ int current() { @Nullable private T getResponse( - InetAddress ipAddress, - CheckedBiFunction, Exception> responseProvider + String ipAddress, + CheckedBiFunction, Exception> responseProvider ) { return cache.putIfAbsent(ipAddress, databasePath.toString(), ip -> { try { return responseProvider.apply(get(), ipAddress).orElse(null); } catch (Exception e) { - throw new RuntimeException(e); + throw ExceptionsHelper.convertToRuntime(e); } }); } - DatabaseReader get() throws IOException { + Reader get() throws IOException { if (databaseReader.get() == null) { synchronized (databaseReader) { if (databaseReader.get() == null) { @@ -206,21 +252,32 @@ protected void doClose() throws IOException { } } - private static CheckedSupplier createDatabaseLoader(Path databasePath) { + private static CheckedSupplier createDatabaseLoader(Path databasePath) { return () -> { - DatabaseReader.Builder builder = createDatabaseBuilder(databasePath).withCache(NoCache.getInstance()); - if (LOAD_DATABASE_ON_HEAP) { - builder.fileMode(Reader.FileMode.MEMORY); - } else { - builder.fileMode(Reader.FileMode.MEMORY_MAPPED); - } - return builder.build(); + Reader.FileMode mode = LOAD_DATABASE_ON_HEAP ? Reader.FileMode.MEMORY : Reader.FileMode.MEMORY_MAPPED; + return new Reader(pathToFile(databasePath), mode, NoCache.getInstance()); }; } @SuppressForbidden(reason = "Maxmind API requires java.io.File") - private static DatabaseReader.Builder createDatabaseBuilder(Path databasePath) { - return new DatabaseReader.Builder(databasePath.toFile()); + private static File pathToFile(Path databasePath) { + return databasePath.toFile(); + } + + @FunctionalInterface + private interface ResponseBuilder { + RESPONSE build(RESPONSE response, String responseIp, Network network, List locales); } + private Optional lookup(Reader reader, String ip, Class clazz, ResponseBuilder builder) + throws IOException { + InetAddress inetAddress = InetAddresses.forString(ip); + DatabaseRecord record = reader.getRecord(inetAddress, clazz); + RESPONSE result = record.getData(); + if (result == null) { + return Optional.empty(); + } else { + return Optional.of(builder.build(result, NetworkAddress.format(inetAddress), record.getNetwork(), List.of("en"))); + } + } } diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpCache.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpCache.java index 2102deb30cd87..c5c7c6175c5f3 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpCache.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpCache.java @@ -9,14 +9,12 @@ package org.elasticsearch.ingest.geoip; import com.maxmind.db.NodeCache; -import com.maxmind.geoip2.model.AbstractResponse; import org.elasticsearch.common.cache.Cache; import org.elasticsearch.common.cache.CacheBuilder; import org.elasticsearch.core.TimeValue; import org.elasticsearch.ingest.geoip.stats.CacheStats; -import java.net.InetAddress; import java.nio.file.Path; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; @@ -36,15 +34,15 @@ final class GeoIpCache { * something not being in the cache because the data doesn't exist in the database. */ // visible for testing - static final AbstractResponse NO_RESULT = new AbstractResponse() { + static final Object NO_RESULT = new Object() { @Override public String toString() { - return "AbstractResponse[NO_RESULT]"; + return "NO_RESULT"; } }; private final LongSupplier relativeNanoTimeProvider; - private final Cache cache; + private final Cache cache; private final AtomicLong hitsTimeInNanos = new AtomicLong(0); private final AtomicLong missesTimeInNanos = new AtomicLong(0); @@ -54,7 +52,7 @@ public String toString() { throw new IllegalArgumentException("geoip max cache size must be 0 or greater"); } this.relativeNanoTimeProvider = relativeNanoTimeProvider; - this.cache = CacheBuilder.builder().setMaximumWeight(maxSize).build(); + this.cache = CacheBuilder.builder().setMaximumWeight(maxSize).build(); } GeoIpCache(long maxSize) { @@ -62,16 +60,12 @@ public String toString() { } @SuppressWarnings("unchecked") - T putIfAbsent( - InetAddress ip, - String databasePath, - Function retrieveFunction - ) { + T putIfAbsent(String ip, String databasePath, Function retrieveFunction) { // can't use cache.computeIfAbsent due to the elevated permissions for the jackson (run via the cache loader) CacheKey cacheKey = new CacheKey(ip, databasePath); long cacheStart = relativeNanoTimeProvider.getAsLong(); // intentionally non-locking for simplicity...it's OK if we re-put the same key/value in the cache during a race condition. - AbstractResponse response = cache.get(cacheKey); + Object response = cache.get(cacheKey); long cacheRequestTime = relativeNanoTimeProvider.getAsLong() - cacheStart; // populate the cache for this key, if necessary @@ -98,7 +92,7 @@ T putIfAbsent( } // only useful for testing - AbstractResponse get(InetAddress ip, String databasePath) { + Object get(String ip, String databasePath) { CacheKey cacheKey = new CacheKey(ip, databasePath); return cache.get(cacheKey); } @@ -141,5 +135,5 @@ public CacheStats getCacheStats() { * path is needed to be included in the cache key. For example, if we only used the IP address as the key the City and ASN the same * IP may be in both with different values and we need to cache both. */ - private record CacheKey(InetAddress ip, String databasePath) {} + private record CacheKey(String ip, String databasePath) {} } diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpProcessor.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpProcessor.java index 98ac2ff7d1044..c770169975404 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpProcessor.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpProcessor.java @@ -29,8 +29,6 @@ import org.elasticsearch.common.CheckedSupplier; import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; -import org.elasticsearch.common.network.InetAddresses; -import org.elasticsearch.common.network.NetworkAddress; import org.elasticsearch.core.Assertions; import org.elasticsearch.ingest.AbstractProcessor; import org.elasticsearch.ingest.IngestDocument; @@ -38,7 +36,6 @@ import org.elasticsearch.ingest.geoip.Database.Property; import java.io.IOException; -import java.net.InetAddress; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -62,7 +59,7 @@ public final class GeoIpProcessor extends AbstractProcessor { private final String field; private final Supplier isValid; private final String targetField; - private final CheckedSupplier supplier; + private final CheckedSupplier supplier; private final Set properties; private final boolean ignoreMissing; private final boolean firstOnly; @@ -85,7 +82,7 @@ public final class GeoIpProcessor extends AbstractProcessor { final String tag, final String description, final String field, - final CheckedSupplier supplier, + final CheckedSupplier supplier, final Supplier isValid, final String targetField, final Set properties, @@ -121,8 +118,8 @@ public IngestDocument execute(IngestDocument ingestDocument) throws IOException throw new IllegalArgumentException("field [" + field + "] is null, cannot extract geoip information."); } - GeoIpDatabase geoIpDatabase = this.supplier.get(); - if (geoIpDatabase == null) { + IpDatabase ipDatabase = this.supplier.get(); + if (ipDatabase == null) { if (ignoreMissing == false) { tag(ingestDocument, databaseFile); } @@ -131,7 +128,7 @@ public IngestDocument execute(IngestDocument ingestDocument) throws IOException try { if (ip instanceof String ipString) { - Map geoData = getGeoData(geoIpDatabase, ipString); + Map geoData = getGeoData(ipDatabase, ipString); if (geoData.isEmpty() == false) { ingestDocument.setFieldValue(targetField, geoData); } @@ -142,7 +139,7 @@ public IngestDocument execute(IngestDocument ingestDocument) throws IOException if (ipAddr instanceof String == false) { throw new IllegalArgumentException("array in field [" + field + "] should only contain strings"); } - Map geoData = getGeoData(geoIpDatabase, (String) ipAddr); + Map geoData = getGeoData(ipDatabase, (String) ipAddr); if (geoData.isEmpty()) { geoDataList.add(null); continue; @@ -161,29 +158,28 @@ public IngestDocument execute(IngestDocument ingestDocument) throws IOException throw new IllegalArgumentException("field [" + field + "] should contain only string or array of strings"); } } finally { - geoIpDatabase.release(); + ipDatabase.release(); } return ingestDocument; } - private Map getGeoData(GeoIpDatabase geoIpDatabase, String ip) throws IOException { - final String databaseType = geoIpDatabase.getDatabaseType(); + private Map getGeoData(IpDatabase ipDatabase, String ipAddress) throws IOException { + final String databaseType = ipDatabase.getDatabaseType(); final Database database; try { database = Database.getDatabase(databaseType, databaseFile); } catch (IllegalArgumentException e) { throw new ElasticsearchParseException(e.getMessage(), e); } - final InetAddress ipAddress = InetAddresses.forString(ip); return switch (database) { - case City -> retrieveCityGeoData(geoIpDatabase, ipAddress); - case Country -> retrieveCountryGeoData(geoIpDatabase, ipAddress); - case Asn -> retrieveAsnGeoData(geoIpDatabase, ipAddress); - case AnonymousIp -> retrieveAnonymousIpGeoData(geoIpDatabase, ipAddress); - case ConnectionType -> retrieveConnectionTypeGeoData(geoIpDatabase, ipAddress); - case Domain -> retrieveDomainGeoData(geoIpDatabase, ipAddress); - case Enterprise -> retrieveEnterpriseGeoData(geoIpDatabase, ipAddress); - case Isp -> retrieveIspGeoData(geoIpDatabase, ipAddress); + case City -> retrieveCityGeoData(ipDatabase, ipAddress); + case Country -> retrieveCountryGeoData(ipDatabase, ipAddress); + case Asn -> retrieveAsnGeoData(ipDatabase, ipAddress); + case AnonymousIp -> retrieveAnonymousIpGeoData(ipDatabase, ipAddress); + case ConnectionType -> retrieveConnectionTypeGeoData(ipDatabase, ipAddress); + case Domain -> retrieveDomainGeoData(ipDatabase, ipAddress); + case Enterprise -> retrieveEnterpriseGeoData(ipDatabase, ipAddress); + case Isp -> retrieveIspGeoData(ipDatabase, ipAddress); }; } @@ -208,8 +204,8 @@ Set getProperties() { return properties; } - private Map retrieveCityGeoData(GeoIpDatabase geoIpDatabase, InetAddress ipAddress) { - CityResponse response = geoIpDatabase.getCity(ipAddress); + private Map retrieveCityGeoData(IpDatabase ipDatabase, String ipAddress) { + CityResponse response = ipDatabase.getCity(ipAddress); if (response == null) { return Map.of(); } @@ -222,7 +218,7 @@ private Map retrieveCityGeoData(GeoIpDatabase geoIpDatabase, Ine Map geoData = new HashMap<>(); for (Property property : this.properties) { switch (property) { - case IP -> geoData.put("ip", NetworkAddress.format(ipAddress)); + case IP -> geoData.put("ip", response.getTraits().getIpAddress()); case COUNTRY_ISO_CODE -> { String countryIsoCode = country.getIsoCode(); if (countryIsoCode != null) { @@ -290,8 +286,8 @@ private Map retrieveCityGeoData(GeoIpDatabase geoIpDatabase, Ine return geoData; } - private Map retrieveCountryGeoData(GeoIpDatabase geoIpDatabase, InetAddress ipAddress) { - CountryResponse response = geoIpDatabase.getCountry(ipAddress); + private Map retrieveCountryGeoData(IpDatabase ipDatabase, String ipAddress) { + CountryResponse response = ipDatabase.getCountry(ipAddress); if (response == null) { return Map.of(); } @@ -301,7 +297,7 @@ private Map retrieveCountryGeoData(GeoIpDatabase geoIpDatabase, Map geoData = new HashMap<>(); for (Property property : this.properties) { switch (property) { - case IP -> geoData.put("ip", NetworkAddress.format(ipAddress)); + case IP -> geoData.put("ip", response.getTraits().getIpAddress()); case COUNTRY_ISO_CODE -> { String countryIsoCode = country.getIsoCode(); if (countryIsoCode != null) { @@ -331,8 +327,8 @@ private Map retrieveCountryGeoData(GeoIpDatabase geoIpDatabase, return geoData; } - private Map retrieveAsnGeoData(GeoIpDatabase geoIpDatabase, InetAddress ipAddress) { - AsnResponse response = geoIpDatabase.getAsn(ipAddress); + private Map retrieveAsnGeoData(IpDatabase ipDatabase, String ipAddress) { + AsnResponse response = ipDatabase.getAsn(ipAddress); if (response == null) { return Map.of(); } @@ -343,7 +339,7 @@ private Map retrieveAsnGeoData(GeoIpDatabase geoIpDatabase, Inet Map geoData = new HashMap<>(); for (Property property : this.properties) { switch (property) { - case IP -> geoData.put("ip", NetworkAddress.format(ipAddress)); + case IP -> geoData.put("ip", response.getIpAddress()); case ASN -> { if (asn != null) { geoData.put("asn", asn); @@ -364,8 +360,8 @@ private Map retrieveAsnGeoData(GeoIpDatabase geoIpDatabase, Inet return geoData; } - private Map retrieveAnonymousIpGeoData(GeoIpDatabase geoIpDatabase, InetAddress ipAddress) { - AnonymousIpResponse response = geoIpDatabase.getAnonymousIp(ipAddress); + private Map retrieveAnonymousIpGeoData(IpDatabase ipDatabase, String ipAddress) { + AnonymousIpResponse response = ipDatabase.getAnonymousIp(ipAddress); if (response == null) { return Map.of(); } @@ -380,7 +376,7 @@ private Map retrieveAnonymousIpGeoData(GeoIpDatabase geoIpDataba Map geoData = new HashMap<>(); for (Property property : this.properties) { switch (property) { - case IP -> geoData.put("ip", NetworkAddress.format(ipAddress)); + case IP -> geoData.put("ip", response.getIpAddress()); case HOSTING_PROVIDER -> { geoData.put("hosting_provider", isHostingProvider); } @@ -404,8 +400,8 @@ private Map retrieveAnonymousIpGeoData(GeoIpDatabase geoIpDataba return geoData; } - private Map retrieveConnectionTypeGeoData(GeoIpDatabase geoIpDatabase, InetAddress ipAddress) { - ConnectionTypeResponse response = geoIpDatabase.getConnectionType(ipAddress); + private Map retrieveConnectionTypeGeoData(IpDatabase ipDatabase, String ipAddress) { + ConnectionTypeResponse response = ipDatabase.getConnectionType(ipAddress); if (response == null) { return Map.of(); } @@ -415,7 +411,7 @@ private Map retrieveConnectionTypeGeoData(GeoIpDatabase geoIpDat Map geoData = new HashMap<>(); for (Property property : this.properties) { switch (property) { - case IP -> geoData.put("ip", NetworkAddress.format(ipAddress)); + case IP -> geoData.put("ip", response.getIpAddress()); case CONNECTION_TYPE -> { if (connectionType != null) { geoData.put("connection_type", connectionType.toString()); @@ -426,8 +422,8 @@ private Map retrieveConnectionTypeGeoData(GeoIpDatabase geoIpDat return geoData; } - private Map retrieveDomainGeoData(GeoIpDatabase geoIpDatabase, InetAddress ipAddress) { - DomainResponse response = geoIpDatabase.getDomain(ipAddress); + private Map retrieveDomainGeoData(IpDatabase ipDatabase, String ipAddress) { + DomainResponse response = ipDatabase.getDomain(ipAddress); if (response == null) { return Map.of(); } @@ -437,7 +433,7 @@ private Map retrieveDomainGeoData(GeoIpDatabase geoIpDatabase, I Map geoData = new HashMap<>(); for (Property property : this.properties) { switch (property) { - case IP -> geoData.put("ip", NetworkAddress.format(ipAddress)); + case IP -> geoData.put("ip", response.getIpAddress()); case DOMAIN -> { if (domain != null) { geoData.put("domain", domain); @@ -448,8 +444,8 @@ private Map retrieveDomainGeoData(GeoIpDatabase geoIpDatabase, I return geoData; } - private Map retrieveEnterpriseGeoData(GeoIpDatabase geoIpDatabase, InetAddress ipAddress) { - EnterpriseResponse response = geoIpDatabase.getEnterprise(ipAddress); + private Map retrieveEnterpriseGeoData(IpDatabase ipDatabase, String ipAddress) { + EnterpriseResponse response = ipDatabase.getEnterprise(ipAddress); if (response == null) { return Map.of(); } @@ -485,7 +481,7 @@ private Map retrieveEnterpriseGeoData(GeoIpDatabase geoIpDatabas Map geoData = new HashMap<>(); for (Property property : this.properties) { switch (property) { - case IP -> geoData.put("ip", NetworkAddress.format(ipAddress)); + case IP -> geoData.put("ip", response.getTraits().getIpAddress()); case COUNTRY_ISO_CODE -> { String countryIsoCode = country.getIsoCode(); if (countryIsoCode != null) { @@ -621,8 +617,8 @@ private Map retrieveEnterpriseGeoData(GeoIpDatabase geoIpDatabas return geoData; } - private Map retrieveIspGeoData(GeoIpDatabase geoIpDatabase, InetAddress ipAddress) { - IspResponse response = geoIpDatabase.getIsp(ipAddress); + private Map retrieveIspGeoData(IpDatabase ipDatabase, String ipAddress) { + IspResponse response = ipDatabase.getIsp(ipAddress); if (response == null) { return Map.of(); } @@ -638,7 +634,7 @@ private Map retrieveIspGeoData(GeoIpDatabase geoIpDatabase, Inet Map geoData = new HashMap<>(); for (Property property : this.properties) { switch (property) { - case IP -> geoData.put("ip", NetworkAddress.format(ipAddress)); + case IP -> geoData.put("ip", response.getIpAddress()); case ASN -> { if (asn != null) { geoData.put("asn", asn); @@ -680,23 +676,23 @@ private Map retrieveIspGeoData(GeoIpDatabase geoIpDatabase, Inet } /** - * Retrieves and verifies a {@link GeoIpDatabase} instance for each execution of the {@link GeoIpProcessor}. Guards against missing + * Retrieves and verifies a {@link IpDatabase} instance for each execution of the {@link GeoIpProcessor}. Guards against missing * custom databases, and ensures that database instances are of the proper type before use. */ - public static final class DatabaseVerifyingSupplier implements CheckedSupplier { - private final GeoIpDatabaseProvider geoIpDatabaseProvider; + public static final class DatabaseVerifyingSupplier implements CheckedSupplier { + private final IpDatabaseProvider ipDatabaseProvider; private final String databaseFile; private final String databaseType; - public DatabaseVerifyingSupplier(GeoIpDatabaseProvider geoIpDatabaseProvider, String databaseFile, String databaseType) { - this.geoIpDatabaseProvider = geoIpDatabaseProvider; + public DatabaseVerifyingSupplier(IpDatabaseProvider ipDatabaseProvider, String databaseFile, String databaseType) { + this.ipDatabaseProvider = ipDatabaseProvider; this.databaseFile = databaseFile; this.databaseType = databaseType; } @Override - public GeoIpDatabase get() throws IOException { - GeoIpDatabase loader = geoIpDatabaseProvider.getDatabase(databaseFile); + public IpDatabase get() throws IOException { + IpDatabase loader = ipDatabaseProvider.getDatabase(databaseFile); if (loader == null) { return null; } @@ -716,10 +712,10 @@ public GeoIpDatabase get() throws IOException { public static final class Factory implements Processor.Factory { - private final GeoIpDatabaseProvider geoIpDatabaseProvider; + private final IpDatabaseProvider ipDatabaseProvider; - public Factory(GeoIpDatabaseProvider geoIpDatabaseProvider) { - this.geoIpDatabaseProvider = geoIpDatabaseProvider; + public Factory(IpDatabaseProvider ipDatabaseProvider) { + this.ipDatabaseProvider = ipDatabaseProvider; } @Override @@ -746,8 +742,8 @@ public Processor create( deprecationLogger.warn(DeprecationCategory.OTHER, "default_databases_message", DEFAULT_DATABASES_DEPRECATION_MESSAGE); } - GeoIpDatabase geoIpDatabase = geoIpDatabaseProvider.getDatabase(databaseFile); - if (geoIpDatabase == null) { + IpDatabase ipDatabase = ipDatabaseProvider.getDatabase(databaseFile); + if (ipDatabase == null) { // It's possible that the database could be downloaded via the GeoipDownloader process and could become available // at a later moment, so a processor impl is returned that tags documents instead. If a database cannot be sourced then the // processor will continue to tag documents with a warning until it is remediated by providing a database or changing the @@ -757,9 +753,9 @@ public Processor create( final String databaseType; try { - databaseType = geoIpDatabase.getDatabaseType(); + databaseType = ipDatabase.getDatabaseType(); } finally { - geoIpDatabase.release(); + ipDatabase.release(); } final Database database; @@ -779,8 +775,8 @@ public Processor create( processorTag, description, ipField, - new DatabaseVerifyingSupplier(geoIpDatabaseProvider, databaseFile, databaseType), - () -> geoIpDatabaseProvider.isValid(databaseFile), + new DatabaseVerifyingSupplier(ipDatabaseProvider, databaseFile, databaseType), + () -> ipDatabaseProvider.isValid(databaseFile), targetField, properties, ignoreMissing, diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDatabase.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpDatabase.java similarity index 82% rename from modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDatabase.java rename to modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpDatabase.java index 6dc005db83097..eb6200374a11a 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDatabase.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpDatabase.java @@ -21,12 +21,11 @@ import org.elasticsearch.core.Nullable; import java.io.IOException; -import java.net.InetAddress; /** - * Provides a uniform interface for interacting with various GeoIP databases. + * Provides a uniform interface for interacting with various ip databases. */ -public interface GeoIpDatabase { +public interface IpDatabase { /** * @return the database type as it is detailed in the database file metadata @@ -40,7 +39,7 @@ public interface GeoIpDatabase { * @throws UnsupportedOperationException may be thrown if the implementation does not support retrieving city data */ @Nullable - CityResponse getCity(InetAddress ipAddress); + CityResponse getCity(String ipAddress); /** * @param ipAddress the IP address to look up @@ -48,7 +47,7 @@ public interface GeoIpDatabase { * @throws UnsupportedOperationException may be thrown if the implementation does not support retrieving country data */ @Nullable - CountryResponse getCountry(InetAddress ipAddress); + CountryResponse getCountry(String ipAddress); /** * @param ipAddress the IP address to look up @@ -57,22 +56,22 @@ public interface GeoIpDatabase { * @throws UnsupportedOperationException may be thrown if the implementation does not support retrieving ASN data */ @Nullable - AsnResponse getAsn(InetAddress ipAddress); + AsnResponse getAsn(String ipAddress); @Nullable - AnonymousIpResponse getAnonymousIp(InetAddress ipAddress); + AnonymousIpResponse getAnonymousIp(String ipAddress); @Nullable - ConnectionTypeResponse getConnectionType(InetAddress ipAddress); + ConnectionTypeResponse getConnectionType(String ipAddress); @Nullable - DomainResponse getDomain(InetAddress ipAddress); + DomainResponse getDomain(String ipAddress); @Nullable - EnterpriseResponse getEnterprise(InetAddress ipAddress); + EnterpriseResponse getEnterprise(String ipAddress); @Nullable - IspResponse getIsp(InetAddress ipAddress); + IspResponse getIsp(String ipAddress); /** * Releases the current database object. Called after processing a single document. Databases should be closed or returned to a diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDatabaseProvider.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpDatabaseProvider.java similarity index 88% rename from modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDatabaseProvider.java rename to modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpDatabaseProvider.java index 2e7a9acc0e69a..9438bf74f8c12 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDatabaseProvider.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpDatabaseProvider.java @@ -10,9 +10,9 @@ package org.elasticsearch.ingest.geoip; /** - * Provides construction and initialization logic for {@link GeoIpDatabase} instances. + * Provides construction and initialization logic for {@link IpDatabase} instances. */ -public interface GeoIpDatabaseProvider { +public interface IpDatabaseProvider { /** * Determines if the given database name corresponds to an expired database. Expired databases will not be loaded. @@ -30,5 +30,5 @@ public interface GeoIpDatabaseProvider { * @param name the name of the database to provide. * @return a ready-to-use database instance, or null if no database could be loaded. */ - GeoIpDatabase getDatabase(String name); + IpDatabase getDatabase(String name); } diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/ConfigDatabasesTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/ConfigDatabasesTests.java index 64c3b91aabda8..83b3d2cfbbc27 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/ConfigDatabasesTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/ConfigDatabasesTests.java @@ -11,7 +11,6 @@ import com.maxmind.geoip2.model.CityResponse; -import org.elasticsearch.common.network.InetAddresses; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.test.ESTestCase; @@ -127,7 +126,7 @@ public void testDatabasesUpdateExistingConfDatabase() throws Exception { DatabaseReaderLazyLoader loader = configDatabases.getDatabase("GeoLite2-City.mmdb"); assertThat(loader.getDatabaseType(), equalTo("GeoLite2-City")); - CityResponse cityResponse = loader.getCity(InetAddresses.forString("89.160.20.128")); + CityResponse cityResponse = loader.getCity("89.160.20.128"); assertThat(cityResponse.getCity().getName(), equalTo("Tumba")); assertThat(cache.count(), equalTo(1)); } @@ -139,7 +138,7 @@ public void testDatabasesUpdateExistingConfDatabase() throws Exception { DatabaseReaderLazyLoader loader = configDatabases.getDatabase("GeoLite2-City.mmdb"); assertThat(loader.getDatabaseType(), equalTo("GeoLite2-City")); - CityResponse cityResponse = loader.getCity(InetAddresses.forString("89.160.20.128")); + CityResponse cityResponse = loader.getCity("89.160.20.128"); assertThat(cityResponse.getCity().getName(), equalTo("Linköping")); assertThat(cache.count(), equalTo(1)); }); diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpCacheTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpCacheTests.java index e6ad2b550bf23..0c92aca882913 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpCacheTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpCacheTests.java @@ -11,12 +11,10 @@ import com.maxmind.geoip2.model.AbstractResponse; -import org.elasticsearch.common.network.InetAddresses; import org.elasticsearch.core.TimeValue; import org.elasticsearch.ingest.geoip.stats.CacheStats; import org.elasticsearch.test.ESTestCase; -import java.net.InetAddress; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; @@ -32,34 +30,34 @@ public void testCachesAndEvictsResults() { AbstractResponse response2 = mock(AbstractResponse.class); // add a key - AbstractResponse cachedResponse = cache.putIfAbsent(InetAddresses.forString("127.0.0.1"), "path/to/db", ip -> response1); + AbstractResponse cachedResponse = cache.putIfAbsent("127.0.0.1", "path/to/db", ip -> response1); assertSame(cachedResponse, response1); - assertSame(cachedResponse, cache.putIfAbsent(InetAddresses.forString("127.0.0.1"), "path/to/db", ip -> response1)); - assertSame(cachedResponse, cache.get(InetAddresses.forString("127.0.0.1"), "path/to/db")); + assertSame(cachedResponse, cache.putIfAbsent("127.0.0.1", "path/to/db", ip -> response1)); + assertSame(cachedResponse, cache.get("127.0.0.1", "path/to/db")); // evict old key by adding another value - cachedResponse = cache.putIfAbsent(InetAddresses.forString("127.0.0.2"), "path/to/db", ip -> response2); + cachedResponse = cache.putIfAbsent("127.0.0.2", "path/to/db", ip -> response2); assertSame(cachedResponse, response2); - assertSame(cachedResponse, cache.putIfAbsent(InetAddresses.forString("127.0.0.2"), "path/to/db", ip -> response2)); - assertSame(cachedResponse, cache.get(InetAddresses.forString("127.0.0.2"), "path/to/db")); - assertNotSame(response1, cache.get(InetAddresses.forString("127.0.0.1"), "path/to/db")); + assertSame(cachedResponse, cache.putIfAbsent("127.0.0.2", "path/to/db", ip -> response2)); + assertSame(cachedResponse, cache.get("127.0.0.2", "path/to/db")); + assertNotSame(response1, cache.get("127.0.0.1", "path/to/db")); } public void testCachesNoResult() { GeoIpCache cache = new GeoIpCache(1); final AtomicInteger count = new AtomicInteger(0); - Function countAndReturnNull = (ip) -> { + Function countAndReturnNull = (ip) -> { count.incrementAndGet(); return null; }; - AbstractResponse response = cache.putIfAbsent(InetAddresses.forString("127.0.0.1"), "path/to/db", countAndReturnNull); + AbstractResponse response = cache.putIfAbsent("127.0.0.1", "path/to/db", countAndReturnNull); assertNull(response); - assertNull(cache.putIfAbsent(InetAddresses.forString("127.0.0.1"), "path/to/db", countAndReturnNull)); + assertNull(cache.putIfAbsent("127.0.0.1", "path/to/db", countAndReturnNull)); assertEquals(1, count.get()); // the cached value is not actually *null*, it's the NO_RESULT sentinel - assertSame(GeoIpCache.NO_RESULT, cache.get(InetAddresses.forString("127.0.0.1"), "path/to/db")); + assertSame(GeoIpCache.NO_RESULT, cache.get("127.0.0.1", "path/to/db")); } public void testCacheKey() { @@ -67,17 +65,17 @@ public void testCacheKey() { AbstractResponse response1 = mock(AbstractResponse.class); AbstractResponse response2 = mock(AbstractResponse.class); - assertSame(response1, cache.putIfAbsent(InetAddresses.forString("127.0.0.1"), "path/to/db1", ip -> response1)); - assertSame(response2, cache.putIfAbsent(InetAddresses.forString("127.0.0.1"), "path/to/db2", ip -> response2)); - assertSame(response1, cache.get(InetAddresses.forString("127.0.0.1"), "path/to/db1")); - assertSame(response2, cache.get(InetAddresses.forString("127.0.0.1"), "path/to/db2")); + assertSame(response1, cache.putIfAbsent("127.0.0.1", "path/to/db1", ip -> response1)); + assertSame(response2, cache.putIfAbsent("127.0.0.1", "path/to/db2", ip -> response2)); + assertSame(response1, cache.get("127.0.0.1", "path/to/db1")); + assertSame(response2, cache.get("127.0.0.1", "path/to/db2")); } public void testThrowsFunctionsException() { GeoIpCache cache = new GeoIpCache(1); IllegalArgumentException ex = expectThrows( IllegalArgumentException.class, - () -> cache.putIfAbsent(InetAddresses.forString("127.0.0.1"), "path/to/db", ip -> { + () -> cache.putIfAbsent("127.0.0.1", "path/to/db", ip -> { throw new IllegalArgumentException("bad"); }) ); @@ -96,9 +94,9 @@ public void testGetCacheStats() { GeoIpCache cache = new GeoIpCache(maxCacheSize, () -> testNanoTime.addAndGet(TimeValue.timeValueMillis(1).getNanos())); AbstractResponse response = mock(AbstractResponse.class); String databasePath = "path/to/db1"; - InetAddress key1 = InetAddresses.forString("127.0.0.1"); - InetAddress key2 = InetAddresses.forString("127.0.0.2"); - InetAddress key3 = InetAddresses.forString("127.0.0.3"); + String key1 = "127.0.0.1"; + String key2 = "127.0.0.2"; + String key3 = "127.0.0.3"; cache.putIfAbsent(key1, databasePath, ip -> response); // cache miss cache.putIfAbsent(key2, databasePath, ip -> response); // cache miss diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java index 1bd0eb6ef9abc..b15aa50fc98fa 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java @@ -287,9 +287,9 @@ public void testBuildIllegalFieldOption() { public void testBuildUnsupportedDatabase() throws Exception { // mock up some unsupported database (it has a databaseType that we don't recognize) - GeoIpDatabase database = mock(GeoIpDatabase.class); + IpDatabase database = mock(IpDatabase.class); when(database.getDatabaseType()).thenReturn("some-unsupported-database"); - GeoIpDatabaseProvider provider = mock(GeoIpDatabaseProvider.class); + IpDatabaseProvider provider = mock(IpDatabaseProvider.class); when(provider.getDatabase(anyString())).thenReturn(database); GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(provider); @@ -306,9 +306,9 @@ public void testBuildUnsupportedDatabase() throws Exception { public void testBuildNullDatabase() throws Exception { // mock up a provider that returns a null databaseType - GeoIpDatabase database = mock(GeoIpDatabase.class); + IpDatabase database = mock(IpDatabase.class); when(database.getDatabaseType()).thenReturn(null); - GeoIpDatabaseProvider provider = mock(GeoIpDatabaseProvider.class); + IpDatabaseProvider provider = mock(IpDatabaseProvider.class); when(provider.getDatabase(anyString())).thenReturn(database); GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(provider); diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java index 4a5d445e3ff5b..3a77e26dbbf7a 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java @@ -785,7 +785,7 @@ public void testNoDatabase_ignoreMissing() throws Exception { assertIngestDocument(originalIngestDocument, ingestDocument); } - private CheckedSupplier loader(final String path) { + private CheckedSupplier loader(final String path) { var loader = loader(path, null); return () -> loader; } diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java index 1cf4ce7facda0..ec05054615bd8 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java @@ -57,7 +57,7 @@ * - Fail if we add support for a new mmdb file type (enterprise for example) but don't update the test with which fields we do and do not * support. * - Fail if MaxMind adds a new mmdb file type that we don't know about - * - Fail if we expose a MaxMind type through GeoIpDatabase, but don't update the test to know how to handle it + * - Fail if we expose a MaxMind type through IpDatabase, but don't update the test to know how to handle it */ public class MaxMindSupportTests extends ESTestCase { @@ -469,7 +469,7 @@ public void testUnknownMaxMindResponseClassess() { } /* - * This tests that this test has a mapping in TYPE_TO_MAX_MIND_CLASS for all MaxMind classes exposed through GeoIpDatabase. + * This tests that this test has a mapping in TYPE_TO_MAX_MIND_CLASS for all MaxMind classes exposed through IpDatabase. */ public void testUsedMaxMindResponseClassesAreAccountedFor() { Set> usedMaxMindResponseClasses = getUsedMaxMindResponseClasses(); @@ -479,7 +479,7 @@ public void testUsedMaxMindResponseClassesAreAccountedFor() { supportedMaxMindClasses ); assertThat( - "GeoIpDatabase exposes MaxMind response classes that this test does not know what to do with. Add mappings to " + "IpDatabase exposes MaxMind response classes that this test does not know what to do with. Add mappings to " + "TYPE_TO_MAX_MIND_CLASS for the following: " + usedButNotSupportedMaxMindResponseClasses, usedButNotSupportedMaxMindResponseClasses, @@ -490,7 +490,7 @@ public void testUsedMaxMindResponseClassesAreAccountedFor() { usedMaxMindResponseClasses ); assertThat( - "This test claims to support MaxMind response classes that are not exposed in GeoIpDatabase. Remove the following from " + "This test claims to support MaxMind response classes that are not exposed in IpDatabase. Remove the following from " + "TYPE_TO_MAX_MIND_CLASS: " + supportedButNotUsedMaxMindClasses, supportedButNotUsedMaxMindClasses, @@ -618,11 +618,11 @@ private static String getFormattedList(Set fields) { } /* - * This returns all AbstractResponse classes that are returned from getter methods on GeoIpDatabase. + * This returns all AbstractResponse classes that are returned from getter methods on IpDatabase. */ private static Set> getUsedMaxMindResponseClasses() { Set> result = new HashSet<>(); - Method[] methods = GeoIpDatabase.class.getMethods(); + Method[] methods = IpDatabase.class.getMethods(); for (Method method : methods) { if (method.getName().startsWith("get")) { Class returnType = method.getReturnType(); diff --git a/muted-tests.yml b/muted-tests.yml index e02f8208b7fb0..f68ed6caab08d 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -210,14 +210,40 @@ tests: - class: org.elasticsearch.xpack.ml.integration.MlJobIT method: testDeleteJobAfterMissingAliases issue: https://github.com/elastic/elasticsearch/issues/112823 -- class: org.elasticsearch.xpack.test.rest.XPackRestIT - issue: https://github.com/elastic/elasticsearch/issues/111944 - class: org.elasticsearch.datastreams.logsdb.qa.StandardVersusLogsIndexModeRandomDataChallengeRestIT method: testDateHistogramAggregation issue: https://github.com/elastic/elasticsearch/issues/112919 - class: org.elasticsearch.datastreams.logsdb.qa.StandardVersusLogsIndexModeRandomDataChallengeRestIT method: testTermsQuery issue: https://github.com/elastic/elasticsearch/issues/112462 +- class: org.elasticsearch.repositories.blobstore.testkit.analyze.HdfsRepositoryAnalysisRestIT + issue: https://github.com/elastic/elasticsearch/issues/112889 +- class: org.elasticsearch.xpack.test.rest.XPackRestIT + issue: https://github.com/elastic/elasticsearch/issues/111944 +- class: org.elasticsearch.xpack.ml.integration.MlJobIT + method: testCreateJob_WithClashingFieldMappingsFails + issue: https://github.com/elastic/elasticsearch/issues/113046 +- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT + method: test {categorize.Categorize SYNC} + issue: https://github.com/elastic/elasticsearch/issues/113054 +- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT + method: test {categorize.Categorize ASYNC} + issue: https://github.com/elastic/elasticsearch/issues/113055 +- class: org.elasticsearch.upgrades.MultiVersionRepositoryAccessIT + method: testCreateAndRestoreSnapshot + issue: https://github.com/elastic/elasticsearch/issues/113058 +- class: org.elasticsearch.xpack.restart.FullClusterRestartIT + method: testRollupAfterRestart {cluster=UPGRADED} + issue: https://github.com/elastic/elasticsearch/issues/113059 +- class: org.elasticsearch.xpack.restart.FullClusterRestartIT + method: testRollupAfterRestart {cluster=OLD} + issue: https://github.com/elastic/elasticsearch/issues/113061 +- class: org.elasticsearch.upgrades.MultiVersionRepositoryAccessIT + method: testUpgradeMovesRepoToNewMetaVersion + issue: https://github.com/elastic/elasticsearch/issues/113062 +- class: org.elasticsearch.upgrades.MultiVersionRepositoryAccessIT + method: testReadOnlyRepo + issue: https://github.com/elastic/elasticsearch/issues/113060 # Examples: # diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CCSUsageTelemetryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CCSUsageTelemetryIT.java index 331c9142d3f56..27e7a15d2107a 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CCSUsageTelemetryIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CCSUsageTelemetryIT.java @@ -63,6 +63,7 @@ import static org.elasticsearch.action.admin.cluster.stats.CCSUsageTelemetry.ASYNC_FEATURE; import static org.elasticsearch.action.admin.cluster.stats.CCSUsageTelemetry.MRT_FEATURE; +import static org.elasticsearch.action.admin.cluster.stats.CCSUsageTelemetry.PIT_FEATURE; import static org.elasticsearch.action.admin.cluster.stats.CCSUsageTelemetry.WILDCARD_FEATURE; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; @@ -622,6 +623,7 @@ public void testPITSearch() throws ExecutionException, InterruptedException { assertThat(telemetry.getTotalCount(), equalTo(2L)); assertThat(telemetry.getSuccessCount(), equalTo(2L)); + assertThat(telemetry.getFeatureCounts().get(PIT_FEATURE), equalTo(2L)); } public void testCompoundRetrieverSearch() throws ExecutionException, InterruptedException { diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 217e65d4cdda9..20b92b9b64137 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -215,6 +215,8 @@ static TransportVersion def(int id) { public static final TransportVersion CCS_TELEMETRY_STATS = def(8_739_00_0); public static final TransportVersion GLOBAL_RETENTION_TELEMETRY = def(8_740_00_0); public static final TransportVersion ROUTING_TABLE_VERSION_REMOVED = def(8_741_00_0); + public static final TransportVersion ML_SCHEDULED_EVENT_TIME_SHIFT_CONFIGURATION = def(8_742_00_0); + public static final TransportVersion SIMULATE_COMPONENT_TEMPLATES_SUBSTITUTIONS = def(8_743_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/CCSUsageTelemetry.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/CCSUsageTelemetry.java index f2eb9eb945dc3..6c8178282d3c3 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/CCSUsageTelemetry.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/CCSUsageTelemetry.java @@ -65,6 +65,7 @@ public String getName() { public static final String MRT_FEATURE = "mrt_on"; public static final String ASYNC_FEATURE = "async"; public static final String WILDCARD_FEATURE = "wildcards"; + public static final String PIT_FEATURE = "pit"; // The list of known Elastic clients. May be incomplete. public static final Set KNOWN_CLIENTS = Set.of( diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkRequest.java index 6e6af4c016cf6..1425dde28ea3b 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkRequest.java @@ -21,6 +21,7 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.replication.ReplicationRequest; import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.cluster.metadata.ComponentTemplate; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; @@ -38,6 +39,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; @@ -475,4 +477,27 @@ public Set getIndices() { public boolean isSimulated() { return false; // Always false, but may be overridden by a subclass } + + /* + * Returns any component template substitutions that are to be used as part of this bulk request. We would likely only have + * substitutions in the event of a simulated request. + */ + public Map getComponentTemplateSubstitutions() throws IOException { + return Map.of(); + } + + /* + * This copies this bulk request, but without all of its inner requests or the set of indices found in those requests + */ + public BulkRequest shallowClone() { + BulkRequest bulkRequest = new BulkRequest(globalIndex); + bulkRequest.setRefreshPolicy(getRefreshPolicy()); + bulkRequest.waitForActiveShards(waitForActiveShards()); + bulkRequest.timeout(timeout()); + bulkRequest.pipeline(pipeline()); + bulkRequest.routing(routing()); + bulkRequest.requireAlias(requireAlias()); + bulkRequest.requireDataStream(requireDataStream()); + return bulkRequest; + } } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestModifier.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestModifier.java index b3048dc18008b..3e47c78a76354 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestModifier.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestModifier.java @@ -87,10 +87,7 @@ BulkRequest getBulkRequest() { if (itemResponses.isEmpty()) { return bulkRequest; } else { - BulkRequest modifiedBulkRequest = new BulkRequest(); - modifiedBulkRequest.setRefreshPolicy(bulkRequest.getRefreshPolicy()); - modifiedBulkRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); - modifiedBulkRequest.timeout(bulkRequest.timeout()); + BulkRequest modifiedBulkRequest = bulkRequest.shallowClone(); int slot = 0; List> requests = bulkRequest.requests(); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/SimulateBulkRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/SimulateBulkRequest.java index 86e5d3b0985d0..3cc7fa12733bf 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/SimulateBulkRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/SimulateBulkRequest.java @@ -9,16 +9,21 @@ package org.elasticsearch.action.bulk; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.cluster.metadata.ComponentTemplate; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentParserConfiguration; import java.io.IOException; +import java.util.HashMap; import java.util.Map; /** - * This extends BulkRequest with support for providing substitute pipeline definitions. In a user request, the pipeline substitutions - * will look something like this: + * This extends BulkRequest with support for providing substitute pipeline definitions and component template definitions. In a user + * request, the substitutions will look something like this: * * "pipeline_substitutions": { * "my-pipeline-1": { @@ -45,6 +50,29 @@ * } * ] * } + * }, + * "component_template_substitutions": { + * "my-template-1": { + * "template": { + * "settings": { + * "number_of_shards": 1 + * }, + * "mappings": { + * "_source": { + * "enabled": false + * }, + * "properties": { + * "host_name": { + * "type": "keyword" + * }, + * "created_at": { + * "type": "date", + * "format": "EEE MMM dd HH:mm:ss Z yyyy" + * } + * } + * } + * } + * } * } * * The pipelineSubstitutions Map held by this class is intended to be the result of XContentHelper.convertToMap(). The top-level keys @@ -53,27 +81,42 @@ */ public class SimulateBulkRequest extends BulkRequest { private final Map> pipelineSubstitutions; + private final Map> componentTemplateSubstitutions; /** * @param pipelineSubstitutions The pipeline definitions that are to be used in place of any pre-existing pipeline definitions with * the same pipelineId. The key of the map is the pipelineId, and the value the pipeline definition as * parsed by XContentHelper.convertToMap(). + * @param componentTemplateSubstitutions The component template definitions that are to be used in place of any pre-existing + * component template definitions with the same name. */ - public SimulateBulkRequest(@Nullable Map> pipelineSubstitutions) { + public SimulateBulkRequest( + @Nullable Map> pipelineSubstitutions, + @Nullable Map> componentTemplateSubstitutions + ) { super(); this.pipelineSubstitutions = pipelineSubstitutions; + this.componentTemplateSubstitutions = componentTemplateSubstitutions; } @SuppressWarnings("unchecked") public SimulateBulkRequest(StreamInput in) throws IOException { super(in); this.pipelineSubstitutions = (Map>) in.readGenericValue(); + if (in.getTransportVersion().onOrAfter(TransportVersions.SIMULATE_COMPONENT_TEMPLATES_SUBSTITUTIONS)) { + this.componentTemplateSubstitutions = (Map>) in.readGenericValue(); + } else { + componentTemplateSubstitutions = Map.of(); + } } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeGenericValue(pipelineSubstitutions); + if (out.getTransportVersion().onOrAfter(TransportVersions.SIMULATE_COMPONENT_TEMPLATES_SUBSTITUTIONS)) { + out.writeGenericValue(componentTemplateSubstitutions); + } } public Map> getPipelineSubstitutions() { @@ -84,4 +127,37 @@ public Map> getPipelineSubstitutions() { public boolean isSimulated() { return true; } + + @Override + public Map getComponentTemplateSubstitutions() throws IOException { + if (componentTemplateSubstitutions == null) { + return Map.of(); + } + Map result = new HashMap<>(componentTemplateSubstitutions.size()); + for (Map.Entry> rawEntry : componentTemplateSubstitutions.entrySet()) { + result.put(rawEntry.getKey(), convertRawTemplateToComponentTemplate(rawEntry.getValue())); + } + return result; + } + + private static ComponentTemplate convertRawTemplateToComponentTemplate(Map rawTemplate) throws IOException { + ComponentTemplate componentTemplate; + try (var parser = XContentHelper.mapToXContentParser(XContentParserConfiguration.EMPTY, rawTemplate)) { + componentTemplate = ComponentTemplate.parse(parser); + } + return componentTemplate; + } + + @Override + public BulkRequest shallowClone() { + BulkRequest bulkRequest = new SimulateBulkRequest(pipelineSubstitutions, componentTemplateSubstitutions); + bulkRequest.setRefreshPolicy(getRefreshPolicy()); + bulkRequest.waitForActiveShards(waitForActiveShards()); + bulkRequest.timeout(timeout()); + bulkRequest.pipeline(pipeline()); + bulkRequest.routing(routing()); + bulkRequest.requireAlias(requireAlias()); + bulkRequest.requireDataStream(requireDataStream()); + return bulkRequest; + } } diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index cb17f69340f48..e3d663ec13618 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -377,6 +377,9 @@ void executeRequest( if (task.isAsync()) { tl.setFeature(CCSUsageTelemetry.ASYNC_FEATURE); } + if (original.pointInTimeBuilder() != null) { + tl.setFeature(CCSUsageTelemetry.PIT_FEATURE); + } String client = task.getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER); if (client != null) { tl.setClient(client); diff --git a/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java b/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java index e4588e648318c..778136cbf5d31 100644 --- a/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java +++ b/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java @@ -33,6 +33,7 @@ import org.elasticsearch.index.engine.EngineConfig; import org.elasticsearch.index.fielddata.IndexFieldDataService; import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.index.mapper.IgnoredSourceFieldMapper; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.similarity.SimilarityService; import org.elasticsearch.index.store.FsDirectoryFactory; @@ -183,6 +184,8 @@ public final class IndexScopedSettings extends AbstractScopedSettings { IndexSettings.PREFER_ILM_SETTING, DataStreamFailureStoreDefinition.FAILURE_STORE_DEFINITION_VERSION_SETTING, FieldMapper.SYNTHETIC_SOURCE_KEEP_INDEX_SETTING, + IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_WRITE_SETTING, + IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING, // validate that built-in similarities don't get redefined Setting.groupSetting("index.similarity.", (s) -> { diff --git a/server/src/main/java/org/elasticsearch/common/util/AbstractBigArray.java b/server/src/main/java/org/elasticsearch/common/util/AbstractBigArray.java index c2544ee02e0d8..f09d773ac86d0 100644 --- a/server/src/main/java/org/elasticsearch/common/util/AbstractBigArray.java +++ b/server/src/main/java/org/elasticsearch/common/util/AbstractBigArray.java @@ -12,6 +12,7 @@ import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.recycler.Recycler; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasables; import java.lang.reflect.Array; @@ -20,6 +21,7 @@ /** Common implementation for array lists that slice data into fixed-size blocks. */ abstract class AbstractBigArray extends AbstractArray { + @Nullable protected final PageCacheRecycler recycler; private Recycler.V[] cache; diff --git a/server/src/main/java/org/elasticsearch/common/util/BigArrays.java b/server/src/main/java/org/elasticsearch/common/util/BigArrays.java index ddf11128e222b..a33ee4c2edeac 100644 --- a/server/src/main/java/org/elasticsearch/common/util/BigArrays.java +++ b/server/src/main/java/org/elasticsearch/common/util/BigArrays.java @@ -462,6 +462,7 @@ public T set(long index, T value) { } + @Nullable final PageCacheRecycler recycler; @Nullable private final CircuitBreakerService breakerService; @@ -471,13 +472,13 @@ public T set(long index, T value) { private final BigArrays circuitBreakingInstance; private final String breakerName; - public BigArrays(PageCacheRecycler recycler, @Nullable final CircuitBreakerService breakerService, String breakerName) { + public BigArrays(@Nullable PageCacheRecycler recycler, @Nullable final CircuitBreakerService breakerService, String breakerName) { // Checking the breaker is disabled if not specified this(recycler, breakerService, breakerName, false); } protected BigArrays( - PageCacheRecycler recycler, + @Nullable PageCacheRecycler recycler, @Nullable final CircuitBreakerService breakerService, String breakerName, boolean checkBreaker diff --git a/server/src/main/java/org/elasticsearch/index/IndexSettings.java b/server/src/main/java/org/elasticsearch/index/IndexSettings.java index 140133378ec06..41523c6dc2c7e 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexSettings.java +++ b/server/src/main/java/org/elasticsearch/index/IndexSettings.java @@ -25,6 +25,7 @@ import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.mapper.IgnoredSourceFieldMapper; import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.translog.Translog; import org.elasticsearch.ingest.IngestService; @@ -778,6 +779,8 @@ private void setRetentionLeaseMillis(final TimeValue retentionLease) { private volatile long mappingDepthLimit; private volatile long mappingFieldNameLengthLimit; private volatile long mappingDimensionFieldsLimit; + private volatile boolean skipIgnoredSourceWrite; + private volatile boolean skipIgnoredSourceRead; /** * The maximum number of refresh listeners allows on this shard. @@ -936,6 +939,8 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti indexRouting = IndexRouting.fromIndexMetadata(indexMetadata); sourceKeepMode = scopedSettings.get(Mapper.SYNTHETIC_SOURCE_KEEP_INDEX_SETTING); es87TSDBCodecEnabled = scopedSettings.get(TIME_SERIES_ES87TSDB_CODEC_ENABLED_SETTING); + skipIgnoredSourceWrite = scopedSettings.get(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_WRITE_SETTING); + skipIgnoredSourceRead = scopedSettings.get(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING); scopedSettings.addSettingsUpdateConsumer( MergePolicyConfig.INDEX_COMPOUND_FORMAT_SETTING, @@ -1018,6 +1023,11 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti scopedSettings.addSettingsUpdateConsumer(INDEX_MAPPING_DEPTH_LIMIT_SETTING, this::setMappingDepthLimit); scopedSettings.addSettingsUpdateConsumer(INDEX_MAPPING_FIELD_NAME_LENGTH_LIMIT_SETTING, this::setMappingFieldNameLengthLimit); scopedSettings.addSettingsUpdateConsumer(INDEX_MAPPING_DIMENSION_FIELDS_LIMIT_SETTING, this::setMappingDimensionFieldsLimit); + scopedSettings.addSettingsUpdateConsumer( + IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_WRITE_SETTING, + this::setSkipIgnoredSourceWrite + ); + scopedSettings.addSettingsUpdateConsumer(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING, this::setSkipIgnoredSourceRead); } private void setSearchIdleAfter(TimeValue searchIdleAfter) { @@ -1594,6 +1604,22 @@ private void setMappingDimensionFieldsLimit(long value) { this.mappingDimensionFieldsLimit = value; } + public boolean getSkipIgnoredSourceWrite() { + return skipIgnoredSourceWrite; + } + + private void setSkipIgnoredSourceWrite(boolean value) { + this.skipIgnoredSourceWrite = value; + } + + public boolean getSkipIgnoredSourceRead() { + return skipIgnoredSourceRead; + } + + private void setSkipIgnoredSourceRead(boolean value) { + this.skipIgnoredSourceRead = value; + } + /** * The bounds for {@code @timestamp} on this index or * {@code null} if there are no bounds. diff --git a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java index 26db211345eae..38fa8d02db84e 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java @@ -333,7 +333,7 @@ final boolean getClonedSource() { } public final boolean canAddIgnoredField() { - return mappingLookup.isSourceSynthetic() && clonedSource == false; + return mappingLookup.isSourceSynthetic() && clonedSource == false && indexSettings().getSkipIgnoredSourceWrite() == false; } Mapper.SourceKeepMode sourceKeepModeFromIndexSettings() { @@ -367,7 +367,7 @@ public boolean isFieldAppliedFromTemplate(String name) { public void markFieldAsCopyTo(String fieldName) { copyToFields.add(fieldName); - if (mappingLookup.isSourceSynthetic()) { + if (mappingLookup.isSourceSynthetic() && indexSettings().getSkipIgnoredSourceWrite() == false) { /* Mark this field as containing copied data meaning it should not be present in synthetic _source (to be consistent with stored _source). diff --git a/server/src/main/java/org/elasticsearch/index/mapper/IgnoredSourceFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/IgnoredSourceFieldMapper.java index 35bbec6355762..94d862d5cc516 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/IgnoredSourceFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/IgnoredSourceFieldMapper.java @@ -10,12 +10,15 @@ package org.elasticsearch.index.mapper; import org.apache.lucene.document.StoredField; +import org.apache.lucene.index.LeafReader; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.util.ByteUtils; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Tuple; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentType; @@ -27,6 +30,8 @@ import java.util.Comparator; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.stream.Stream; /** @@ -40,6 +45,7 @@ * if we can replace it for all use cases to avoid duplication, assuming that the storage tradeoff is favorable. */ public class IgnoredSourceFieldMapper extends MetadataFieldMapper { + private final IndexSettings indexSettings; // This factor is used to combine two offsets within the same integer: // - the offset of the end of the parent field within the field name (N / PARENT_OFFSET_IN_NAME_OFFSET) @@ -49,12 +55,32 @@ public class IgnoredSourceFieldMapper extends MetadataFieldMapper { public static final String NAME = "_ignored_source"; - public static final IgnoredSourceFieldMapper INSTANCE = new IgnoredSourceFieldMapper(); - - public static final TypeParser PARSER = new FixedTypeParser(context -> INSTANCE); + public static final TypeParser PARSER = new FixedTypeParser(context -> new IgnoredSourceFieldMapper(context.getIndexSettings())); static final NodeFeature TRACK_IGNORED_SOURCE = new NodeFeature("mapper.track_ignored_source"); + /* + Setting to disable encoding and writing values for this field. + This is needed to unblock index functionality in case there is a bug on this code path. + */ + public static final Setting SKIP_IGNORED_SOURCE_WRITE_SETTING = Setting.boolSetting( + "index.mapping.synthetic_source.skip_ignored_source_write", + false, + Setting.Property.Dynamic, + Setting.Property.IndexScope + ); + + /* + Setting to disable reading and decoding values stored in this field. + This is needed to unblock search functionality in case there is a bug on this code path. + */ + public static final Setting SKIP_IGNORED_SOURCE_READ_SETTING = Setting.boolSetting( + "index.mapping.synthetic_source.skip_ignored_source_read", + false, + Setting.Property.Dynamic, + Setting.Property.IndexScope + ); + /* * Container for the ignored field data: * - the full name @@ -108,8 +134,9 @@ public ValueFetcher valueFetcher(SearchExecutionContext context, String format) } } - private IgnoredSourceFieldMapper() { + private IgnoredSourceFieldMapper(IndexSettings indexSettings) { super(IgnoredValuesFieldMapperType.INSTANCE); + this.indexSettings = indexSettings; } @Override @@ -151,6 +178,64 @@ static NameValue decode(Object field) { return new NameValue(name, parentOffset, value, null); } + // In rare cases decoding values stored in this field can fail leading to entire source + // not being available. + // We would like to have an option to lose some values in synthetic source + // but have search not fail. + public static Set ensureLoaded(Set fieldsToLoadForSyntheticSource, IndexSettings indexSettings) { + if (indexSettings.getSkipIgnoredSourceRead() == false) { + fieldsToLoadForSyntheticSource.add(NAME); + } + + return fieldsToLoadForSyntheticSource; + } + + @Override + protected SyntheticSourceSupport syntheticSourceSupport() { + // This loader controls if this field is loaded in scope of synthetic source constructions. + // In rare cases decoding values stored in this field can fail leading to entire source + // not being available. + // We would like to have an option to lose some values in synthetic source + // but have search not fail. + return new SyntheticSourceSupport.Native(new SourceLoader.SyntheticFieldLoader() { + @Override + public Stream> storedFieldLoaders() { + if (indexSettings.getSkipIgnoredSourceRead()) { + return Stream.empty(); + } + + // Values are handled in `SourceLoader`. + return Stream.of(Map.entry(NAME, (v) -> {})); + } + + @Override + public DocValuesLoader docValuesLoader(LeafReader leafReader, int[] docIdsInLeaf) throws IOException { + return null; + } + + @Override + public boolean hasValue() { + return false; + } + + @Override + public void write(XContentBuilder b) throws IOException { + + } + + @Override + public String fieldName() { + // Does not really matter. + return NAME; + } + + @Override + public void reset() { + + } + }); + } + public record MappedNameValue(NameValue nameValue, XContentType type, Map map) {} /** diff --git a/server/src/main/java/org/elasticsearch/index/mapper/NestedObjectMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/NestedObjectMapper.java index adf1b329d9e83..fc5f28dd51c9d 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/NestedObjectMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/NestedObjectMapper.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.Explicit; import org.elasticsearch.common.lucene.search.Queries; import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.fieldvisitor.LeafStoredFieldLoader; @@ -48,11 +49,18 @@ public static class Builder extends ObjectMapper.Builder { private Explicit includeInParent = Explicit.IMPLICIT_FALSE; private final IndexVersion indexCreatedVersion; private final Function bitSetProducer; + private final IndexSettings indexSettings; - public Builder(String name, IndexVersion indexCreatedVersion, Function bitSetProducer) { + public Builder( + String name, + IndexVersion indexCreatedVersion, + Function bitSetProducer, + IndexSettings indexSettings + ) { super(name, Optional.empty()); this.indexCreatedVersion = indexCreatedVersion; this.bitSetProducer = bitSetProducer; + this.indexSettings = indexSettings; } Builder includeInRoot(boolean includeInRoot) { @@ -113,7 +121,8 @@ public NestedObjectMapper build(MapperBuilderContext context) { parentTypeFilter, nestedTypePath, nestedTypeFilter, - bitSetProducer + bitSetProducer, + indexSettings ); } } @@ -128,7 +137,8 @@ public Mapper.Builder parse(String name, Map node, MappingParser NestedObjectMapper.Builder builder = new NestedObjectMapper.Builder( name, parserContext.indexVersionCreated(), - parserContext::bitSetProducer + parserContext::bitSetProducer, + parserContext.getIndexSettings() ); parseNested(name, node, builder); parseObjectFields(node, parserContext, builder); @@ -195,6 +205,7 @@ public MapperBuilderContext createChildContext(String name, Dynamic dynamic) { private final Query nestedTypeFilter; // Function to create a bitset for identifying parent documents private final Function bitsetProducer; + private final IndexSettings indexSettings; NestedObjectMapper( String name, @@ -208,7 +219,8 @@ public MapperBuilderContext createChildContext(String name, Dynamic dynamic) { Query parentTypeFilter, String nestedTypePath, Query nestedTypeFilter, - Function bitsetProducer + Function bitsetProducer, + IndexSettings indexSettings ) { super(name, fullPath, enabled, Optional.empty(), storeArraySource, dynamic, mappers); this.parentTypeFilter = parentTypeFilter; @@ -217,6 +229,7 @@ public MapperBuilderContext createChildContext(String name, Dynamic dynamic) { this.includeInParent = includeInParent; this.includeInRoot = includeInRoot; this.bitsetProducer = bitsetProducer; + this.indexSettings = indexSettings; } public Query parentTypeFilter() { @@ -254,7 +267,7 @@ public Map getChildren() { @Override public ObjectMapper.Builder newBuilder(IndexVersion indexVersionCreated) { - NestedObjectMapper.Builder builder = new NestedObjectMapper.Builder(leafName(), indexVersionCreated, bitsetProducer); + NestedObjectMapper.Builder builder = new NestedObjectMapper.Builder(leafName(), indexVersionCreated, bitsetProducer, indexSettings); builder.enabled = enabled; builder.dynamic = dynamic; builder.includeInRoot = includeInRoot; @@ -276,7 +289,8 @@ NestedObjectMapper withoutMappers() { parentTypeFilter, nestedTypePath, nestedTypeFilter, - bitsetProducer + bitsetProducer, + indexSettings ); } @@ -351,7 +365,8 @@ public ObjectMapper merge(Mapper mergeWith, MapperMergeContext parentMergeContex parentTypeFilter, nestedTypePath, nestedTypeFilter, - bitsetProducer + bitsetProducer, + indexSettings ); } @@ -384,7 +399,9 @@ public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { } SourceLoader sourceLoader = new SourceLoader.Synthetic(() -> super.syntheticFieldLoader(mappers.values().stream(), true), NOOP); - var storedFieldLoader = org.elasticsearch.index.fieldvisitor.StoredFieldLoader.create(false, sourceLoader.requiredStoredFields()); + // Some synthetic source use cases require using _ignored_source field + var requiredStoredFields = IgnoredSourceFieldMapper.ensureLoaded(sourceLoader.requiredStoredFields(), indexSettings); + var storedFieldLoader = org.elasticsearch.index.fieldvisitor.StoredFieldLoader.create(false, requiredStoredFields); return new NestedSyntheticFieldLoader( storedFieldLoader, sourceLoader, diff --git a/server/src/main/java/org/elasticsearch/index/mapper/SourceLoader.java b/server/src/main/java/org/elasticsearch/index/mapper/SourceLoader.java index 43bf6f2bd83dd..baff3835d104b 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/SourceLoader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/SourceLoader.java @@ -118,7 +118,6 @@ public Synthetic(Supplier fieldLoaderSupplier, SourceField .storedFieldLoaders() .map(Map.Entry::getKey) .collect(Collectors.toSet()); - this.requiredStoredFields.add(IgnoredSourceFieldMapper.NAME); this.metrics = metrics; } diff --git a/server/src/main/java/org/elasticsearch/rest/action/ingest/RestSimulateIngestAction.java b/server/src/main/java/org/elasticsearch/rest/action/ingest/RestSimulateIngestAction.java index ef9252072c526..6de15b0046f1b 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/ingest/RestSimulateIngestAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/ingest/RestSimulateIngestAction.java @@ -75,7 +75,8 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC Tuple sourceTuple = request.contentOrSourceParam(); Map sourceMap = XContentHelper.convertToMap(sourceTuple.v2(), false, sourceTuple.v1()).v2(); SimulateBulkRequest bulkRequest = new SimulateBulkRequest( - (Map>) sourceMap.remove("pipeline_substitutions") + (Map>) sourceMap.remove("pipeline_substitutions"), + (Map>) sourceMap.remove("component_template_substitutions") ); BytesReference transformedData = convertToBulkRequestXContentBytes(sourceMap); bulkRequest.add( diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/EmptyTDigestState.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/EmptyTDigestState.java index cf93e4f8fac93..6ae9c655df3e8 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/EmptyTDigestState.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/EmptyTDigestState.java @@ -9,10 +9,12 @@ package org.elasticsearch.search.aggregations.metrics; +import org.elasticsearch.tdigest.arrays.WrapperTDigestArrays; + public final class EmptyTDigestState extends TDigestState { public EmptyTDigestState() { // Use the sorting implementation to minimize memory allocation. - super(Type.SORTING, 1.0D); + super(WrapperTDigestArrays.INSTANCE, Type.SORTING, 1.0D); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TDigestState.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TDigestState.java index 336efb78dfc03..48bdb59e430a5 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TDigestState.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TDigestState.java @@ -13,6 +13,8 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.tdigest.Centroid; import org.elasticsearch.tdigest.TDigest; +import org.elasticsearch.tdigest.arrays.TDigestArrays; +import org.elasticsearch.tdigest.arrays.WrapperTDigestArrays; import java.io.IOException; import java.util.Collection; @@ -47,14 +49,22 @@ static Type valueForHighAccuracy() { private final Type type; + /** + * @deprecated this method will be removed after all usages are replaced + */ + @Deprecated + public static TDigestState create(double compression) { + return create(WrapperTDigestArrays.INSTANCE, compression); + } + /** * Default factory for TDigestState. The underlying {@link org.elasticsearch.tdigest.TDigest} implementation is optimized for * performance, potentially providing slightly inaccurate results compared to other, substantially slower implementations. * @param compression the compression factor for the underlying {@link org.elasticsearch.tdigest.TDigest} object * @return a TDigestState object that's optimized for performance */ - public static TDigestState create(double compression) { - return new TDigestState(Type.defaultValue(), compression); + public static TDigestState create(TDigestArrays arrays, double compression) { + return new TDigestState(arrays, Type.defaultValue(), compression); } /** @@ -62,8 +72,16 @@ public static TDigestState create(double compression) { * @param compression the compression factor for the underlying {@link org.elasticsearch.tdigest.TDigest} object * @return a TDigestState object that's optimized for performance */ - public static TDigestState createOptimizedForAccuracy(double compression) { - return new TDigestState(Type.valueForHighAccuracy(), compression); + public static TDigestState createOptimizedForAccuracy(TDigestArrays arrays, double compression) { + return new TDigestState(arrays, Type.valueForHighAccuracy(), compression); + } + + /** + * @deprecated this method will be removed after all usages are replaced + */ + @Deprecated + public static TDigestState create(double compression, TDigestExecutionHint executionHint) { + return create(WrapperTDigestArrays.INSTANCE, compression, executionHint); } /** @@ -74,10 +92,10 @@ public static TDigestState createOptimizedForAccuracy(double compression) { * @param executionHint controls which implementation is used; accepted values are 'high_accuracy' and '' (default) * @return a TDigestState object */ - public static TDigestState create(double compression, TDigestExecutionHint executionHint) { + public static TDigestState create(TDigestArrays arrays, double compression, TDigestExecutionHint executionHint) { return switch (executionHint) { - case HIGH_ACCURACY -> createOptimizedForAccuracy(compression); - case DEFAULT -> create(compression); + case HIGH_ACCURACY -> createOptimizedForAccuracy(arrays, compression); + case DEFAULT -> create(arrays, compression); }; } @@ -88,15 +106,15 @@ public static TDigestState create(double compression, TDigestExecutionHint execu * @return a TDigestState object */ public static TDigestState createUsingParamsFrom(TDigestState state) { - return new TDigestState(state.type, state.compression); + return new TDigestState(WrapperTDigestArrays.INSTANCE, state.type, state.compression); } - protected TDigestState(Type type, double compression) { + protected TDigestState(TDigestArrays arrays, Type type, double compression) { tdigest = switch (type) { - case HYBRID -> TDigest.createHybridDigest(compression); - case AVL_TREE -> TDigest.createAvlTreeDigest(compression); - case SORTING -> TDigest.createSortingDigest(); - case MERGING -> TDigest.createMergingDigest(compression); + case HYBRID -> TDigest.createHybridDigest(arrays, compression); + case AVL_TREE -> TDigest.createAvlTreeDigest(arrays, compression); + case SORTING -> TDigest.createSortingDigest(arrays); + case MERGING -> TDigest.createMergingDigest(arrays, compression); }; this.type = type; this.compression = compression; @@ -120,15 +138,23 @@ public static void write(TDigestState state, StreamOutput out) throws IOExceptio } } + /** + * @deprecated this method will be removed after all usages are replaced + */ + @Deprecated public static TDigestState read(StreamInput in) throws IOException { + return read(WrapperTDigestArrays.INSTANCE, in); + } + + public static TDigestState read(TDigestArrays arrays, StreamInput in) throws IOException { double compression = in.readDouble(); TDigestState state; long size = 0; if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_9_X)) { - state = new TDigestState(Type.valueOf(in.readString()), compression); + state = new TDigestState(arrays, Type.valueOf(in.readString()), compression); size = in.readVLong(); } else { - state = new TDigestState(Type.valueForHighAccuracy(), compression); + state = new TDigestState(arrays, Type.valueForHighAccuracy(), compression); } int n = in.readVInt(); if (size > 0) { diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkRequestTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkRequestTests.java index 643e2d90bf615..c601401a1c49d 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkRequestTests.java @@ -475,4 +475,24 @@ public void testUnsupportedAction() { allOf(containsString("Malformed action/metadata line [1]"), containsString("found [get")) ); } + + public void testShallowClone() { + BulkRequest bulkRequest = new BulkRequest(randomBoolean() ? null : randomAlphaOfLength(10)); + bulkRequest.setRefreshPolicy(randomFrom(RefreshPolicy.values())); + bulkRequest.waitForActiveShards(randomIntBetween(1, 10)); + bulkRequest.timeout(randomTimeValue()); + bulkRequest.pipeline(randomBoolean() ? null : randomAlphaOfLength(10)); + bulkRequest.routing(randomBoolean() ? null : randomAlphaOfLength(10)); + bulkRequest.requireAlias(randomBoolean()); + bulkRequest.requireDataStream(randomBoolean()); + BulkRequest shallowCopy = bulkRequest.shallowClone(); + assertThat(shallowCopy.requests, equalTo(List.of())); + assertThat(shallowCopy.getRefreshPolicy(), equalTo(bulkRequest.getRefreshPolicy())); + assertThat(shallowCopy.waitForActiveShards(), equalTo(bulkRequest.waitForActiveShards())); + assertThat(shallowCopy.timeout(), equalTo(bulkRequest.timeout())); + assertThat(shallowCopy.pipeline(), equalTo(bulkRequest.pipeline())); + assertThat(shallowCopy.routing(), equalTo(bulkRequest.routing())); + assertThat(shallowCopy.requireAlias(), equalTo(bulkRequest.requireAlias())); + assertThat(shallowCopy.requireDataStream(), equalTo(bulkRequest.requireDataStream())); + } } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/SimulateBulkRequestTests.java b/server/src/test/java/org/elasticsearch/action/bulk/SimulateBulkRequestTests.java index 2e347e052125a..b6b1770e2ed5c 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/SimulateBulkRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/SimulateBulkRequestTests.java @@ -9,24 +9,36 @@ package org.elasticsearch.action.bulk; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.cluster.metadata.ComponentTemplate; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; public class SimulateBulkRequestTests extends ESTestCase { public void testSerialization() throws Exception { - testSerialization(getTestPipelineSubstitutions()); - testSerialization(null); - testSerialization(Map.of()); + testSerialization(getTestPipelineSubstitutions(), getTestTemplateSubstitutions()); + testSerialization(getTestPipelineSubstitutions(), null); + testSerialization(null, getTestTemplateSubstitutions()); + testSerialization(null, null); + testSerialization(Map.of(), Map.of()); } - private void testSerialization(Map> pipelineSubstitutions) throws IOException { - SimulateBulkRequest simulateBulkRequest = new SimulateBulkRequest(pipelineSubstitutions); + private void testSerialization( + Map> pipelineSubstitutions, + Map> templateSubstitutions + ) throws IOException { + SimulateBulkRequest simulateBulkRequest = new SimulateBulkRequest(pipelineSubstitutions, templateSubstitutions); /* * Note: SimulateBulkRequest does not implement equals or hashCode, so we can't test serialization in the usual way for a * Writable @@ -35,6 +47,94 @@ private void testSerialization(Map> pipelineSubstitu assertThat(copy.getPipelineSubstitutions(), equalTo(simulateBulkRequest.getPipelineSubstitutions())); } + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testGetComponentTemplateSubstitutions() throws IOException { + SimulateBulkRequest simulateBulkRequest = new SimulateBulkRequest(Map.of(), Map.of()); + assertThat(simulateBulkRequest.getComponentTemplateSubstitutions(), equalTo(Map.of())); + String substituteComponentTemplatesString = """ + { + "mappings_template": { + "template": { + "mappings": { + "dynamic": "true", + "properties": { + "foo": { + "type": "keyword" + } + } + } + } + }, + "settings_template": { + "template": { + "settings": { + "index": { + "default_pipeline": "bar-pipeline" + } + } + } + } + } + """; + + Map tempMap = XContentHelper.convertToMap( + new BytesArray(substituteComponentTemplatesString.getBytes(StandardCharsets.UTF_8)), + randomBoolean(), + XContentType.JSON + ).v2(); + Map> substituteComponentTemplates = (Map>) tempMap; + simulateBulkRequest = new SimulateBulkRequest(Map.of(), substituteComponentTemplates); + Map componentTemplateSubstitutions = simulateBulkRequest.getComponentTemplateSubstitutions(); + assertThat(componentTemplateSubstitutions.size(), equalTo(2)); + assertThat( + XContentHelper.convertToMap( + XContentHelper.toXContent( + componentTemplateSubstitutions.get("mappings_template").template(), + XContentType.JSON, + randomBoolean() + ), + randomBoolean(), + XContentType.JSON + ).v2(), + equalTo(substituteComponentTemplates.get("mappings_template").get("template")) + ); + assertNull(componentTemplateSubstitutions.get("mappings_template").template().settings()); + assertNull(componentTemplateSubstitutions.get("settings_template").template().mappings()); + assertThat(componentTemplateSubstitutions.get("settings_template").template().settings().size(), equalTo(1)); + assertThat( + componentTemplateSubstitutions.get("settings_template").template().settings().get("index.default_pipeline"), + equalTo("bar-pipeline") + ); + } + + public void testShallowClone() throws IOException { + SimulateBulkRequest simulateBulkRequest = new SimulateBulkRequest(getTestPipelineSubstitutions(), getTestTemplateSubstitutions()); + simulateBulkRequest.setRefreshPolicy(randomFrom(WriteRequest.RefreshPolicy.values())); + simulateBulkRequest.waitForActiveShards(randomIntBetween(1, 10)); + simulateBulkRequest.timeout(randomTimeValue()); + simulateBulkRequest.pipeline(randomBoolean() ? null : randomAlphaOfLength(10)); + simulateBulkRequest.routing(randomBoolean() ? null : randomAlphaOfLength(10)); + simulateBulkRequest.requireAlias(randomBoolean()); + simulateBulkRequest.requireDataStream(randomBoolean()); + BulkRequest shallowCopy = simulateBulkRequest.shallowClone(); + assertThat(shallowCopy, instanceOf(SimulateBulkRequest.class)); + SimulateBulkRequest simulateBulkRequestCopy = (SimulateBulkRequest) shallowCopy; + assertThat(simulateBulkRequestCopy.requests, equalTo(List.of())); + assertThat( + simulateBulkRequestCopy.getComponentTemplateSubstitutions(), + equalTo(simulateBulkRequest.getComponentTemplateSubstitutions()) + ); + assertThat(simulateBulkRequestCopy.getPipelineSubstitutions(), equalTo(simulateBulkRequest.getPipelineSubstitutions())); + assertThat(simulateBulkRequestCopy.getRefreshPolicy(), equalTo(simulateBulkRequest.getRefreshPolicy())); + assertThat(simulateBulkRequestCopy.waitForActiveShards(), equalTo(simulateBulkRequest.waitForActiveShards())); + assertThat(simulateBulkRequestCopy.timeout(), equalTo(simulateBulkRequest.timeout())); + assertThat(shallowCopy.pipeline(), equalTo(simulateBulkRequest.pipeline())); + assertThat(shallowCopy.routing(), equalTo(simulateBulkRequest.routing())); + assertThat(shallowCopy.requireAlias(), equalTo(simulateBulkRequest.requireAlias())); + assertThat(shallowCopy.requireDataStream(), equalTo(simulateBulkRequest.requireDataStream())); + + } + private static Map> getTestPipelineSubstitutions() { return Map.of( "pipeline1", @@ -43,4 +143,16 @@ private static Map> getTestPipelineSubstitutions() { Map.of("processors", List.of(Map.of("processor3", Map.of()))) ); } + + private static Map> getTestTemplateSubstitutions() { + return Map.of( + "template1", + Map.of( + "template", + Map.of("mappings", Map.of("_source", Map.of("enabled", false), "properties", Map.of()), "settings", Map.of()) + ), + "template2", + Map.of("template", Map.of("mappings", Map.of(), "settings", Map.of())) + ); + } } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportSimulateBulkActionTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportSimulateBulkActionTests.java index 77a463e5ca3e5..e3c863ee69985 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportSimulateBulkActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportSimulateBulkActionTests.java @@ -135,7 +135,7 @@ public void tearDown() throws Exception { public void testIndexData() { Task task = mock(Task.class); // unused - BulkRequest bulkRequest = new SimulateBulkRequest((Map>) null); + BulkRequest bulkRequest = new SimulateBulkRequest(null, null); int bulkItemCount = randomIntBetween(0, 200); for (int i = 0; i < bulkItemCount; i++) { Map source = Map.of(randomAlphaOfLength(10), randomAlphaOfLength(5)); @@ -218,7 +218,7 @@ public void testIndexDataWithValidation() throws IOException { * (7) An indexing request to a nonexistent index that matches no templates */ Task task = mock(Task.class); // unused - BulkRequest bulkRequest = new SimulateBulkRequest((Map>) null); + BulkRequest bulkRequest = new SimulateBulkRequest(null, null); int bulkItemCount = randomIntBetween(0, 200); Map indicesMap = new HashMap<>(); Map v1Templates = new HashMap<>(); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/FieldAliasMapperValidationTests.java b/server/src/test/java/org/elasticsearch/index/mapper/FieldAliasMapperValidationTests.java index f303171c7e465..d48c5550631cd 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/FieldAliasMapperValidationTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/FieldAliasMapperValidationTests.java @@ -186,9 +186,8 @@ private static ObjectMapper createObjectMapper(String name) { } private static NestedObjectMapper createNestedObjectMapper(String name) { - return new NestedObjectMapper.Builder(name, IndexVersion.current(), query -> { throw new UnsupportedOperationException(); }).build( - MapperBuilderContext.root(false, false) - ); + return new NestedObjectMapper.Builder(name, IndexVersion.current(), query -> { throw new UnsupportedOperationException(); }, null) + .build(MapperBuilderContext.root(false, false)); } private static MappingLookup createMappingLookup( diff --git a/server/src/test/java/org/elasticsearch/index/mapper/IgnoredSourceFieldMapperConfigurationTests.java b/server/src/test/java/org/elasticsearch/index/mapper/IgnoredSourceFieldMapperConfigurationTests.java new file mode 100644 index 0000000000000..e08ace01e88e8 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/mapper/IgnoredSourceFieldMapperConfigurationTests.java @@ -0,0 +1,144 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.mapper; + +import org.apache.lucene.index.DirectoryReader; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.CheckedConsumer; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; + +public class IgnoredSourceFieldMapperConfigurationTests extends MapperServiceTestCase { + public void testDisableIgnoredSourceRead() throws IOException { + var mapperService = mapperServiceWithCustomSettings( + Map.of(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING.getKey(), true), + b -> { + b.startObject("fallback_field"); + { + b.field("type", "long").field("doc_values", "false"); + } + b.endObject(); + b.startObject("disabled_object"); + { + b.field("enabled", "false"); + b.startObject("properties"); + { + b.startObject("field").field("type", "keyword").endObject(); + } + b.endObject(); + } + b.endObject(); + } + ); + + CheckedConsumer inputDocument = b -> { + b.field("fallback_field", 111); + b.startObject("disabled_object"); + { + b.field("field", "hey"); + } + b.endObject(); + }; + + var doc = mapperService.documentMapper().parse(source(inputDocument)); + // Field was written. + assertNotNull(doc.docs().get(0).getField(IgnoredSourceFieldMapper.NAME)); + + String syntheticSource = syntheticSource(mapperService.documentMapper(), inputDocument); + // Values are not loaded. + assertEquals("{}", syntheticSource); + + mapperService.getIndexSettings() + .getScopedSettings() + .applySettings(Settings.builder().put(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING.getKey(), false).build()); + + doc = mapperService.documentMapper().parse(source(inputDocument)); + // Field was written. + assertNotNull(doc.docs().get(0).getField(IgnoredSourceFieldMapper.NAME)); + + syntheticSource = syntheticSource(mapperService.documentMapper(), inputDocument); + // Values are loaded. + assertEquals("{\"disabled_object\":{\"field\":\"hey\"},\"fallback_field\":111}", syntheticSource); + } + + public void testDisableIgnoredSourceWrite() throws IOException { + var mapperService = mapperServiceWithCustomSettings( + Map.of(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_WRITE_SETTING.getKey(), true), + b -> { + b.startObject("fallback_field"); + { + b.field("type", "long").field("doc_values", "false"); + } + b.endObject(); + b.startObject("disabled_object"); + { + b.field("enabled", "false"); + b.startObject("properties"); + { + b.startObject("field").field("type", "keyword").endObject(); + } + b.endObject(); + } + b.endObject(); + } + ); + + CheckedConsumer inputDocument = b -> { + b.field("fallback_field", 111); + b.startObject("disabled_object"); + { + b.field("field", "hey"); + } + b.endObject(); + }; + + var doc = mapperService.documentMapper().parse(source(inputDocument)); + // Field is not written. + assertNull(doc.docs().get(0).getField(IgnoredSourceFieldMapper.NAME)); + + String syntheticSource = syntheticSource(mapperService.documentMapper(), inputDocument); + // Values are not loaded. + assertEquals("{}", syntheticSource); + + mapperService.getIndexSettings() + .getScopedSettings() + .applySettings(Settings.builder().put(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_WRITE_SETTING.getKey(), false).build()); + + doc = mapperService.documentMapper().parse(source(inputDocument)); + // Field was written. + assertNotNull(doc.docs().get(0).getField(IgnoredSourceFieldMapper.NAME)); + + syntheticSource = syntheticSource(mapperService.documentMapper(), inputDocument); + // Values are loaded. + assertEquals("{\"disabled_object\":{\"field\":\"hey\"},\"fallback_field\":111}", syntheticSource); + } + + private MapperService mapperServiceWithCustomSettings( + Map customSettings, + CheckedConsumer mapping + ) throws IOException { + var settings = Settings.builder(); + for (var entry : customSettings.entrySet()) { + settings.put(entry.getKey(), entry.getValue()); + } + + return createMapperService(settings.build(), syntheticSourceMapping(mapping)); + } + + protected void validateRoundTripReader(String syntheticSource, DirectoryReader reader, DirectoryReader roundTripReader) + throws IOException { + // Disabling this field via index settings leads to some values not being present in source and assertReaderEquals validation to + // fail as a result. + // This is expected, these settings are introduced only as a safety net when related logic blocks ingestion or search + // and we would rather lose some part of source but unblock the workflow. + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/NestedLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/NestedLookupTests.java index 4953a330107b4..d209d08b48469 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/NestedLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/NestedLookupTests.java @@ -65,9 +65,8 @@ public void testMultiLevelParents() throws IOException { } private static NestedObjectMapper buildMapper(String name) { - return new NestedObjectMapper.Builder(name, IndexVersion.current(), query -> { throw new UnsupportedOperationException(); }).build( - MapperBuilderContext.root(false, false) - ); + return new NestedObjectMapper.Builder(name, IndexVersion.current(), query -> { throw new UnsupportedOperationException(); }, null) + .build(MapperBuilderContext.root(false, false)); } public void testAllParentFilters() { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/NestedObjectMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/NestedObjectMapperTests.java index 37971a1908fa2..0a954115e77f6 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/NestedObjectMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/NestedObjectMapperTests.java @@ -1505,10 +1505,10 @@ public void testIndexTemplatesMergeIncludes() throws IOException { public void testMergeNested() { NestedObjectMapper firstMapper = new NestedObjectMapper.Builder("nested1", IndexVersion.current(), query -> { throw new UnsupportedOperationException(); - }).includeInParent(true).includeInRoot(true).build(MapperBuilderContext.root(false, false)); + }, null).includeInParent(true).includeInRoot(true).build(MapperBuilderContext.root(false, false)); NestedObjectMapper secondMapper = new NestedObjectMapper.Builder("nested1", IndexVersion.current(), query -> { throw new UnsupportedOperationException(); - }).includeInParent(false).includeInRoot(true).build(MapperBuilderContext.root(false, false)); + }, null).includeInParent(false).includeInRoot(true).build(MapperBuilderContext.root(false, false)); MapperException e = expectThrows( MapperException.class, @@ -1855,7 +1855,7 @@ public void testNestedMapperBuilderContextConstructor() { MergeReason mergeReason = randomFrom(MergeReason.values()); MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(isSourceSynthetic, isDataStream, mergeReason); mapperBuilderContext = mapperBuilderContext.createChildContext("name", parentContainsDimensions, randomFrom(Dynamic.values())); - NestedObjectMapper.Builder builder = new NestedObjectMapper.Builder("name", IndexVersion.current(), query -> null); + NestedObjectMapper.Builder builder = new NestedObjectMapper.Builder("name", IndexVersion.current(), query -> null, null); builder.add(new Mapper.Builder("name") { @Override public Mapper build(MapperBuilderContext context) { @@ -1876,7 +1876,7 @@ public void testNestedMapperMergeContextRootConstructor() { MergeReason mergeReason = randomFrom(MergeReason.values()); { MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false, mergeReason); - NestedObjectMapper.Builder builder = new NestedObjectMapper.Builder("name", IndexVersion.current(), query -> null); + NestedObjectMapper.Builder builder = new NestedObjectMapper.Builder("name", IndexVersion.current(), query -> null, null); NestedObjectMapper nestedObjectMapper = builder.build(mapperBuilderContext); MapperMergeContext mapperMergeContext = MapperMergeContext.root(isSourceSynthetic, isDataStream, mergeReason, randomLong()); MapperMergeContext childMergeContext = nestedObjectMapper.createChildContext(mapperMergeContext, "name"); @@ -1907,7 +1907,7 @@ public void testNestedMapperMergeContextFromConstructor() { MergeReason mergeReason = randomFrom(MergeReason.values()); MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(isSourceSynthetic, isDataStream, mergeReason); mapperBuilderContext = mapperBuilderContext.createChildContext("name", parentContainsDimensions, randomFrom(Dynamic.values())); - NestedObjectMapper.Builder builder = new NestedObjectMapper.Builder("name", IndexVersion.current(), query -> null); + NestedObjectMapper.Builder builder = new NestedObjectMapper.Builder("name", IndexVersion.current(), query -> null, null); NestedObjectMapper nestedObjectMapper = builder.build(mapperBuilderContext); MapperMergeContext mapperMergeContext = MapperMergeContext.from(mapperBuilderContext, randomLong()); diff --git a/server/src/test/java/org/elasticsearch/ingest/SimulateIngestServiceTests.java b/server/src/test/java/org/elasticsearch/ingest/SimulateIngestServiceTests.java index 554bc34cce7cc..332a04e40e43d 100644 --- a/server/src/test/java/org/elasticsearch/ingest/SimulateIngestServiceTests.java +++ b/server/src/test/java/org/elasticsearch/ingest/SimulateIngestServiceTests.java @@ -65,7 +65,7 @@ public void testGetPipeline() { ingestService.innerUpdatePipelines(ingestMetadata); { // First we make sure that if there are no substitutions that we get our original pipeline back: - SimulateBulkRequest simulateBulkRequest = new SimulateBulkRequest((Map>) null); + SimulateBulkRequest simulateBulkRequest = new SimulateBulkRequest(null, null); SimulateIngestService simulateIngestService = new SimulateIngestService(ingestService, simulateBulkRequest); Pipeline pipeline = simulateIngestService.getPipeline("pipeline1"); assertThat(pipeline.getProcessors(), contains(transformedMatch(Processor::getType, equalTo("processor1")))); @@ -83,7 +83,7 @@ public void testGetPipeline() { ); pipelineSubstitutions.put("pipeline2", newHashMap("processors", List.of(newHashMap("processor3", Collections.emptyMap())))); - SimulateBulkRequest simulateBulkRequest = new SimulateBulkRequest(pipelineSubstitutions); + SimulateBulkRequest simulateBulkRequest = new SimulateBulkRequest(pipelineSubstitutions, null); SimulateIngestService simulateIngestService = new SimulateIngestService(ingestService, simulateBulkRequest); Pipeline pipeline1 = simulateIngestService.getPipeline("pipeline1"); assertThat( @@ -103,7 +103,7 @@ public void testGetPipeline() { */ Map> pipelineSubstitutions = new HashMap<>(); pipelineSubstitutions.put("pipeline2", newHashMap("processors", List.of(newHashMap("processor3", Collections.emptyMap())))); - SimulateBulkRequest simulateBulkRequest = new SimulateBulkRequest(pipelineSubstitutions); + SimulateBulkRequest simulateBulkRequest = new SimulateBulkRequest(pipelineSubstitutions, null); SimulateIngestService simulateIngestService = new SimulateIngestService(ingestService, simulateBulkRequest); Pipeline pipeline1 = simulateIngestService.getPipeline("pipeline1"); assertThat(pipeline1.getProcessors(), contains(transformedMatch(Processor::getType, equalTo("processor1")))); diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/nested/NestedAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/nested/NestedAggregatorTests.java index 78943ed6ccdd7..c7e9f02c283ab 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/nested/NestedAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/nested/NestedAggregatorTests.java @@ -913,8 +913,7 @@ protected List objectMappers() { ); public static NestedObjectMapper nestedObject(String path) { - return new NestedObjectMapper.Builder(path, IndexVersion.current(), query -> { throw new UnsupportedOperationException(); }).build( - MapperBuilderContext.root(false, false) - ); + return new NestedObjectMapper.Builder(path, IndexVersion.current(), query -> { throw new UnsupportedOperationException(); }, null) + .build(MapperBuilderContext.root(false, false)); } } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/TDigestStateTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/TDigestStateTests.java index 84fb4728ec6d6..e7799a133b5af 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/TDigestStateTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/TDigestStateTests.java @@ -16,6 +16,8 @@ import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.tdigest.arrays.TDigestArrays; +import org.elasticsearch.tdigest.arrays.WrapperTDigestArrays; import org.elasticsearch.test.ESTestCase; import org.hamcrest.Matchers; @@ -31,7 +33,7 @@ public class TDigestStateTests extends ESTestCase { public void testMoreThan4BValues() { // Regression test for #19528 // See https://github.com/tdunning/t-digest/pull/70/files#diff-4487072cee29b939694825647928f742R439 - TDigestState digest = TDigestState.create(100); + TDigestState digest = TDigestState.create(arrays(), 100); for (int i = 0; i < 1000; ++i) { digest.add(randomDouble()); } @@ -57,9 +59,9 @@ public void testMoreThan4BValues() { public void testEqualsHashCode() { final TDigestState empty1 = new EmptyTDigestState(); final TDigestState empty2 = new EmptyTDigestState(); - final TDigestState a = TDigestState.create(200); - final TDigestState b = TDigestState.create(100); - final TDigestState c = TDigestState.create(100); + final TDigestState a = TDigestState.create(arrays(), 200); + final TDigestState b = TDigestState.create(arrays(), 100); + final TDigestState c = TDigestState.create(arrays(), 100); assertEquals(empty1, empty2); assertEquals(empty1.hashCode(), empty2.hashCode()); @@ -101,9 +103,9 @@ public void testHash() { final Set set = new HashSet<>(); final TDigestState empty1 = new EmptyTDigestState(); final TDigestState empty2 = new EmptyTDigestState(); - final TDigestState a = TDigestState.create(200); - final TDigestState b = TDigestState.create(100); - final TDigestState c = TDigestState.create(100); + final TDigestState a = TDigestState.create(arrays(), 200); + final TDigestState b = TDigestState.create(arrays(), 100); + final TDigestState c = TDigestState.create(arrays(), 100); a.add(randomDouble()); b.add(randomDouble()); @@ -139,9 +141,9 @@ public void testHash() { } public void testFactoryMethods() { - TDigestState fast = TDigestState.create(100); - TDigestState anotherFast = TDigestState.create(100); - TDigestState accurate = TDigestState.createOptimizedForAccuracy(100); + TDigestState fast = TDigestState.create(arrays(), 100); + TDigestState anotherFast = TDigestState.create(arrays(), 100); + TDigestState accurate = TDigestState.createOptimizedForAccuracy(arrays(), 100); TDigestState anotherAccurate = TDigestState.createUsingParamsFrom(accurate); for (int i = 0; i < 100; i++) { @@ -173,7 +175,7 @@ private static TDigestState writeToAndReadFrom(TDigestState state, TransportVers ) ) { in.setTransportVersion(version); - return TDigestState.read(in); + return TDigestState.read(arrays(), in); } } @@ -188,8 +190,8 @@ private static BytesRef serialize(TDigestState state, TransportVersion version) public void testSerialization() throws IOException { // Past default was the accuracy-optimized version. - TDigestState state = TDigestState.create(100); - TDigestState backwardsCompatible = TDigestState.createOptimizedForAccuracy(100); + TDigestState state = TDigestState.create(arrays(), 100); + TDigestState backwardsCompatible = TDigestState.createOptimizedForAccuracy(arrays(), 100); for (int i = 0; i < 1000; i++) { state.add(i); backwardsCompatible.add(i); @@ -202,4 +204,8 @@ public void testSerialization() throws IOException { assertNotEquals(serializedBackwardsCompatible, state); assertEquals(serializedBackwardsCompatible, backwardsCompatible); } + + private static TDigestArrays arrays() { + return WrapperTDigestArrays.INSTANCE; + } } diff --git a/server/src/test/java/org/elasticsearch/search/sort/AbstractSortTestCase.java b/server/src/test/java/org/elasticsearch/search/sort/AbstractSortTestCase.java index 0e9ca00702b68..583cdf302ad65 100644 --- a/server/src/test/java/org/elasticsearch/search/sort/AbstractSortTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/sort/AbstractSortTestCase.java @@ -197,7 +197,7 @@ protected final SearchExecutionContext createMockSearchExecutionContext(IndexSea }; NestedLookup nestedLookup = NestedLookup.build(List.of(new NestedObjectMapper.Builder("path", IndexVersion.current(), query -> { throw new UnsupportedOperationException(); - }).build(MapperBuilderContext.root(false, false)))); + }, null).build(MapperBuilderContext.root(false, false)))); return new SearchExecutionContext( 0, 0, diff --git a/test/framework/src/main/java/org/elasticsearch/test/InternalTestCluster.java b/test/framework/src/main/java/org/elasticsearch/test/InternalTestCluster.java index 2ae1085f2ab25..ff66d59a21c5b 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/InternalTestCluster.java +++ b/test/framework/src/main/java/org/elasticsearch/test/InternalTestCluster.java @@ -130,6 +130,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import java.util.function.Predicate; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -1249,27 +1250,29 @@ public synchronized void validateClusterFormed() { .isTimedOut() ); try { + final Object[] previousStates = new Object[1]; assertBusy(() -> { final List states = nodes.values() .stream() .map(node -> getInstanceFromNode(ClusterService.class, node.node())) .map(ClusterService::state) .toList(); - final String debugString = ", expected nodes: " + expectedNodes + " and actual cluster states " + states; + if (previousStates[0] != null && previousStates[0].equals(states)) { + throw new AssertionError("unchanged"); + } + previousStates[0] = states; + final Supplier debugString = () -> ", expected nodes: " + expectedNodes + " and actual cluster states " + states; // all nodes have a master - assertTrue("Missing master" + debugString, states.stream().allMatch(cs -> cs.nodes().getMasterNodeId() != null)); + assert states.stream().allMatch(cs -> cs.nodes().getMasterNodeId() != null) : "Missing master" + debugString.get(); // all nodes have the same master (in same term) - assertEquals( - "Not all masters in same term" + debugString, - 1, - states.stream().mapToLong(ClusterState::term).distinct().count() - ); + assert 1L == states.stream().mapToLong(ClusterState::term).distinct().count() + : "Not all masters in same term" + debugString.get(); // all nodes know about all other nodes states.forEach(cs -> { DiscoveryNodes discoveryNodes = cs.nodes(); - assertEquals("Node size mismatch" + debugString, expectedNodes.size(), discoveryNodes.getSize()); + assert expectedNodes.size() == discoveryNodes.getSize() : "Node size mismatch" + debugString.get(); for (DiscoveryNode expectedNode : expectedNodes) { - assertTrue("Expected node to exist: " + expectedNode + debugString, discoveryNodes.nodeExists(expectedNode)); + assert discoveryNodes.nodeExists(expectedNode) : "Expected node to exist: " + expectedNode + debugString.get(); } }); }, 30, TimeUnit.SECONDS); diff --git a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/TransportResumeFollowActionTests.java b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/TransportResumeFollowActionTests.java index 7de0d775ba150..b4be0b33a464e 100644 --- a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/TransportResumeFollowActionTests.java +++ b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/TransportResumeFollowActionTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.MapperTestUtils; +import org.elasticsearch.index.mapper.IgnoredSourceFieldMapper; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ccr.Ccr; @@ -331,6 +332,8 @@ public void testDynamicIndexSettingsAreClassified() { replicatedSettings.add(IndexSettings.MAX_SHINGLE_DIFF_SETTING); replicatedSettings.add(IndexSettings.TIME_SERIES_END_TIME); replicatedSettings.add(IndexSettings.PREFER_ILM_SETTING); + replicatedSettings.add(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING); + replicatedSettings.add(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_WRITE_SETTING); for (Setting setting : IndexScopedSettings.BUILT_IN_INDEX_SETTINGS) { // removed settings have no effect, they are only there for BWC diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java new file mode 100644 index 0000000000000..672ad74419c8f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.ToXContent; + +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Flow; + +/** + * Chat Completion results that only contain a Flow.Publisher. + */ +public record StreamingChatCompletionResults(Flow.Publisher publisher) implements InferenceServiceResults { + + @Override + public boolean isStreaming() { + return true; + } + + @Override + public List transformToCoordinationFormat() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public List transformToLegacyFormat() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public Map asMap() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public String getWriteableName() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + throw new UnsupportedOperationException("Not implemented"); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetaIndex.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetaIndex.java index 58978537cd2c2..a373c4b9a70fd 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetaIndex.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetaIndex.java @@ -17,7 +17,7 @@ public final class MlMetaIndex { private static final String INDEX_NAME = ".ml-meta"; private static final String MAPPINGS_VERSION_VARIABLE = "xpack.ml.version"; - private static final int META_INDEX_MAPPINGS_VERSION = 1; + private static final int META_INDEX_MAPPINGS_VERSION = 2; /** * Where to store the ml info in Elasticsearch - must match what's diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/calendars/ScheduledEvent.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/calendars/ScheduledEvent.java index 95a68e4391da7..c6fa4e052c683 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/calendars/ScheduledEvent.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/calendars/ScheduledEvent.java @@ -6,6 +6,7 @@ */ package org.elasticsearch.xpack.core.ml.calendars; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -20,6 +21,8 @@ import org.elasticsearch.xpack.core.ml.job.config.Operator; import org.elasticsearch.xpack.core.ml.job.config.RuleAction; import org.elasticsearch.xpack.core.ml.job.config.RuleCondition; +import org.elasticsearch.xpack.core.ml.job.config.RuleParams; +import org.elasticsearch.xpack.core.ml.job.config.RuleParamsForForceTimeShift; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.Intervals; @@ -36,6 +39,9 @@ public class ScheduledEvent implements ToXContentObject, Writeable { public static final ParseField DESCRIPTION = new ParseField("description"); public static final ParseField START_TIME = new ParseField("start_time"); public static final ParseField END_TIME = new ParseField("end_time"); + public static final ParseField SKIP_RESULT = new ParseField("skip_result"); + public static final ParseField SKIP_MODEL_UPDATE = new ParseField("skip_model_update"); + public static final ParseField FORCE_TIME_SHIFT = new ParseField("force_time_shift"); public static final ParseField TYPE = new ParseField("type"); public static final ParseField EVENT_ID = new ParseField("event_id"); @@ -63,6 +69,9 @@ private static ObjectParser createParser(boolean i END_TIME, ObjectParser.ValueType.VALUE ); + parser.declareBoolean(ScheduledEvent.Builder::skipResult, SKIP_RESULT); + parser.declareBoolean(ScheduledEvent.Builder::skipModelUpdate, SKIP_MODEL_UPDATE); + parser.declareInt(ScheduledEvent.Builder::forceTimeShift, FORCE_TIME_SHIFT); parser.declareString(ScheduledEvent.Builder::calendarId, Calendar.ID); parser.declareString((builder, s) -> {}, TYPE); @@ -76,13 +85,28 @@ public static String documentId(String eventId) { private final String description; private final Instant startTime; private final Instant endTime; + private final Boolean skipResult; + private final Boolean skipModelUpdate; + private final Integer forceTimeShift; private final String calendarId; private final String eventId; - ScheduledEvent(String description, Instant startTime, Instant endTime, String calendarId, @Nullable String eventId) { + ScheduledEvent( + String description, + Instant startTime, + Instant endTime, + Boolean skipResult, + Boolean skipModelUpdate, + @Nullable Integer forceTimeShift, + String calendarId, + @Nullable String eventId + ) { this.description = Objects.requireNonNull(description); this.startTime = Instant.ofEpochMilli(Objects.requireNonNull(startTime).toEpochMilli()); this.endTime = Instant.ofEpochMilli(Objects.requireNonNull(endTime).toEpochMilli()); + this.skipResult = Objects.requireNonNull(skipResult); + this.skipModelUpdate = Objects.requireNonNull(skipModelUpdate); + this.forceTimeShift = forceTimeShift; this.calendarId = Objects.requireNonNull(calendarId); this.eventId = eventId; } @@ -91,6 +115,15 @@ public ScheduledEvent(StreamInput in) throws IOException { description = in.readString(); startTime = in.readInstant(); endTime = in.readInstant(); + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_SCHEDULED_EVENT_TIME_SHIFT_CONFIGURATION)) { + skipResult = in.readBoolean(); + skipModelUpdate = in.readBoolean(); + forceTimeShift = in.readOptionalInt(); + } else { + skipResult = true; + skipModelUpdate = true; + forceTimeShift = null; + } calendarId = in.readString(); eventId = in.readOptionalString(); } @@ -111,6 +144,18 @@ public String getCalendarId() { return calendarId; } + public Boolean getSkipResult() { + return skipResult; + } + + public Boolean getSkipModelUpdate() { + return skipModelUpdate; + } + + public Integer getForceTimeShift() { + return forceTimeShift; + } + public String getEventId() { return eventId; } @@ -138,7 +183,19 @@ public DetectionRule toDetectionRule(TimeValue bucketSpan) { conditions.add(RuleCondition.createTime(Operator.LT, bucketEndTime)); DetectionRule.Builder builder = new DetectionRule.Builder(conditions); - builder.setActions(RuleAction.SKIP_RESULT, RuleAction.SKIP_MODEL_UPDATE); + List ruleActions = new ArrayList<>(); + if (skipResult) { + ruleActions.add(RuleAction.SKIP_RESULT.toString()); + builder.setActions(RuleAction.SKIP_RESULT); + } + if (skipModelUpdate) { + ruleActions.add(RuleAction.SKIP_MODEL_UPDATE.toString()); + } + if (forceTimeShift != null) { + ruleActions.add(RuleAction.FORCE_TIME_SHIFT.toString()); + builder.setParams(new RuleParams(new RuleParamsForForceTimeShift(forceTimeShift))); + } + builder.setActions(ruleActions); return builder.build(); } @@ -147,6 +204,11 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(description); out.writeInstant(startTime); out.writeInstant(endTime); + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_SCHEDULED_EVENT_TIME_SHIFT_CONFIGURATION)) { + out.writeBoolean(skipResult); + out.writeBoolean(skipModelUpdate); + out.writeOptionalInt(forceTimeShift); + } out.writeString(calendarId); out.writeOptionalString(eventId); } @@ -157,6 +219,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(DESCRIPTION.getPreferredName(), description); builder.timeField(START_TIME.getPreferredName(), START_TIME.getPreferredName() + "_string", startTime.toEpochMilli()); builder.timeField(END_TIME.getPreferredName(), END_TIME.getPreferredName() + "_string", endTime.toEpochMilli()); + builder.field(SKIP_RESULT.getPreferredName(), skipResult); + builder.field(SKIP_MODEL_UPDATE.getPreferredName(), skipModelUpdate); + if (forceTimeShift != null) { + builder.field(FORCE_TIME_SHIFT.getPreferredName(), forceTimeShift); + } builder.field(Calendar.ID.getPreferredName(), calendarId); if (eventId != null) { builder.field(EVENT_ID.getPreferredName(), eventId); @@ -182,18 +249,24 @@ public boolean equals(Object obj) { return description.equals(other.description) && Objects.equals(startTime, other.startTime) && Objects.equals(endTime, other.endTime) + && Objects.equals(skipResult, other.skipResult) + && Objects.equals(skipModelUpdate, other.skipModelUpdate) + && Objects.equals(forceTimeShift, other.forceTimeShift) && calendarId.equals(other.calendarId); } @Override public int hashCode() { - return Objects.hash(description, startTime, endTime, calendarId); + return Objects.hash(description, startTime, endTime, skipResult, skipModelUpdate, forceTimeShift, calendarId); } public static class Builder { private String description; private Instant startTime; private Instant endTime; + private Boolean skipResult; + private Boolean skipModelUpdate; + private Integer forceTimeShift; private String calendarId; private String eventId; @@ -212,6 +285,21 @@ public Builder endTime(Instant endTime) { return this; } + public Builder skipResult(Boolean skipResult) { + this.skipResult = skipResult; + return this; + } + + public Builder skipModelUpdate(Boolean skipModelUpdate) { + this.skipModelUpdate = skipModelUpdate; + return this; + } + + public Builder forceTimeShift(Integer forceTimeShift) { + this.forceTimeShift = forceTimeShift; + return this; + } + public Builder calendarId(String calendarId) { this.calendarId = calendarId; return this; @@ -255,7 +343,19 @@ public ScheduledEvent build() { ); } - ScheduledEvent event = new ScheduledEvent(description, startTime, endTime, calendarId, eventId); + skipResult = skipResult == null || skipResult; + skipModelUpdate = skipModelUpdate == null || skipModelUpdate; + + ScheduledEvent event = new ScheduledEvent( + description, + startTime, + endTime, + skipResult, + skipModelUpdate, + forceTimeShift, + calendarId, + eventId + ); return event; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/calendars/ScheduledEventTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/calendars/ScheduledEventTests.java index 0a78b7b6d6f2c..891430057513e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/calendars/ScheduledEventTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/calendars/ScheduledEventTests.java @@ -16,19 +16,51 @@ import org.elasticsearch.xpack.core.ml.job.config.Operator; import org.elasticsearch.xpack.core.ml.job.config.RuleAction; import org.elasticsearch.xpack.core.ml.job.config.RuleCondition; +import org.elasticsearch.xpack.core.ml.job.config.RuleParams; +import org.elasticsearch.xpack.core.ml.job.config.RuleParamsForForceTimeShift; import java.io.IOException; import java.time.Instant; -import java.util.EnumSet; import java.util.List; +import java.util.Optional; import static org.hamcrest.Matchers.containsString; public class ScheduledEventTests extends AbstractXContentSerializingTestCase { + private static final long BUCKET_SPAN_SECS = 300; + public static ScheduledEvent createScheduledEvent(String calendarId) { Instant start = Instant.now(); - return new ScheduledEvent(randomAlphaOfLength(10), start, start.plusSeconds(randomIntBetween(1, 10000)), calendarId, null); + return new ScheduledEvent( + randomAlphaOfLength(10), + start, + start.plusSeconds(randomIntBetween(1, 10000)), + randomBoolean(), + randomBoolean(), + randomBoolean() ? null : randomInt(), + calendarId, + null + ); + } + + public static ScheduledEvent createScheduledEvent( + String calendarId, + Boolean skipResult, + Boolean skipModelUpdate, + Integer forceTimeShift + ) { + Instant start = Instant.now(); + return new ScheduledEvent( + randomAlphaOfLength(10), + start, + start.plusSeconds(randomIntBetween(1, 10000)), + skipResult, + skipModelUpdate, + forceTimeShift, + calendarId, + null + ); } @Override @@ -51,12 +83,52 @@ protected ScheduledEvent doParseInstance(XContentParser parser) throws IOExcepti return ScheduledEvent.STRICT_PARSER.apply(parser, null).build(); } - public void testToDetectionRule() { - long bucketSpanSecs = 300; - ScheduledEvent event = createTestInstance(); - DetectionRule rule = event.toDetectionRule(TimeValue.timeValueSeconds(bucketSpanSecs)); + public void testToDetectionRule_SetsSkipResultActionProperly() { + List validValues = List.of(true, false); + validValues.forEach((skipResult) -> { + Boolean skipModelUpdate = randomBoolean(); + Integer forceTimeShift = randomInt(); + ScheduledEvent event = createScheduledEvent(randomAlphaOfLength(10), skipResult, skipModelUpdate, forceTimeShift); + DetectionRule rule = event.toDetectionRule(TimeValue.timeValueSeconds(BUCKET_SPAN_SECS)); + validateDetectionRule(event, rule, skipResult, skipModelUpdate, forceTimeShift); + }); + } - assertEquals(rule.getActions(), EnumSet.of(RuleAction.SKIP_RESULT, RuleAction.SKIP_MODEL_UPDATE)); + public void testToDetectionRule_SetsSkipModelUpdateActionProperly() { + List validValues = List.of(true, false); + validValues.forEach((skipModelUpdate) -> { + Boolean skipResult = randomBoolean(); + Integer forceTimeShift = randomInt(); + ScheduledEvent event = createScheduledEvent(randomAlphaOfLength(10), skipResult, skipModelUpdate, forceTimeShift); + DetectionRule rule = event.toDetectionRule(TimeValue.timeValueSeconds(BUCKET_SPAN_SECS)); + validateDetectionRule(event, rule, skipResult, skipModelUpdate, forceTimeShift); + }); + } + + public void testToDetectionRule_SetsForceTimeShiftActionProperly() { + List> validValues = List.of(Optional.of(randomInt()), Optional.empty()); + validValues.forEach((forceTimeShift) -> { + Boolean skipResult = randomBoolean(); + Boolean skipModelUpdate = randomBoolean(); + ScheduledEvent event = createScheduledEvent(randomAlphaOfLength(10), skipResult, skipModelUpdate, forceTimeShift.orElse(null)); + DetectionRule rule = event.toDetectionRule(TimeValue.timeValueSeconds(BUCKET_SPAN_SECS)); + validateDetectionRule(event, rule, skipResult, skipModelUpdate, forceTimeShift.orElse(null)); + }); + } + + private void validateDetectionRule( + ScheduledEvent event, + DetectionRule rule, + Boolean skipResult, + Boolean skipModelUpdate, + Integer forceTimeShift + ) { + assertEquals(skipResult, rule.getActions().contains(RuleAction.SKIP_RESULT)); + assertEquals(skipModelUpdate, rule.getActions().contains(RuleAction.SKIP_MODEL_UPDATE)); + if (forceTimeShift != null) { + RuleParams expectedRuleParams = new RuleParams(new RuleParamsForForceTimeShift(forceTimeShift)); + assertEquals(expectedRuleParams, rule.getParams()); + } List conditions = rule.getConditions(); assertEquals(2, conditions.size()); @@ -67,16 +139,92 @@ public void testToDetectionRule() { // Check times are aligned with the bucket long conditionStartTime = (long) conditions.get(0).getValue(); - assertEquals(0, conditionStartTime % bucketSpanSecs); - long bucketCount = conditionStartTime / bucketSpanSecs; - assertEquals(bucketSpanSecs * bucketCount, conditionStartTime); + assertEquals(0, conditionStartTime % BUCKET_SPAN_SECS); + long bucketCount = conditionStartTime / BUCKET_SPAN_SECS; + assertEquals(BUCKET_SPAN_SECS * bucketCount, conditionStartTime); long conditionEndTime = (long) conditions.get(1).getValue(); - assertEquals(0, conditionEndTime % bucketSpanSecs); + assertEquals(0, conditionEndTime % BUCKET_SPAN_SECS); long eventTime = event.getEndTime().getEpochSecond() - conditionStartTime; - long numbBucketsInEvent = (eventTime + bucketSpanSecs - 1) / bucketSpanSecs; - assertEquals(bucketSpanSecs * (bucketCount + numbBucketsInEvent), conditionEndTime); + long numbBucketsInEvent = (eventTime + BUCKET_SPAN_SECS - 1) / BUCKET_SPAN_SECS; + assertEquals(BUCKET_SPAN_SECS * (bucketCount + numbBucketsInEvent), conditionEndTime); + } + + public void testBuild_DescriptionNull() { + ScheduledEvent.Builder builder = new ScheduledEvent.Builder(); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, builder::build); + assertEquals("Field [description] cannot be null", e.getMessage()); + } + + public void testBuild_StartTimeNull() { + ScheduledEvent.Builder builder = new ScheduledEvent.Builder().description("foo"); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, builder::build); + assertEquals("Field [start_time] cannot be null", e.getMessage()); + } + + public void testBuild_EndTimeNull() { + ScheduledEvent.Builder builder = new ScheduledEvent.Builder().description("foo").startTime(Instant.now()); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, builder::build); + assertEquals("Field [end_time] cannot be null", e.getMessage()); + } + + public void testBuild_CalendarIdNull() { + ScheduledEvent.Builder builder = new ScheduledEvent.Builder().description("foo") + .startTime(Instant.now()) + .endTime(Instant.now().plusSeconds(1 * 60 * 60)); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, builder::build); + assertEquals("Field [calendar_id] cannot be null", e.getMessage()); + } + + public void testBuild_StartTimeAfterEndTime() { + Instant now = Instant.now(); + ScheduledEvent.Builder builder = new ScheduledEvent.Builder().description("f") + .calendarId("c") + .startTime(now) + .endTime(now.minusSeconds(2 * 60 * 60)); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, builder::build); + assertThat(e.getMessage(), containsString("must come before end time")); + } + + public void testBuild_SucceedsWithDefaultSkipResultAndSkipModelUpdatesValues() { + validateScheduledEventSuccessfulBuild(Optional.empty(), Optional.empty(), Optional.empty()); + } + + public void testBuild_SucceedsWithProvidedSkipResultAndSkipModelUpdatesValues() { + Boolean skipResult = randomBoolean(); + Boolean skipModelUpdate = randomBoolean(); + Integer forceTimeShift = randomBoolean() ? null : randomInt(); + + validateScheduledEventSuccessfulBuild(Optional.of(skipResult), Optional.of(skipModelUpdate), Optional.ofNullable(forceTimeShift)); + } + + private void validateScheduledEventSuccessfulBuild( + Optional skipResult, + Optional skipModelUpdate, + Optional forceTimeShift + ) { + String description = randomAlphaOfLength(10); + String calendarId = randomAlphaOfLength(10); + Instant startTime = Instant.ofEpochMilli(Instant.now().toEpochMilli()); + Instant endTime = startTime.plusSeconds(randomInt(3600)); + + ScheduledEvent.Builder builder = new ScheduledEvent.Builder().description(description) + .calendarId(calendarId) + .startTime(startTime) + .endTime(endTime); + skipResult.ifPresent(builder::skipResult); + skipModelUpdate.ifPresent(builder::skipModelUpdate); + forceTimeShift.ifPresent(builder::forceTimeShift); + + ScheduledEvent event = builder.build(); + assertEquals(description, event.getDescription()); + assertEquals(calendarId, event.getCalendarId()); + assertEquals(startTime, event.getStartTime()); + assertEquals(endTime, event.getEndTime()); + assertEquals(skipResult.orElse(true), event.getSkipResult()); + assertEquals(skipModelUpdate.orElse(true), event.getSkipModelUpdate()); + assertEquals(forceTimeShift.orElse(null), event.getForceTimeShift()); } public void testBuild() { diff --git a/x-pack/plugin/core/template-resources/src/main/resources/ml/meta_index_mappings.json b/x-pack/plugin/core/template-resources/src/main/resources/ml/meta_index_mappings.json index 4606cb0d75d8f..5862e5d195f25 100644 --- a/x-pack/plugin/core/template-resources/src/main/resources/ml/meta_index_mappings.json +++ b/x-pack/plugin/core/template-resources/src/main/resources/ml/meta_index_mappings.json @@ -36,6 +36,15 @@ "end_time": { "type": "date" }, + "skip_result": { + "type": "boolean" + }, + "skip_model_update": { + "type": "boolean" + }, + "force_time_shift": { + "type": "integer" + }, "type": { "type": "keyword" } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/multivalue_geometries.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/multivalue_geometries.csv-spec index d79f40711e9f6..7e3ff5effd354 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/multivalue_geometries.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/multivalue_geometries.csv-spec @@ -7,6 +7,7 @@ #################################################################################################### spatialGeometryCollectionStats +required_capability: spatial_shapes FROM multivalue_geometries | MV_EXPAND shape diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java index 373bf5a99056d..e4ae181915c8a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java @@ -77,7 +77,6 @@ import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders; @@ -471,7 +470,7 @@ private static class LookupRequest extends TransportRequest implements IndicesRe this.inputPage = new Page(bsi); } this.toRelease = inputPage; - PlanStreamInput planIn = new PlanStreamInput(in, PlanNameRegistry.INSTANCE, in.namedWriteableRegistry(), null); + PlanStreamInput planIn = new PlanStreamInput(in, in.namedWriteableRegistry(), null); this.extractFields = planIn.readNamedWriteableCollectionAsList(NamedExpression.class); } @@ -486,7 +485,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(matchType); out.writeString(matchField); out.writeWriteable(inputPage); - PlanStreamOutput planOut = new PlanStreamOutput(out, PlanNameRegistry.INSTANCE, null); + PlanStreamOutput planOut = new PlanStreamOutput(out, null); planOut.writeNamedWriteableCollection(extractFields); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichPolicyResolver.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichPolicyResolver.java index f77bfa6d3f862..447df09942ca8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichPolicyResolver.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichPolicyResolver.java @@ -40,7 +40,6 @@ import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.core.util.StringUtils; import org.elasticsearch.xpack.esql.index.EsIndex; -import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.plan.logical.Enrich; @@ -329,14 +328,14 @@ private static class LookupResponse extends TransportResponse { } LookupResponse(StreamInput in) throws IOException { - PlanStreamInput planIn = new PlanStreamInput(in, PlanNameRegistry.INSTANCE, in.namedWriteableRegistry(), null); + PlanStreamInput planIn = new PlanStreamInput(in, in.namedWriteableRegistry(), null); this.policies = planIn.readMap(StreamInput::readString, ResolvedEnrichPolicy::new); this.failures = planIn.readMap(StreamInput::readString, StreamInput::readString); } @Override public void writeTo(StreamOutput out) throws IOException { - PlanStreamOutput pso = new PlanStreamOutput(out, new PlanNameRegistry(), null); + PlanStreamOutput pso = new PlanStreamOutput(out, null); pso.writeMap(policies, StreamOutput::writeWriteable); pso.writeMap(failures, StreamOutput::writeString); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/package-info.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/package-info.java index 1c10c7d2fa9ef..4106df331d101 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/package-info.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/package-info.java @@ -89,13 +89,12 @@ * Check all usages of other aggregations there, and replicate the logic. * *
  • - * Add it to {@link org.elasticsearch.xpack.esql.io.stream.PlanNamedTypes}. - * Consider adding a {@code writeTo} method and a constructor/{@code readFrom} method inside your function, - * to keep all the logic in one place. - *

    - * You can find examples of other aggregations using this method, - * like {@link org.elasticsearch.xpack.esql.expression.function.aggregate.Top#writeTo(PlanStreamOutput)} - *

    + * Implement serialization for your aggregation by implementing + * {@link org.elasticsearch.common.io.stream.NamedWriteable#getWriteableName}, + * {@link org.elasticsearch.common.io.stream.NamedWriteable#writeTo}, + * and a deserializing constructor. Then add an {@link org.elasticsearch.common.io.stream.NamedWriteableRegistry.Entry} + * constant and add that constant to the list in + * {@link org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction#getNamedWriteables}. *
  • *
  • * Do the same with {@link org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry}. @@ -199,5 +198,3 @@ * */ package org.elasticsearch.xpack.esql.expression.function.aggregate; - -import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNameRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNameRegistry.java deleted file mode 100644 index 15368dc0fdb36..0000000000000 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNameRegistry.java +++ /dev/null @@ -1,272 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.io.stream; - -import org.elasticsearch.common.io.stream.NamedWriteable; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * A registry of ESQL names to readers and writers, that can be used to serialize a physical plan - * fragment. Allows to serialize the non-(Named)Writable types in both the QL and ESQL modules. - * Serialization is from the outside in, rather than from within. - *

    - * This class is somewhat analogous to NamedWriteableRegistry, but does not require the types to - * be NamedWriteable. - */ -public class PlanNameRegistry { - - public static final PlanNameRegistry INSTANCE = new PlanNameRegistry(); - - /** Adaptable writer interface to bridge between ESQL and regular stream outputs. */ - @FunctionalInterface - public interface PlanWriter extends Writeable.Writer { - - void write(PlanStreamOutput out, V value) throws IOException; - - @Override - default void write(StreamOutput out, V value) throws IOException { - write((PlanStreamOutput) out, value); - } - - static Writeable.Writer writerFromPlanWriter(PlanWriter planWriter) { - return planWriter; - } - } - - /** Adaptable reader interface to bridge between ESQL and regular stream inputs. */ - @FunctionalInterface - public interface PlanReader extends Writeable.Reader { - - V read(PlanStreamInput in) throws IOException; - - @Override - default V read(StreamInput in) throws IOException { - return read((PlanStreamInput) in); - } - - static Writeable.Reader readerFromPlanReader(PlanReader planReader) { - return planReader; - } - } - - /** Adaptable reader interface that allows access to the reader name. */ - @FunctionalInterface - interface PlanNamedReader extends PlanReader { - - V read(PlanStreamInput in, String name) throws IOException; - - default V read(PlanStreamInput in) throws IOException { - throw new UnsupportedOperationException("should not reach here"); - } - } - - record Entry( - /** The superclass of a writeable category will be read by a reader. */ - Class categoryClass, - /** A name for the writeable which is unique to the categoryClass. */ - String name, - /** A writer for non-NamedWriteable class */ - PlanWriter writer, - /** A reader capability of reading the writeable. */ - PlanReader reader - ) { - - /** Creates a new entry which can be stored by the registry. */ - Entry { - Objects.requireNonNull(categoryClass); - Objects.requireNonNull(name); - Objects.requireNonNull(writer); - Objects.requireNonNull(reader); - } - - static Entry of( - Class categoryClass, - Class concreteClass, - PlanWriter writer, - PlanReader reader - ) { - return new Entry(categoryClass, PlanNamedTypes.name(concreteClass), writer, reader); - } - - static Entry of(Class categoryClass, NamedWriteableRegistry.Entry entry) { - return new Entry( - categoryClass, - entry.name, - (o, v) -> categoryClass.cast(v).writeTo(o), - in -> categoryClass.cast(entry.reader.read(in)) - ); - } - - static Entry of( - Class categoryClass, - Class concreteClass, - PlanWriter writer, - PlanNamedReader reader - ) { - return new Entry(categoryClass, PlanNamedTypes.name(concreteClass), writer, reader); - } - } - - /** - * The underlying data of the registry maps from the category to an inner - * map of name unique to that category, to the actual reader. - */ - private final Map, Map>> readerRegistry; - - /** - * The underlying data of the registry maps from the category to an inner - * map of name unique to that category, to the actual writer. - */ - private final Map, Map>> writerRegistry; - - public PlanNameRegistry() { - this(PlanNamedTypes.namedTypeEntries()); - } - - /** Constructs a new registry from the given entries. */ - PlanNameRegistry(List entries) { - entries = new ArrayList<>(entries); - entries.sort(Comparator.comparing(e -> e.categoryClass().getName())); - - Map, Map>> rr = new HashMap<>(); - Map, Map>> wr = new HashMap<>(); - for (Entry entry : entries) { - Class categoryClass = entry.categoryClass; - Map> readers = rr.computeIfAbsent(categoryClass, v -> new HashMap<>()); - Map> writers = wr.computeIfAbsent(categoryClass, v -> new HashMap<>()); - - PlanReader oldReader = readers.put(entry.name, entry.reader); - if (oldReader != null) { - throwAlreadyRegisteredReader(categoryClass, entry.name, oldReader.getClass(), entry.reader.getClass()); - } - PlanWriter oldWriter = writers.put(entry.name, entry.writer); - if (oldWriter != null) { - throwAlreadyRegisteredReader(categoryClass, entry.name, oldWriter.getClass(), entry.writer.getClass()); - } - } - - // add subclass categories, e.g. NamedExpressions are also Expressions - Map, List>> subCategories = subCategories(entries); - for (var entry : subCategories.entrySet()) { - var readers = rr.get(entry.getKey()); - var writers = wr.get(entry.getKey()); - for (Class subCategory : entry.getValue()) { - readers.putAll(rr.get(subCategory)); - writers.putAll(wr.get(subCategory)); - } - } - - this.readerRegistry = Map.copyOf(rr); - this.writerRegistry = Map.copyOf(wr); - } - - /** Determines the subclass relation of category classes.*/ - static Map, List>> subCategories(List entries) { - Map, Set>> map = new HashMap<>(); - for (Entry entry : entries) { - Class category = entry.categoryClass; - for (Entry entry1 : entries) { - Class category1 = entry1.categoryClass; - if (category == category1) { - continue; - } - if (category.isAssignableFrom(category1)) { // category is a superclass/interface of category1 - Set> set = map.computeIfAbsent(category, v -> new HashSet<>()); - set.add(category1); - } - } - } - return map.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, s -> new ArrayList<>(s.getValue()))); - } - - PlanReader getReader(Class categoryClass, String name) { - Map> readers = getReaders(categoryClass); - return getReader(categoryClass, name, readers); - } - - static PlanReader getReader(Class categoryClass, String name, Map> readers) { - @SuppressWarnings("unchecked") - PlanReader reader = (PlanReader) readers.get(name); - if (reader == null) { - throwOnUnknownReadable(categoryClass, name); - } - return reader; - } - - Map> getReaders(Class categoryClass) { - Map> readers = readerRegistry.get(categoryClass); - if (readers == null) { - throwOnUnknownCategory(categoryClass); - } - return readers; - } - - PlanWriter getWriter(Class categoryClass, String name, Map> writers) { - @SuppressWarnings("unchecked") - PlanWriter writer = (PlanWriter) writers.get(name); - if (writer == null) { - throwOnUnknownWritable(categoryClass, name); - } - return writer; - } - - public Map> getWriters(Class categoryClass) { - Map> writers = writerRegistry.get(categoryClass); - if (writers == null) { - throwOnUnknownCategory(categoryClass); - } - return writers; - } - - public PlanWriter getWriter(Class categoryClass, String name) { - Map> writers = getWriters(categoryClass); - return getWriter(categoryClass, name, writers); - } - - private static void throwAlreadyRegisteredReader(Class categoryClass, String entryName, Class oldReader, Class entryReader) { - throw new IllegalArgumentException( - "PlanReader [" - + categoryClass.getName() - + "][" - + entryName - + "]" - + " is already registered for [" - + oldReader.getName() - + "]," - + " cannot register [" - + entryReader.getName() - + "]" - ); - } - - private static void throwOnUnknownWritable(Class categoryClass, String name) { - throw new IllegalArgumentException("Unknown writeable [" + categoryClass.getName() + "][" + name + "]"); - } - - private static void throwOnUnknownCategory(Class categoryClass) { - throw new IllegalArgumentException("Unknown writeable category [" + categoryClass.getName() + "]"); - } - - private static void throwOnUnknownReadable(Class categoryClass, String name) { - throw new IllegalArgumentException("Unknown readable [" + categoryClass.getName() + "][" + name + "]"); - } -} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java deleted file mode 100644 index 9b5c521516363..0000000000000 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.io.stream; - -import org.elasticsearch.common.io.stream.NamedWriteable; -import org.elasticsearch.xpack.esql.core.expression.Alias; -import org.elasticsearch.xpack.esql.core.expression.Attribute; -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; -import org.elasticsearch.xpack.esql.plan.physical.DissectExec; -import org.elasticsearch.xpack.esql.plan.physical.EnrichExec; -import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec; -import org.elasticsearch.xpack.esql.plan.physical.EsSourceExec; -import org.elasticsearch.xpack.esql.plan.physical.EvalExec; -import org.elasticsearch.xpack.esql.plan.physical.ExchangeExec; -import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec; -import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec; -import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec; -import org.elasticsearch.xpack.esql.plan.physical.FilterExec; -import org.elasticsearch.xpack.esql.plan.physical.FragmentExec; -import org.elasticsearch.xpack.esql.plan.physical.GrokExec; -import org.elasticsearch.xpack.esql.plan.physical.HashJoinExec; -import org.elasticsearch.xpack.esql.plan.physical.LimitExec; -import org.elasticsearch.xpack.esql.plan.physical.LocalSourceExec; -import org.elasticsearch.xpack.esql.plan.physical.MvExpandExec; -import org.elasticsearch.xpack.esql.plan.physical.OrderExec; -import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; -import org.elasticsearch.xpack.esql.plan.physical.ProjectExec; -import org.elasticsearch.xpack.esql.plan.physical.RowExec; -import org.elasticsearch.xpack.esql.plan.physical.ShowExec; -import org.elasticsearch.xpack.esql.plan.physical.TopNExec; - -import java.io.IOException; -import java.util.List; - -import static org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.Entry.of; - -/** - * A utility class that consists solely of static methods that describe how to serialize and - * deserialize QL and ESQL plan types. - *

    - * All types that require to be serialized should have a pair of co-located `readFoo` and `writeFoo` - * methods that deserialize and serialize respectively. - *

    - * A type can be named or non-named. A named type has a name written to the stream before its - * contents (similar to NamedWriteable), whereas a non-named type does not (similar to Writable). - * Named types allow to determine specific deserialization implementations for more general types, - * e.g. Literal, which is an Expression. Named types must have an entries in the namedTypeEntries - * list. - */ -public final class PlanNamedTypes { - - private PlanNamedTypes() {} - - /** - * Determines the writeable name of the give class. The simple class name is commonly used for - * {@link NamedWriteable}s and is sufficient here too, but it could be almost anything else. - */ - public static String name(Class cls) { - return cls.getSimpleName(); - } - - /** - * List of named type entries that link concrete names to stream reader and writer implementations. - * Entries have the form: category, name, serializer method, deserializer method. - */ - public static List namedTypeEntries() { - List declared = List.of( - // Physical Plan Nodes - of(PhysicalPlan.class, AggregateExec.ENTRY), - of(PhysicalPlan.class, DissectExec.ENTRY), - of(PhysicalPlan.class, EsQueryExec.ENTRY), - of(PhysicalPlan.class, EsSourceExec.ENTRY), - of(PhysicalPlan.class, EvalExec.ENTRY), - of(PhysicalPlan.class, EnrichExec.ENTRY), - of(PhysicalPlan.class, ExchangeExec.ENTRY), - of(PhysicalPlan.class, ExchangeSinkExec.ENTRY), - of(PhysicalPlan.class, ExchangeSourceExec.ENTRY), - of(PhysicalPlan.class, FieldExtractExec.ENTRY), - of(PhysicalPlan.class, FilterExec.ENTRY), - of(PhysicalPlan.class, FragmentExec.ENTRY), - of(PhysicalPlan.class, GrokExec.ENTRY), - of(PhysicalPlan.class, LimitExec.ENTRY), - of(PhysicalPlan.class, LocalSourceExec.ENTRY), - of(PhysicalPlan.class, HashJoinExec.ENTRY), - of(PhysicalPlan.class, MvExpandExec.class, PlanNamedTypes::writeMvExpandExec, PlanNamedTypes::readMvExpandExec), - of(PhysicalPlan.class, OrderExec.class, PlanNamedTypes::writeOrderExec, PlanNamedTypes::readOrderExec), - of(PhysicalPlan.class, ProjectExec.class, PlanNamedTypes::writeProjectExec, PlanNamedTypes::readProjectExec), - of(PhysicalPlan.class, RowExec.class, PlanNamedTypes::writeRowExec, PlanNamedTypes::readRowExec), - of(PhysicalPlan.class, ShowExec.class, PlanNamedTypes::writeShowExec, PlanNamedTypes::readShowExec), - of(PhysicalPlan.class, TopNExec.class, PlanNamedTypes::writeTopNExec, PlanNamedTypes::readTopNExec) - ); - return declared; - } - - // -- physical plan nodes - static MvExpandExec readMvExpandExec(PlanStreamInput in) throws IOException { - return new MvExpandExec( - Source.readFrom(in), - in.readPhysicalPlanNode(), - in.readNamedWriteable(NamedExpression.class), - in.readNamedWriteable(Attribute.class) - ); - } - - static void writeMvExpandExec(PlanStreamOutput out, MvExpandExec mvExpandExec) throws IOException { - Source.EMPTY.writeTo(out); - out.writePhysicalPlanNode(mvExpandExec.child()); - out.writeNamedWriteable(mvExpandExec.target()); - out.writeNamedWriteable(mvExpandExec.expanded()); - } - - static OrderExec readOrderExec(PlanStreamInput in) throws IOException { - return new OrderExec( - Source.readFrom(in), - in.readPhysicalPlanNode(), - in.readCollectionAsList(org.elasticsearch.xpack.esql.expression.Order::new) - ); - } - - static void writeOrderExec(PlanStreamOutput out, OrderExec orderExec) throws IOException { - Source.EMPTY.writeTo(out); - out.writePhysicalPlanNode(orderExec.child()); - out.writeCollection(orderExec.order()); - } - - static ProjectExec readProjectExec(PlanStreamInput in) throws IOException { - return new ProjectExec( - Source.readFrom(in), - in.readPhysicalPlanNode(), - in.readNamedWriteableCollectionAsList(NamedExpression.class) - ); - } - - static void writeProjectExec(PlanStreamOutput out, ProjectExec projectExec) throws IOException { - Source.EMPTY.writeTo(out); - out.writePhysicalPlanNode(projectExec.child()); - out.writeNamedWriteableCollection(projectExec.projections()); - } - - static RowExec readRowExec(PlanStreamInput in) throws IOException { - return new RowExec(Source.readFrom(in), in.readCollectionAsList(Alias::new)); - } - - static void writeRowExec(PlanStreamOutput out, RowExec rowExec) throws IOException { - assert rowExec.children().size() == 0; - Source.EMPTY.writeTo(out); - out.writeCollection(rowExec.fields()); - } - - @SuppressWarnings("unchecked") - static ShowExec readShowExec(PlanStreamInput in) throws IOException { - return new ShowExec( - Source.readFrom(in), - in.readNamedWriteableCollectionAsList(Attribute.class), - (List>) in.readGenericValue() - ); - } - - static void writeShowExec(PlanStreamOutput out, ShowExec showExec) throws IOException { - Source.EMPTY.writeTo(out); - out.writeNamedWriteableCollection(showExec.output()); - out.writeGenericValue(showExec.values()); - } - - static TopNExec readTopNExec(PlanStreamInput in) throws IOException { - return new TopNExec( - Source.readFrom(in), - in.readPhysicalPlanNode(), - in.readCollectionAsList(org.elasticsearch.xpack.esql.expression.Order::new), - in.readNamedWriteable(Expression.class), - in.readOptionalVInt() - ); - } - - static void writeTopNExec(PlanStreamOutput out, TopNExec topNExec) throws IOException { - Source.EMPTY.writeTo(out); - out.writePhysicalPlanNode(topNExec.child()); - out.writeCollection(topNExec.order()); - out.writeNamedWriteable(topNExec.limit()); - out.writeOptionalVInt(topNExec.estimatedRowSize()); - } -} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java index 6c372d5e725a7..ef4417a1c7a02 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java @@ -30,9 +30,6 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.type.EsField; -import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.PlanNamedReader; -import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.PlanReader; -import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.session.Configuration; import java.io.IOException; @@ -69,56 +66,17 @@ public NameId apply(long streamNameId) { private EsField[] esFieldsCache = new EsField[64]; - private final PlanNameRegistry registry; - // hook for nameId, where can cache and map, for now just return a NameId of the same long value. private final LongFunction nameIdFunction; private final Configuration configuration; - public PlanStreamInput( - StreamInput streamInput, - PlanNameRegistry registry, - NamedWriteableRegistry namedWriteableRegistry, - Configuration configuration - ) { + public PlanStreamInput(StreamInput streamInput, NamedWriteableRegistry namedWriteableRegistry, Configuration configuration) { super(streamInput, namedWriteableRegistry); - this.registry = registry; this.configuration = configuration; this.nameIdFunction = new NameIdMapper(); } - public PhysicalPlan readPhysicalPlanNode() throws IOException { - return readNamed(PhysicalPlan.class); - } - - public PhysicalPlan readOptionalPhysicalPlanNode() throws IOException { - return readOptionalNamed(PhysicalPlan.class); - } - - public T readNamed(Class type) throws IOException { - String name = readString(); - @SuppressWarnings("unchecked") - PlanReader reader = (PlanReader) registry.getReader(type, name); - if (reader instanceof PlanNamedReader namedReader) { - return namedReader.read(this, name); - } else { - return reader.read(this); - } - } - - public T readOptionalNamed(Class type) throws IOException { - if (readBoolean()) { - T t = readNamed(type); - if (t == null) { - throwOnNullOptionalRead(type); - } - return t; - } else { - return null; - } - } - public Configuration configuration() throws IOException { return configuration; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java index 815c9e82e0460..fe66b799195f8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java @@ -23,14 +23,11 @@ import org.elasticsearch.xpack.esql.core.InvalidArgumentException; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.type.EsField; -import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.PlanWriter; -import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.session.Configuration; import java.io.IOException; import java.util.IdentityHashMap; import java.util.Map; -import java.util.function.Function; /** * A customized stream output used to serialize ESQL physical plan fragments. Complements stream @@ -69,28 +66,17 @@ public final class PlanStreamOutput extends StreamOutput implements org.elastics protected final Map cachedEsFields = new IdentityHashMap<>(); private final StreamOutput delegate; - private final PlanNameRegistry registry; - - private final Function, String> nameSupplier; private int nextCachedBlock = 0; - private int maxSerializedAttributes; + private final int maxSerializedAttributes; - public PlanStreamOutput(StreamOutput delegate, PlanNameRegistry registry, @Nullable Configuration configuration) throws IOException { - this(delegate, registry, configuration, PlanNamedTypes::name, MAX_SERIALIZED_ATTRIBUTES); + public PlanStreamOutput(StreamOutput delegate, @Nullable Configuration configuration) throws IOException { + this(delegate, configuration, MAX_SERIALIZED_ATTRIBUTES); } - public PlanStreamOutput( - StreamOutput delegate, - PlanNameRegistry registry, - @Nullable Configuration configuration, - Function, String> nameSupplier, - int maxSerializedAttributes - ) throws IOException { + public PlanStreamOutput(StreamOutput delegate, @Nullable Configuration configuration, int maxSerializedAttributes) throws IOException { this.delegate = delegate; - this.registry = registry; - this.nameSupplier = nameSupplier; if (configuration != null) { for (Map.Entry> table : configuration.tables().entrySet()) { for (Map.Entry column : table.getValue().entrySet()) { @@ -101,28 +87,6 @@ public PlanStreamOutput( this.maxSerializedAttributes = maxSerializedAttributes; } - public void writePhysicalPlanNode(PhysicalPlan physicalPlan) throws IOException { - assert physicalPlan.children().size() <= 1; - writeNamed(PhysicalPlan.class, physicalPlan); - } - - public void writeOptionalPhysicalPlanNode(PhysicalPlan physicalPlan) throws IOException { - if (physicalPlan == null) { - writeBoolean(false); - } else { - writeBoolean(true); - writePhysicalPlanNode(physicalPlan); - } - } - - public void writeNamed(Class type, T value) throws IOException { - String name = nameSupplier.apply(value.getClass()); - @SuppressWarnings("unchecked") - PlanWriter writer = (PlanWriter) registry.getWriter(type, name); - writeString(name); - writer.write(this, value); - } - @Override public void writeByte(byte b) throws IOException { delegate.writeByte(b); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/AggregateExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/AggregateExec.java index f003abca7d1da..dff55f0738975 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/AggregateExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/AggregateExec.java @@ -19,7 +19,6 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; -import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import java.io.IOException; @@ -71,7 +70,7 @@ private AggregateExec(StreamInput in) throws IOException { // So, we do not have to consider previous transport versions here, because old nodes will not send AggregateExecs to new nodes. this( Source.readFrom((PlanStreamInput) in), - ((PlanStreamInput) in).readPhysicalPlanNode(), + in.readNamedWriteable(PhysicalPlan.class), in.readNamedWriteableCollectionAsList(Expression.class), in.readNamedWriteableCollectionAsList(NamedExpression.class), in.readEnum(AggregatorMode.class), @@ -83,7 +82,7 @@ private AggregateExec(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { Source.EMPTY.writeTo(out); - ((PlanStreamOutput) out).writePhysicalPlanNode(child()); + out.writeNamedWriteable(child()); out.writeNamedWriteableCollection(groupings()); out.writeNamedWriteableCollection(aggregates()); if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_AGGREGATE_EXEC_TRACKS_INTERMEDIATE_ATTRS)) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/DissectExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/DissectExec.java index 35a364126a66b..e22dbdd2961c4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/DissectExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/DissectExec.java @@ -15,7 +15,6 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; -import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.plan.logical.Dissect; import java.io.IOException; @@ -45,7 +44,7 @@ public DissectExec( private DissectExec(StreamInput in) throws IOException { this( Source.readFrom((PlanStreamInput) in), - ((PlanStreamInput) in).readPhysicalPlanNode(), + in.readNamedWriteable(PhysicalPlan.class), in.readNamedWriteable(Expression.class), Dissect.Parser.readFrom(in), in.readNamedWriteableCollectionAsList(Attribute.class) @@ -55,7 +54,7 @@ private DissectExec(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { Source.EMPTY.writeTo(out); - ((PlanStreamOutput) out).writePhysicalPlanNode(child()); + out.writeNamedWriteable(child()); out.writeNamedWriteable(inputExpression()); parser().writeTo(out); out.writeNamedWriteableCollection(extractedFields()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EnrichExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EnrichExec.java index 40ff7318b889b..a14332ebef7c3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EnrichExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EnrichExec.java @@ -20,7 +20,6 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; -import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import java.io.IOException; @@ -76,7 +75,7 @@ public EnrichExec( private static EnrichExec readFrom(StreamInput in) throws IOException { final Source source = Source.readFrom((PlanStreamInput) in); - final PhysicalPlan child = ((PlanStreamInput) in).readPhysicalPlanNode(); + final PhysicalPlan child = in.readNamedWriteable(PhysicalPlan.class); final NamedExpression matchField = in.readNamedWriteable(NamedExpression.class); final String policyName = in.readString(); final String matchType = (in.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) ? in.readString() : "match"; @@ -110,7 +109,7 @@ private static EnrichExec readFrom(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { Source.EMPTY.writeTo(out); - ((PlanStreamOutput) out).writePhysicalPlanNode(child()); + out.writeNamedWriteable(child()); out.writeNamedWriteable(matchField()); out.writeString(policyName()); if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsStatsQueryExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsStatsQueryExec.java index dbcdf3378028a..5a98ecc7d6594 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsStatsQueryExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsStatsQueryExec.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.plan.physical; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -17,6 +18,7 @@ import org.elasticsearch.xpack.esql.core.util.Queries; import org.elasticsearch.xpack.esql.index.EsIndex; +import java.io.IOException; import java.util.List; import java.util.Objects; @@ -64,6 +66,16 @@ public EsStatsQueryExec( this.stats = stats; } + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new UnsupportedOperationException("not serialized"); + } + + @Override + public String getWriteableName() { + throw new UnsupportedOperationException("not serialized"); + } + @Override protected NodeInfo info() { return NodeInfo.create(this, EsStatsQueryExec::new, index, query, limit, attrs, stats); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EvalExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EvalExec.java index 860ba1489f572..5e8007f6b6ec5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EvalExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EvalExec.java @@ -16,7 +16,6 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; -import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.plan.logical.Eval; import java.io.IOException; @@ -40,13 +39,13 @@ public EvalExec(Source source, PhysicalPlan child, List fields) { } private EvalExec(StreamInput in) throws IOException { - this(Source.readFrom((PlanStreamInput) in), ((PlanStreamInput) in).readPhysicalPlanNode(), in.readCollectionAsList(Alias::new)); + this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(PhysicalPlan.class), in.readCollectionAsList(Alias::new)); } @Override public void writeTo(StreamOutput out) throws IOException { Source.EMPTY.writeTo(out); - ((PlanStreamOutput) out).writePhysicalPlanNode(child()); + out.writeNamedWriteable(child()); out.writeCollection(fields()); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeExec.java index f20b218f28efb..5530b3ea54d3d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeExec.java @@ -14,7 +14,6 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; -import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import java.io.IOException; import java.util.List; @@ -47,7 +46,7 @@ private ExchangeExec(StreamInput in) throws IOException { Source.readFrom((PlanStreamInput) in), in.readNamedWriteableCollectionAsList(Attribute.class), in.readBoolean(), - ((PlanStreamInput) in).readPhysicalPlanNode() + in.readNamedWriteable(PhysicalPlan.class) ); } @@ -56,7 +55,7 @@ public void writeTo(StreamOutput out) throws IOException { Source.EMPTY.writeTo(out); out.writeNamedWriteableCollection(output); out.writeBoolean(inBetweenAggs()); - ((PlanStreamOutput) out).writePhysicalPlanNode(child()); + out.writeNamedWriteable(child()); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeSinkExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeSinkExec.java index 2992619da75ef..e342f17363bc8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeSinkExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeSinkExec.java @@ -14,7 +14,6 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; -import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import java.io.IOException; import java.util.List; @@ -41,7 +40,7 @@ private ExchangeSinkExec(StreamInput in) throws IOException { Source.readFrom((PlanStreamInput) in), in.readNamedWriteableCollectionAsList(Attribute.class), in.readBoolean(), - ((PlanStreamInput) in).readPhysicalPlanNode() + in.readNamedWriteable(PhysicalPlan.class) ); } @@ -50,7 +49,7 @@ public void writeTo(StreamOutput out) throws IOException { Source.EMPTY.writeTo(out); out.writeNamedWriteableCollection(output()); out.writeBoolean(isIntermediateAgg()); - ((PlanStreamOutput) out).writePhysicalPlanNode(child()); + out.writeNamedWriteable(child()); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/FieldExtractExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/FieldExtractExec.java index 7b51450d9f5e8..35c6e4846bd88 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/FieldExtractExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/FieldExtractExec.java @@ -16,7 +16,6 @@ import org.elasticsearch.xpack.esql.core.tree.NodeUtils; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; -import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import java.io.IOException; import java.util.ArrayList; @@ -60,7 +59,7 @@ private FieldExtractExec(Source source, PhysicalPlan child, List attr private FieldExtractExec(StreamInput in) throws IOException { this( Source.readFrom((PlanStreamInput) in), - ((PlanStreamInput) in).readPhysicalPlanNode(), + in.readNamedWriteable(PhysicalPlan.class), in.readNamedWriteableCollectionAsList(Attribute.class) ); // docValueAttributes are only used on the data node and never serialized. @@ -69,7 +68,7 @@ private FieldExtractExec(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { Source.EMPTY.writeTo(out); - ((PlanStreamOutput) out).writePhysicalPlanNode(child()); + out.writeNamedWriteable(child()); out.writeNamedWriteableCollection(attributesToExtract()); // docValueAttributes are only used on the data node and never serialized. } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/FilterExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/FilterExec.java index 14785a8764ad2..0802fc3423b23 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/FilterExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/FilterExec.java @@ -14,7 +14,6 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; -import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import java.io.IOException; import java.util.List; @@ -35,13 +34,13 @@ public FilterExec(Source source, PhysicalPlan child, Expression condition) { } private FilterExec(StreamInput in) throws IOException { - this(Source.readFrom((PlanStreamInput) in), ((PlanStreamInput) in).readPhysicalPlanNode(), in.readNamedWriteable(Expression.class)); + this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(PhysicalPlan.class), in.readNamedWriteable(Expression.class)); } @Override public void writeTo(StreamOutput out) throws IOException { Source.EMPTY.writeTo(out); - ((PlanStreamOutput) out).writePhysicalPlanNode(child()); + out.writeNamedWriteable(child()); out.writeNamedWriteable(condition()); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/FragmentExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/FragmentExec.java index f8ba9e06ed3a5..7594c971b7ffc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/FragmentExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/FragmentExec.java @@ -16,7 +16,6 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; -import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import java.io.IOException; @@ -58,7 +57,7 @@ private FragmentExec(StreamInput in) throws IOException { in.readNamedWriteable(LogicalPlan.class), in.readOptionalNamedWriteable(QueryBuilder.class), in.readOptionalVInt(), - in.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0) ? ((PlanStreamInput) in).readOptionalPhysicalPlanNode() : null + in.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0) ? in.readOptionalNamedWriteable(PhysicalPlan.class) : null ); } @@ -69,7 +68,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalNamedWriteable(esFilter()); out.writeOptionalVInt(estimatedRowSize()); if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) { - ((PlanStreamOutput) out).writeOptionalPhysicalPlanNode(reducer()); + out.writeOptionalNamedWriteable(reducer); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/GrokExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/GrokExec.java index d0c87630e94f0..59d25a4a2f472 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/GrokExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/GrokExec.java @@ -15,7 +15,6 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; -import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.plan.logical.Grok; import java.io.IOException; @@ -46,7 +45,7 @@ private static GrokExec readFrom(StreamInput in) throws IOException { Source source = Source.readFrom((PlanStreamInput) in); return new GrokExec( source, - ((PlanStreamInput) in).readPhysicalPlanNode(), + in.readNamedWriteable(PhysicalPlan.class), in.readNamedWriteable(Expression.class), Grok.pattern(source, in.readString()), in.readNamedWriteableCollectionAsList(Attribute.class) @@ -56,7 +55,7 @@ private static GrokExec readFrom(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { Source.EMPTY.writeTo(out); - ((PlanStreamOutput) out).writePhysicalPlanNode(child()); + out.writeNamedWriteable(child()); out.writeNamedWriteable(inputExpression()); out.writeString(pattern().pattern()); out.writeNamedWriteableCollection(extractedFields()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/HashJoinExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/HashJoinExec.java index 51bcf7f0722cb..5b83c4d95cabf 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/HashJoinExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/HashJoinExec.java @@ -16,7 +16,6 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; -import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import java.io.IOException; import java.util.List; @@ -55,7 +54,7 @@ public HashJoinExec( } private HashJoinExec(StreamInput in) throws IOException { - super(Source.readFrom((PlanStreamInput) in), ((PlanStreamInput) in).readPhysicalPlanNode()); + super(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(PhysicalPlan.class)); this.joinData = new LocalSourceExec(in); this.matchFields = in.readNamedWriteableCollectionAsList(Attribute.class); this.leftFields = in.readNamedWriteableCollectionAsList(Attribute.class); @@ -66,7 +65,7 @@ private HashJoinExec(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { source().writeTo(out); - ((PlanStreamOutput) out).writePhysicalPlanNode(child()); + out.writeNamedWriteable(child()); joinData.writeTo(out); out.writeNamedWriteableCollection(matchFields); out.writeNamedWriteableCollection(leftFields); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/LimitExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/LimitExec.java index 6c9d3e7867b02..8445fea08111c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/LimitExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/LimitExec.java @@ -14,7 +14,6 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; -import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import java.io.IOException; import java.util.Objects; @@ -34,13 +33,13 @@ public LimitExec(Source source, PhysicalPlan child, Expression limit) { } private LimitExec(StreamInput in) throws IOException { - this(Source.readFrom((PlanStreamInput) in), ((PlanStreamInput) in).readPhysicalPlanNode(), in.readNamedWriteable(Expression.class)); + this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(PhysicalPlan.class), in.readNamedWriteable(Expression.class)); } @Override public void writeTo(StreamOutput out) throws IOException { Source.EMPTY.writeTo(out); - ((PlanStreamOutput) out).writePhysicalPlanNode(child()); + out.writeNamedWriteable(child()); out.writeNamedWriteable(limit()); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/MvExpandExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/MvExpandExec.java index 2e7531a880742..071bad060320e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/MvExpandExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/MvExpandExec.java @@ -6,18 +6,28 @@ */ package org.elasticsearch.xpack.esql.plan.physical; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeSet; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import java.io.IOException; import java.util.List; import java.util.Objects; import static org.elasticsearch.xpack.esql.plan.logical.MvExpand.calculateOutput; public class MvExpandExec extends UnaryExec { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + PhysicalPlan.class, + "MvExpandExec", + MvExpandExec::new + ); private final NamedExpression target; private final Attribute expanded; @@ -30,6 +40,28 @@ public MvExpandExec(Source source, PhysicalPlan child, NamedExpression target, A this.output = calculateOutput(child.output(), target, expanded); } + private MvExpandExec(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(PhysicalPlan.class), + in.readNamedWriteable(NamedExpression.class), + in.readNamedWriteable(Attribute.class) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + out.writeNamedWriteable(child()); + out.writeNamedWriteable(target()); + out.writeNamedWriteable(expanded()); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected AttributeSet computeReferences() { return target.references(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/OrderExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/OrderExec.java index 63b838e068f1e..9d53e828f4f81 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/OrderExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/OrderExec.java @@ -7,14 +7,24 @@ package org.elasticsearch.xpack.esql.plan.physical; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.expression.Order; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import java.io.IOException; import java.util.List; import java.util.Objects; public class OrderExec extends UnaryExec { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + PhysicalPlan.class, + "OrderExec", + OrderExec::new + ); private final List order; @@ -23,6 +33,26 @@ public OrderExec(Source source, PhysicalPlan child, List order) { this.order = order; } + private OrderExec(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(PhysicalPlan.class), + in.readCollectionAsList(org.elasticsearch.xpack.esql.expression.Order::new) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + out.writeNamedWriteable(child()); + out.writeCollection(order()); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected NodeInfo info() { return NodeInfo.create(this, OrderExec::new, child(), order); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/OutputExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/OutputExec.java index b4a5608e31dfd..c7d2f53d3c748 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/OutputExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/OutputExec.java @@ -7,11 +7,13 @@ package org.elasticsearch.xpack.esql.plan.physical; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.data.Page; import org.elasticsearch.xpack.esql.core.expression.AttributeSet; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; +import java.io.IOException; import java.util.function.Consumer; public class OutputExec extends UnaryExec { @@ -28,6 +30,16 @@ public OutputExec(Source source, PhysicalPlan child, Consumer pageConsumer this.pageConsumer = pageConsumer; } + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new UnsupportedOperationException("not serialized"); + } + + @Override + public String getWriteableName() { + throw new UnsupportedOperationException("not serialized"); + } + public Consumer getPageConsumer() { return pageConsumer; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/PhysicalPlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/PhysicalPlan.java index 32bff02959097..9ddcd97218069 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/PhysicalPlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/PhysicalPlan.java @@ -8,11 +8,9 @@ package org.elasticsearch.xpack.esql.plan.physical; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.plan.QueryPlan; -import java.io.IOException; import java.util.List; /** @@ -37,9 +35,15 @@ public static List getNamedWriteables() { FilterExec.ENTRY, FragmentExec.ENTRY, GrokExec.ENTRY, + HashJoinExec.ENTRY, LimitExec.ENTRY, LocalSourceExec.ENTRY, - HashJoinExec.ENTRY + MvExpandExec.ENTRY, + OrderExec.ENTRY, + ProjectExec.ENTRY, + RowExec.ENTRY, + ShowExec.ENTRY, + TopNExec.ENTRY ); } @@ -47,17 +51,6 @@ public PhysicalPlan(Source source, List children) { super(source, children); } - @Override - public void writeTo(StreamOutput out) throws IOException { - // TODO remove when all PhysicalPlans are migrated to NamedWriteable - throw new UnsupportedOperationException(); - } - - @Override - public String getWriteableName() { - throw new UnsupportedOperationException(); - } - @Override public abstract int hashCode(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/ProjectExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/ProjectExec.java index 95fef43f7e6aa..b1e9d517889aa 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/ProjectExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/ProjectExec.java @@ -6,16 +6,26 @@ */ package org.elasticsearch.xpack.esql.plan.physical; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import java.io.IOException; import java.util.List; import java.util.Objects; public class ProjectExec extends UnaryExec { // TODO implement EstimatesRowSize *somehow* + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + PhysicalPlan.class, + "ProjectExec", + ProjectExec::new + ); private final List projections; @@ -24,6 +34,26 @@ public ProjectExec(Source source, PhysicalPlan child, List info() { return NodeInfo.create(this, ProjectExec::new, child(), projections); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/RowExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/RowExec.java index a80b2bee36292..3a104d4bc292b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/RowExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/RowExec.java @@ -7,16 +7,23 @@ package org.elasticsearch.xpack.esql.plan.physical; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import java.io.IOException; import java.util.List; import java.util.Objects; public class RowExec extends LeafExec { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(PhysicalPlan.class, "RowExec", RowExec::new); + private final List fields; public RowExec(Source source, List fields) { @@ -24,6 +31,21 @@ public RowExec(Source source, List fields) { this.fields = fields; } + private RowExec(StreamInput in) throws IOException { + this(Source.readFrom((PlanStreamInput) in), in.readCollectionAsList(Alias::new)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + out.writeCollection(fields()); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + public List fields() { return fields; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/ShowExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/ShowExec.java index 700e3282b9efc..23864e4001279 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/ShowExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/ShowExec.java @@ -7,14 +7,24 @@ package org.elasticsearch.xpack.esql.plan.physical; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import java.io.IOException; import java.util.List; import java.util.Objects; public class ShowExec extends LeafExec { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + PhysicalPlan.class, + "ShowExec", + ShowExec::new + ); private final List attributes; private final List> values; @@ -25,6 +35,27 @@ public ShowExec(Source source, List attributes, List> va this.values = values; } + @SuppressWarnings("unchecked") + private ShowExec(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteableCollectionAsList(Attribute.class), + (List>) in.readGenericValue() + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + out.writeNamedWriteableCollection(output()); + out.writeGenericValue(values()); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected NodeInfo info() { return NodeInfo.create(this, ShowExec::new, attributes, values); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/TopNExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/TopNExec.java index a671e85b3a754..61e40b3fa4693 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/TopNExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/TopNExec.java @@ -7,15 +7,26 @@ package org.elasticsearch.xpack.esql.plan.physical; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.expression.Order; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import java.io.IOException; import java.util.List; import java.util.Objects; public class TopNExec extends UnaryExec implements EstimatesRowSize { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + PhysicalPlan.class, + "TopNExec", + TopNExec::new + ); + private final Expression limit; private final List order; @@ -32,6 +43,30 @@ public TopNExec(Source source, PhysicalPlan child, List order, Expression this.estimatedRowSize = estimatedRowSize; } + private TopNExec(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(PhysicalPlan.class), + in.readCollectionAsList(org.elasticsearch.xpack.esql.expression.Order::new), + in.readNamedWriteable(Expression.class), + in.readOptionalVInt() + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + out.writeNamedWriteable(child()); + out.writeCollection(order()); + out.writeNamedWriteable(limit()); + out.writeOptionalVInt(estimatedRowSize()); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected NodeInfo info() { return NodeInfo.create(this, TopNExec::new, child(), order, limit, estimatedRowSize); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ClusterComputeRequest.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ClusterComputeRequest.java index f82060419f73a..74935e116f064 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ClusterComputeRequest.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ClusterComputeRequest.java @@ -21,7 +21,6 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.transport.TransportRequest; -import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.session.Configuration; @@ -37,7 +36,6 @@ * will poll pages from this sink. Internally, this compute will trigger sub-computes on data nodes via {@link DataNodeRequest}. */ final class ClusterComputeRequest extends TransportRequest implements IndicesRequest.Replaceable { - private static final PlanNameRegistry planNameRegistry = new PlanNameRegistry(); private final String clusterAlias; private final String sessionId; private final Configuration configuration; @@ -69,7 +67,7 @@ final class ClusterComputeRequest extends TransportRequest implements IndicesReq // TODO make EsqlConfiguration Releasable new BlockStreamInput(in, new BlockFactory(new NoopCircuitBreaker(CircuitBreaker.REQUEST), BigArrays.NON_RECYCLING_INSTANCE)) ); - this.plan = RemoteClusterPlan.from(new PlanStreamInput(in, planNameRegistry, in.namedWriteableRegistry(), configuration)); + this.plan = RemoteClusterPlan.from(new PlanStreamInput(in, in.namedWriteableRegistry(), configuration)); this.indices = plan.originalIndices().indices(); } @@ -79,7 +77,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(clusterAlias); out.writeString(sessionId); configuration.writeTo(out); - plan.writeTo(new PlanStreamOutput(out, planNameRegistry, configuration)); + plan.writeTo(new PlanStreamOutput(out, configuration)); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequest.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequest.java index 4174cc25552f5..8f890e63bf54e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequest.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequest.java @@ -25,7 +25,6 @@ import org.elasticsearch.tasks.TaskId; import org.elasticsearch.transport.RemoteClusterAware; import org.elasticsearch.transport.TransportRequest; -import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; @@ -38,7 +37,6 @@ import java.util.Objects; final class DataNodeRequest extends TransportRequest implements IndicesRequest.Replaceable { - private static final PlanNameRegistry planNameRegistry = new PlanNameRegistry(); private final String sessionId; private final Configuration configuration; private final String clusterAlias; @@ -82,7 +80,7 @@ final class DataNodeRequest extends TransportRequest implements IndicesRequest.R } this.shardIds = in.readCollectionAsList(ShardId::new); this.aliasFilters = in.readMap(Index::new, AliasFilter::readFrom); - this.plan = new PlanStreamInput(in, planNameRegistry, in.namedWriteableRegistry(), configuration).readPhysicalPlanNode(); + this.plan = new PlanStreamInput(in, in.namedWriteableRegistry(), configuration).readNamedWriteable(PhysicalPlan.class); if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_ORIGINAL_INDICES)) { this.indices = in.readStringArray(); this.indicesOptions = IndicesOptions.readIndicesOptions(in); @@ -102,7 +100,7 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeCollection(shardIds); out.writeMap(aliasFilters); - new PlanStreamOutput(out, planNameRegistry, configuration).writePhysicalPlanNode(plan); + new PlanStreamOutput(out, configuration).writeNamedWriteable(plan); if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_ORIGINAL_INDICES)) { out.writeStringArray(indices); indicesOptions.writeIndicesOptions(out); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java index c630051e79a26..315309fbad677 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java @@ -67,6 +67,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.querydsl.query.SingleValueQuery; import org.elasticsearch.xpack.esql.session.IndexResolver; @@ -200,6 +201,7 @@ public List getNamedWriteables() { entries.addAll(EsqlScalarFunction.getNamedWriteables()); entries.addAll(AggregateFunction.getNamedWriteables()); entries.addAll(LogicalPlan.getNamedWriteables()); + entries.addAll(PhysicalPlan.getNamedWriteables()); return entries; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/RemoteClusterPlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/RemoteClusterPlan.java index 3fd8b17f01778..8564e4b3afde1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/RemoteClusterPlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/RemoteClusterPlan.java @@ -18,7 +18,7 @@ record RemoteClusterPlan(PhysicalPlan plan, String[] targetIndices, OriginalIndices originalIndices) { static RemoteClusterPlan from(PlanStreamInput planIn) throws IOException { - var plan = planIn.readPhysicalPlanNode(); + var plan = planIn.readNamedWriteable(PhysicalPlan.class); var targetIndices = planIn.readStringArray(); final OriginalIndices originalIndices; if (planIn.getTransportVersion().onOrAfter(TransportVersions.ESQL_ORIGINAL_INDICES)) { @@ -30,7 +30,7 @@ static RemoteClusterPlan from(PlanStreamInput planIn) throws IOException { } public void writeTo(PlanStreamOutput out) throws IOException { - out.writePhysicalPlanNode(plan); + out.writeNamedWriteable(plan); out.writeStringArray(targetIndices); if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_ORIGINAL_INDICES)) { OriginalIndices.writeOriginalIndices(originalIndices, out); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java index 339e7159ed87d..fe74883a0c24f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java @@ -30,7 +30,6 @@ import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; -import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; @@ -44,9 +43,6 @@ import java.util.List; public class SerializationTestUtils { - - private static final PlanNameRegistry planNameRegistry = new PlanNameRegistry(); - public static void assertSerialization(PhysicalPlan plan) { assertSerialization(plan, EsqlTestUtils.TEST_CFG); } @@ -54,8 +50,8 @@ public static void assertSerialization(PhysicalPlan plan) { public static void assertSerialization(PhysicalPlan plan, Configuration configuration) { var deserPlan = serializeDeserialize( plan, - PlanStreamOutput::writePhysicalPlanNode, - PlanStreamInput::readPhysicalPlanNode, + PlanStreamOutput::writeNamedWriteable, + in -> in.readNamedWriteable(PhysicalPlan.class), configuration ); EqualsHashCodeTestUtils.checkEqualsAndHashCode(plan, unused -> deserPlan); @@ -86,13 +82,13 @@ public static T serializeDeserialize(T orig, Serializer serializer, Deser public static T serializeDeserialize(T orig, Serializer serializer, Deserializer deserializer, Configuration config) { try (BytesStreamOutput out = new BytesStreamOutput()) { - PlanStreamOutput planStreamOutput = new PlanStreamOutput(out, planNameRegistry, config); + PlanStreamOutput planStreamOutput = new PlanStreamOutput(out, config); serializer.write(planStreamOutput, orig); StreamInput in = new NamedWriteableAwareStreamInput( ByteBufferStreamInput.wrap(BytesReference.toBytes(out.bytes())), writableRegistry() ); - PlanStreamInput planStreamInput = new PlanStreamInput(in, planNameRegistry, in.namedWriteableRegistry(), config); + PlanStreamInput planStreamInput = new PlanStreamInput(in, in.namedWriteableRegistry(), config); return deserializer.read(planStreamInput); } catch (IOException e) { throw new UncheckedIOException(e); @@ -127,6 +123,7 @@ public static NamedWriteableRegistry writableRegistry() { entries.addAll(AggregateFunction.getNamedWriteables()); entries.addAll(Block.getNamedWriteables()); entries.addAll(LogicalPlan.getNamedWriteables()); + entries.addAll(PhysicalPlan.getNamedWriteables()); return new NamedWriteableRegistry(entries); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/AbstractExpressionSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/AbstractExpressionSerializationTests.java index 596ff2af5fb5a..ab20a5ce0cc6b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/AbstractExpressionSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/AbstractExpressionSerializationTests.java @@ -23,7 +23,7 @@ public abstract class AbstractExpressionSerializationTests extends AbstractNodeSerializationTests { public static Expression randomChild() { - return ReferenceAttributeTests.randomReferenceAttribute(); + return ReferenceAttributeTests.randomReferenceAttribute(false); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/AliasTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/AliasTests.java index 2a6791a1f5300..ccbed01994bf7 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/AliasTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/AliasTests.java @@ -19,7 +19,6 @@ import org.elasticsearch.xpack.esql.core.tree.SourceTests; import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; -import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; @@ -35,7 +34,7 @@ public static Alias randomAlias() { Source source = SourceTests.randomSource(); String name = randomAlphaOfLength(5); // TODO better randomChild - Expression child = ReferenceAttributeTests.randomReferenceAttribute(); + Expression child = ReferenceAttributeTests.randomReferenceAttribute(false); boolean synthetic = randomBoolean(); return new Alias(source, name, child, new NameId(), synthetic); } @@ -53,7 +52,7 @@ protected Alias mutateInstance(Alias instance) throws IOException { boolean synthetic = instance.synthetic(); switch (between(0, 2)) { case 0 -> name = randomAlphaOfLength(name.length() + 1); - case 1 -> child = randomValueOtherThan(child, ReferenceAttributeTests::randomReferenceAttribute); + case 1 -> child = randomValueOtherThan(child, () -> ReferenceAttributeTests.randomReferenceAttribute(false)); case 2 -> synthetic = false == synthetic; } return new Alias(source, name, child, instance.id(), synthetic); @@ -64,9 +63,9 @@ protected Alias copyInstance(Alias instance, TransportVersion version) throws IO return copyInstance( instance, getNamedWriteableRegistry(), - (out, v) -> new PlanStreamOutput(out, new PlanNameRegistry(), null).writeNamedWriteable(v), + (out, v) -> new PlanStreamOutput(out, null).writeNamedWriteable(v), in -> { - PlanStreamInput pin = new PlanStreamInput(in, new PlanNameRegistry(), in.namedWriteableRegistry(), null); + PlanStreamInput pin = new PlanStreamInput(in, in.namedWriteableRegistry(), null); Alias deser = (Alias) pin.readNamedWriteable(NamedExpression.class); assertThat(deser.id(), equalTo(pin.mapNameId(Long.parseLong(instance.id().toString())))); return deser; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAttributeTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAttributeTestCase.java index 76b813f08d818..a9750acdb1b84 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAttributeTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAttributeTestCase.java @@ -15,7 +15,6 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.session.Configuration; @@ -61,7 +60,7 @@ protected final NamedWriteableRegistry getNamedWriteableRegistry() { @Override protected final Writeable.Reader instanceReader() { return in -> { - PlanStreamInput pin = new PlanStreamInput(in, PlanNameRegistry.INSTANCE, in.namedWriteableRegistry(), config); + PlanStreamInput pin = new PlanStreamInput(in, in.namedWriteableRegistry(), config); pin.setTransportVersion(in.getTransportVersion()); return new ExtraAttribute(pin); }; @@ -84,7 +83,7 @@ public static class ExtraAttribute implements Writeable { @Override public void writeTo(StreamOutput out) throws IOException { - new PlanStreamOutput(out, new PlanNameRegistry(), EsqlTestUtils.TEST_CFG).writeNamedWriteable(a); + new PlanStreamOutput(out, EsqlTestUtils.TEST_CFG).writeNamedWriteable(a); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/ReferenceAttributeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/ReferenceAttributeTests.java index 493cecffe8b3f..1c03f121a91ff 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/ReferenceAttributeTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/ReferenceAttributeTests.java @@ -14,10 +14,12 @@ import org.elasticsearch.xpack.esql.core.type.DataType; public class ReferenceAttributeTests extends AbstractAttributeTestCase { - public static ReferenceAttribute randomReferenceAttribute() { + public static ReferenceAttribute randomReferenceAttribute(boolean onlyRepresentable) { Source source = Source.EMPTY; String name = randomAlphaOfLength(5); - DataType type = randomFrom(DataType.types()); + DataType type = onlyRepresentable + ? randomValueOtherThanMany(t -> false == DataType.isRepresentable(t), () -> randomFrom(DataType.types())) + : randomFrom(DataType.types()); Nullability nullability = randomFrom(Nullability.values()); boolean synthetic = randomBoolean(); return new ReferenceAttribute(source, name, type, nullability, new NameId(), synthetic); @@ -25,7 +27,7 @@ public static ReferenceAttribute randomReferenceAttribute() { @Override protected ReferenceAttribute create() { - return randomReferenceAttribute(); + return randomReferenceAttribute(false); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/index/EsIndexSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/index/EsIndexSerializationTests.java index 76fc8c52d91ab..7be200baf6c58 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/index/EsIndexSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/index/EsIndexSerializationTests.java @@ -15,7 +15,6 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.core.type.InvalidMappedField; -import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.type.EsFieldTests; @@ -57,12 +56,12 @@ private static Map randomConcreteIndices() { @Override protected Writeable.Reader instanceReader() { - return a -> new EsIndex(new PlanStreamInput(a, new PlanNameRegistry(), a.namedWriteableRegistry(), null)); + return a -> new EsIndex(new PlanStreamInput(a, a.namedWriteableRegistry(), null)); } @Override protected Writeable.Writer instanceWriter() { - return (out, idx) -> new PlanStreamOutput(out, new PlanNameRegistry(), null).writeWriteable(idx); + return (out, idx) -> new PlanStreamOutput(out, null).writeWriteable(idx); } @Override @@ -176,7 +175,7 @@ public void testManyTypeConflictsWithParent() throws IOException { *

    */ private void testManyTypeConflicts(boolean withParent, ByteSizeValue expected) throws IOException { - try (BytesStreamOutput out = new BytesStreamOutput(); var pso = new PlanStreamOutput(out, new PlanNameRegistry(), null)) { + try (BytesStreamOutput out = new BytesStreamOutput(); var pso = new PlanStreamOutput(out, null)) { indexWithManyConflicts(withParent).writeTo(pso); assertThat(ByteSizeValue.ofBytes(out.bytes().length()), byteSizeEquals(expected)); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java deleted file mode 100644 index 853062676a0dc..0000000000000 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java +++ /dev/null @@ -1,214 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.io.stream; - -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.ByteBufferStreamInput; -import org.elasticsearch.common.io.stream.BytesStreamOutput; -import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.test.EqualsHashCodeTestUtils; -import org.elasticsearch.xpack.esql.EsqlTestUtils; -import org.elasticsearch.xpack.esql.SerializationTestUtils; -import org.elasticsearch.xpack.esql.core.expression.Alias; -import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; -import org.elasticsearch.xpack.esql.core.expression.NameId; -import org.elasticsearch.xpack.esql.core.expression.Nullability; -import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.core.type.EsField; -import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; -import org.elasticsearch.xpack.esql.plan.physical.DissectExec; -import org.elasticsearch.xpack.esql.plan.physical.EnrichExec; -import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec; -import org.elasticsearch.xpack.esql.plan.physical.EsSourceExec; -import org.elasticsearch.xpack.esql.plan.physical.EvalExec; -import org.elasticsearch.xpack.esql.plan.physical.ExchangeExec; -import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec; -import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec; -import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec; -import org.elasticsearch.xpack.esql.plan.physical.FilterExec; -import org.elasticsearch.xpack.esql.plan.physical.FragmentExec; -import org.elasticsearch.xpack.esql.plan.physical.GrokExec; -import org.elasticsearch.xpack.esql.plan.physical.HashJoinExec; -import org.elasticsearch.xpack.esql.plan.physical.LimitExec; -import org.elasticsearch.xpack.esql.plan.physical.LocalSourceExec; -import org.elasticsearch.xpack.esql.plan.physical.MvExpandExec; -import org.elasticsearch.xpack.esql.plan.physical.OrderExec; -import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; -import org.elasticsearch.xpack.esql.plan.physical.ProjectExec; -import org.elasticsearch.xpack.esql.plan.physical.RowExec; -import org.elasticsearch.xpack.esql.plan.physical.ShowExec; -import org.elasticsearch.xpack.esql.plan.physical.TopNExec; - -import java.io.IOException; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.elasticsearch.test.ListMatcher.matchesList; -import static org.elasticsearch.test.MapMatcher.assertMap; -import static org.elasticsearch.xpack.esql.core.expression.FieldAttributeTestUtils.newFieldAttributeWithType; -import static org.hamcrest.Matchers.equalTo; - -public class PlanNamedTypesTests extends ESTestCase { - - // List of known serializable physical plan nodes - this should be kept up to date or retrieved - // programmatically. Excludes LocalSourceExec - public static final List> PHYSICAL_PLAN_NODE_CLS = List.of( - AggregateExec.class, - DissectExec.class, - EsQueryExec.class, - EsSourceExec.class, - EvalExec.class, - EnrichExec.class, - ExchangeExec.class, - ExchangeSinkExec.class, - ExchangeSourceExec.class, - FieldExtractExec.class, - FilterExec.class, - FragmentExec.class, - GrokExec.class, - LimitExec.class, - LocalSourceExec.class, - HashJoinExec.class, - MvExpandExec.class, - OrderExec.class, - ProjectExec.class, - RowExec.class, - ShowExec.class, - TopNExec.class - ); - - // Tests that all physical plan nodes have a suitably named serialization entry. - public void testPhysicalPlanEntries() { - var expected = PHYSICAL_PLAN_NODE_CLS.stream().map(Class::getSimpleName).toList(); - var actual = PlanNamedTypes.namedTypeEntries() - .stream() - .filter(e -> e.categoryClass().isAssignableFrom(PhysicalPlan.class)) - .map(PlanNameRegistry.Entry::name) - .toList(); - assertMap(actual, matchesList(expected)); - } - - // Tests that all names are unique - there should be a good reason if this is not the case. - public void testUniqueNames() { - var actual = PlanNamedTypes.namedTypeEntries().stream().map(PlanNameRegistry.Entry::name).distinct().toList(); - assertThat(actual.size(), equalTo(PlanNamedTypes.namedTypeEntries().size())); - } - - // Tests that reader from the original(outer) stream and inner(plan) streams work together. - public void testWrappedStreamSimple() throws IOException { - // write - BytesStreamOutput bso = new BytesStreamOutput(); - bso.writeString("hello"); - PlanStreamOutput out = new PlanStreamOutput(bso, planNameRegistry, null); - var plan = new RowExec(Source.EMPTY, List.of(new Alias(Source.EMPTY, "foo", field("field", DataType.LONG)))); - out.writePhysicalPlanNode(plan); - bso.writeVInt(11_345); - - // read - StreamInput in = ByteBufferStreamInput.wrap(BytesReference.toBytes(bso.bytes())); - assertThat(in.readString(), equalTo("hello")); - var planStreamInput = new PlanStreamInput(in, planNameRegistry, SerializationTestUtils.writableRegistry(), EsqlTestUtils.TEST_CFG); - var deser = (RowExec) planStreamInput.readPhysicalPlanNode(); - EqualsHashCodeTestUtils.checkEqualsAndHashCode(plan, unused -> deser); - assertThat(in.readVInt(), equalTo(11_345)); - } - - static FieldAttribute randomFieldAttributeOrNull() { - return randomBoolean() ? randomFieldAttribute() : null; - } - - static FieldAttribute randomFieldAttribute() { - return newFieldAttributeWithType( - Source.EMPTY, - randomFieldAttributeOrNull(), // parent - randomAlphaOfLength(randomIntBetween(1, 25)), // name - randomDataType(), - randomEsField(), - randomNullability(), - nameIdOrNull(), - randomBoolean() // synthetic - ); - } - - static NameId nameIdOrNull() { - return randomBoolean() ? new NameId() : null; - } - - static Nullability randomNullability() { - int i = randomInt(2); - return switch (i) { - case 0 -> Nullability.UNKNOWN; - case 1 -> Nullability.TRUE; - case 2 -> Nullability.FALSE; - default -> throw new AssertionError(i); - }; - } - - public static EsField randomEsField() { - return randomEsField(0); - } - - static EsField randomEsField(int depth) { - return new EsField( - randomAlphaOfLength(randomIntBetween(1, 25)), - randomDataType(), - randomProperties(depth), - randomBoolean(), // aggregatable - randomBoolean() // isAlias - ); - } - - static Map randomProperties(int depth) { - if (depth > 2) { - return Map.of(); // prevent infinite recursion (between EsField and properties) - } - depth += 1; - int size = randomIntBetween(0, 5); - Map map = new HashMap<>(); - for (int i = 0; i < size; i++) { - map.put( - randomAlphaOfLength(randomIntBetween(1, 10)), // name - randomEsField(depth) - ); - } - return Map.copyOf(map); - } - - static List DATA_TYPES = DataType.types().stream().toList(); - - static DataType randomDataType() { - return DATA_TYPES.get(randomIntBetween(0, DATA_TYPES.size() - 1)); - } - - static String randomStringOrNull() { - return randomBoolean() ? randomAlphaOfLength(randomIntBetween(1, 25)) : null; - } - - static String randomName() { - return randomAlphaOfLength(randomIntBetween(1, 25)); - } - - static FieldAttribute field(String name, DataType type) { - return new FieldAttribute(Source.EMPTY, name, new EsField(name, type, Collections.emptyMap(), false)); - } - - static PlanNameRegistry planNameRegistry = new PlanNameRegistry(); - - static PlanStreamInput planStreamInput(BytesStreamOutput out) { - StreamInput in = new NamedWriteableAwareStreamInput( - ByteBufferStreamInput.wrap(BytesReference.toBytes(out.bytes())), - SerializationTestUtils.writableRegistry() - ); - return new PlanStreamInput(in, planNameRegistry, SerializationTestUtils.writableRegistry(), EsqlTestUtils.TEST_CFG); - } -} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutputTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutputTests.java index cdb6c5384e16a..33252b9dbaaa3 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutputTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutputTests.java @@ -24,11 +24,13 @@ import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.EsField; +import org.elasticsearch.xpack.esql.expression.function.FieldAttributeTests; import org.elasticsearch.xpack.esql.expression.function.MetadataAttributeTests; import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttributeTests; import org.elasticsearch.xpack.esql.session.Configuration; +import org.elasticsearch.xpack.esql.type.EsFieldTests; import java.io.IOException; import java.util.ArrayList; @@ -51,7 +53,7 @@ public void testTransportVersion() throws IOException { BytesStreamOutput out = new BytesStreamOutput(); TransportVersion v1 = TransportVersionUtils.randomCompatibleVersion(random()); out.setTransportVersion(v1); - PlanStreamOutput planOut = new PlanStreamOutput(out, PlanNameRegistry.INSTANCE, randomBoolean() ? null : randomConfiguration()); + PlanStreamOutput planOut = new PlanStreamOutput(out, randomBoolean() ? null : randomConfiguration()); assertThat(planOut.getTransportVersion(), equalTo(v1)); TransportVersion v2 = TransportVersionUtils.randomCompatibleVersion(random()); planOut.setTransportVersion(v2); @@ -64,15 +66,10 @@ public void testWriteBlockFromConfig() throws IOException { String columnName = randomAlphaOfLength(10); try (Column c = randomColumn()) { Configuration configuration = randomConfiguration("query_" + randomAlphaOfLength(1), Map.of(tableName, Map.of(columnName, c))); - try ( - BytesStreamOutput out = new BytesStreamOutput(); - PlanStreamOutput planStream = new PlanStreamOutput(out, PlanNameRegistry.INSTANCE, configuration) - ) { + try (BytesStreamOutput out = new BytesStreamOutput(); PlanStreamOutput planStream = new PlanStreamOutput(out, configuration)) { planStream.writeCachedBlock(c.values()); assertThat(out.bytes().length(), equalTo(3 + tableName.length() + columnName.length())); - try ( - PlanStreamInput in = new PlanStreamInput(out.bytes().streamInput(), PlanNameRegistry.INSTANCE, REGISTRY, configuration) - ) { + try (PlanStreamInput in = new PlanStreamInput(out.bytes().streamInput(), REGISTRY, configuration)) { assertThat(in.readCachedBlock(), sameInstance(c.values())); } } @@ -82,16 +79,11 @@ public void testWriteBlockFromConfig() throws IOException { public void testWriteBlockOnce() throws IOException { try (Block b = randomColumn().values()) { Configuration configuration = randomConfiguration(); - try ( - BytesStreamOutput out = new BytesStreamOutput(); - PlanStreamOutput planStream = new PlanStreamOutput(out, PlanNameRegistry.INSTANCE, configuration) - ) { + try (BytesStreamOutput out = new BytesStreamOutput(); PlanStreamOutput planStream = new PlanStreamOutput(out, configuration)) { planStream.writeCachedBlock(b); assertThat(out.bytes().length(), greaterThan(4 * LEN)); assertThat(out.bytes().length(), lessThan(8 * LEN)); - try ( - PlanStreamInput in = new PlanStreamInput(out.bytes().streamInput(), PlanNameRegistry.INSTANCE, REGISTRY, configuration) - ) { + try (PlanStreamInput in = new PlanStreamInput(out.bytes().streamInput(), REGISTRY, configuration)) { Block read = in.readCachedBlock(); assertThat(read, not(sameInstance(b))); assertThat(read, equalTo(b)); @@ -103,17 +95,12 @@ public void testWriteBlockOnce() throws IOException { public void testWriteBlockTwice() throws IOException { try (Block b = randomColumn().values()) { Configuration configuration = randomConfiguration(); - try ( - BytesStreamOutput out = new BytesStreamOutput(); - PlanStreamOutput planStream = new PlanStreamOutput(out, PlanNameRegistry.INSTANCE, configuration) - ) { + try (BytesStreamOutput out = new BytesStreamOutput(); PlanStreamOutput planStream = new PlanStreamOutput(out, configuration)) { planStream.writeCachedBlock(b); planStream.writeCachedBlock(b); assertThat(out.bytes().length(), greaterThan(4 * LEN)); assertThat(out.bytes().length(), lessThan(8 * LEN)); - try ( - PlanStreamInput in = new PlanStreamInput(out.bytes().streamInput(), PlanNameRegistry.INSTANCE, REGISTRY, configuration) - ) { + try (PlanStreamInput in = new PlanStreamInput(out.bytes().streamInput(), REGISTRY, configuration)) { Block read = in.readCachedBlock(); assertThat(read, not(sameInstance(b))); assertThat(read, equalTo(b)); @@ -126,10 +113,7 @@ public void testWriteBlockTwice() throws IOException { public void testWriteAttributeMultipleTimes() throws IOException { Attribute attribute = randomAttribute(); Configuration configuration = randomConfiguration(); - try ( - BytesStreamOutput out = new BytesStreamOutput(); - PlanStreamOutput planStream = new PlanStreamOutput(out, PlanNameRegistry.INSTANCE, configuration) - ) { + try (BytesStreamOutput out = new BytesStreamOutput(); PlanStreamOutput planStream = new PlanStreamOutput(out, configuration)) { int occurrences = randomIntBetween(2, 150); for (int i = 0; i < occurrences; i++) { planStream.writeNamedWriteable(attribute); @@ -141,7 +125,7 @@ public void testWriteAttributeMultipleTimes() throws IOException { parent = parent instanceof FieldAttribute f ? f.parent() : null; } assertThat(planStream.cachedAttributes.size(), is(depth)); - try (PlanStreamInput in = new PlanStreamInput(out.bytes().streamInput(), PlanNameRegistry.INSTANCE, REGISTRY, configuration)) { + try (PlanStreamInput in = new PlanStreamInput(out.bytes().streamInput(), REGISTRY, configuration)) { Attribute first = in.readNamedWriteable(Attribute.class); for (int i = 1; i < occurrences; i++) { Attribute next = in.readNamedWriteable(Attribute.class); @@ -160,10 +144,7 @@ public void testWriteAttributeMultipleTimes() throws IOException { public void testWriteMultipleAttributes() throws IOException { Configuration configuration = randomConfiguration(); - try ( - BytesStreamOutput out = new BytesStreamOutput(); - PlanStreamOutput planStream = new PlanStreamOutput(out, PlanNameRegistry.INSTANCE, configuration) - ) { + try (BytesStreamOutput out = new BytesStreamOutput(); PlanStreamOutput planStream = new PlanStreamOutput(out, configuration)) { List attrs = new ArrayList<>(); int occurrences = randomIntBetween(2, 300); for (int i = 0; i < occurrences; i++) { @@ -177,7 +158,7 @@ public void testWriteMultipleAttributes() throws IOException { } } - try (PlanStreamInput in = new PlanStreamInput(out.bytes().streamInput(), PlanNameRegistry.INSTANCE, REGISTRY, configuration)) { + try (PlanStreamInput in = new PlanStreamInput(out.bytes().streamInput(), REGISTRY, configuration)) { List readAttrs = new ArrayList<>(); for (int i = 0; i < occurrences; i++) { readAttrs.add(in.readNamedWriteable(Attribute.class)); @@ -196,10 +177,7 @@ public void testWriteMultipleAttributes() throws IOException { public void testWriteMultipleAttributesWithSmallCache() throws IOException { Configuration configuration = randomConfiguration(); - try ( - BytesStreamOutput out = new BytesStreamOutput(); - PlanStreamOutput planStream = new PlanStreamOutput(out, PlanNameRegistry.INSTANCE, configuration, PlanNamedTypes::name, 10) - ) { + try (BytesStreamOutput out = new BytesStreamOutput(); PlanStreamOutput planStream = new PlanStreamOutput(out, configuration, 10)) { expectThrows(InvalidArgumentException.class, () -> { for (int i = 0; i <= 10; i++) { planStream.writeNamedWriteable(randomAttribute()); @@ -210,10 +188,7 @@ public void testWriteMultipleAttributesWithSmallCache() throws IOException { public void testWriteEqualAttributesDifferentID() throws IOException { Configuration configuration = randomConfiguration(); - try ( - BytesStreamOutput out = new BytesStreamOutput(); - PlanStreamOutput planStream = new PlanStreamOutput(out, PlanNameRegistry.INSTANCE, configuration) - ) { + try (BytesStreamOutput out = new BytesStreamOutput(); PlanStreamOutput planStream = new PlanStreamOutput(out, configuration)) { Attribute one = randomAttribute(); Attribute two = one.withId(new NameId()); @@ -221,7 +196,7 @@ public void testWriteEqualAttributesDifferentID() throws IOException { planStream.writeNamedWriteable(one); planStream.writeNamedWriteable(two); - try (PlanStreamInput in = new PlanStreamInput(out.bytes().streamInput(), PlanNameRegistry.INSTANCE, REGISTRY, configuration)) { + try (PlanStreamInput in = new PlanStreamInput(out.bytes().streamInput(), REGISTRY, configuration)) { Attribute oneCopy = in.readNamedWriteable(Attribute.class); Attribute twoCopy = in.readNamedWriteable(Attribute.class); @@ -235,10 +210,7 @@ public void testWriteEqualAttributesDifferentID() throws IOException { public void testWriteDifferentAttributesSameID() throws IOException { Configuration configuration = randomConfiguration(); - try ( - BytesStreamOutput out = new BytesStreamOutput(); - PlanStreamOutput planStream = new PlanStreamOutput(out, PlanNameRegistry.INSTANCE, configuration) - ) { + try (BytesStreamOutput out = new BytesStreamOutput(); PlanStreamOutput planStream = new PlanStreamOutput(out, configuration)) { Attribute one = randomAttribute(); Attribute two = randomAttribute().withId(one.id()); @@ -246,7 +218,7 @@ public void testWriteDifferentAttributesSameID() throws IOException { planStream.writeNamedWriteable(one); planStream.writeNamedWriteable(two); - try (PlanStreamInput in = new PlanStreamInput(out.bytes().streamInput(), PlanNameRegistry.INSTANCE, REGISTRY, configuration)) { + try (PlanStreamInput in = new PlanStreamInput(out.bytes().streamInput(), REGISTRY, configuration)) { Attribute oneCopy = in.readNamedWriteable(Attribute.class); Attribute twoCopy = in.readNamedWriteable(Attribute.class); @@ -261,14 +233,11 @@ public void testWriteDifferentAttributesSameID() throws IOException { public void testWriteMultipleEsFields() throws IOException { Configuration configuration = randomConfiguration(); - try ( - BytesStreamOutput out = new BytesStreamOutput(); - PlanStreamOutput planStream = new PlanStreamOutput(out, PlanNameRegistry.INSTANCE, configuration) - ) { + try (BytesStreamOutput out = new BytesStreamOutput(); PlanStreamOutput planStream = new PlanStreamOutput(out, configuration)) { List fields = new ArrayList<>(); int occurrences = randomIntBetween(2, 300); for (int i = 0; i < occurrences; i++) { - fields.add(PlanNamedTypesTests.randomEsField()); + fields.add(EsFieldTests.randomEsField(4)); } // send all the EsFields, three times @@ -278,7 +247,7 @@ public void testWriteMultipleEsFields() throws IOException { } } - try (PlanStreamInput in = new PlanStreamInput(out.bytes().streamInput(), PlanNameRegistry.INSTANCE, REGISTRY, configuration)) { + try (PlanStreamInput in = new PlanStreamInput(out.bytes().streamInput(), REGISTRY, configuration)) { List readFields = new ArrayList<>(); for (int i = 0; i < occurrences; i++) { readFields.add(EsField.readFrom(in)); @@ -297,8 +266,8 @@ public void testWriteMultipleEsFields() throws IOException { private static Attribute randomAttribute() { return switch (randomInt(3)) { - case 0 -> PlanNamedTypesTests.randomFieldAttribute(); - case 1 -> ReferenceAttributeTests.randomReferenceAttribute(); + case 0 -> FieldAttributeTests.createFieldAttribute(0, false); + case 1 -> ReferenceAttributeTests.randomReferenceAttribute(false); case 2 -> UnsupportedAttributeTests.randomUnsupportedAttribute(); case 3 -> MetadataAttributeTests.randomMetadataAttribute(); default -> throw new IllegalArgumentException(); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/AbstractNodeSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/AbstractNodeSerializationTests.java index e6f5d6e4fac70..e6faa9a253d76 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/AbstractNodeSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/AbstractNodeSerializationTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.xpack.esql.core.tree.Node; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.expression.function.FieldAttributeTests; -import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.session.Configuration; @@ -56,9 +55,9 @@ protected final T copyInstance(T instance, TransportVersion version) throws IOEx return copyInstance( instance, getNamedWriteableRegistry(), - (out, v) -> new PlanStreamOutput(out, new PlanNameRegistry(), configuration()).writeNamedWriteable(v), + (out, v) -> new PlanStreamOutput(out, configuration()).writeNamedWriteable(v), in -> { - PlanStreamInput pin = new PlanStreamInput(in, new PlanNameRegistry(), in.namedWriteableRegistry(), configuration()); + PlanStreamInput pin = new PlanStreamInput(in, in.namedWriteableRegistry(), configuration()); @SuppressWarnings("unchecked") T deser = (T) pin.readNamedWriteable(categoryClass()); if (alwaysEmptySource()) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/GrokSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/GrokSerializationTests.java index 770b7983fd782..dea198ee36891 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/GrokSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/GrokSerializationTests.java @@ -23,7 +23,7 @@ protected Grok createTestInstance() { LogicalPlan child = randomChild(0); Expression inputExpr = FieldAttributeTests.createFieldAttribute(3, false); String pattern = randomAlphaOfLength(5); - List extracted = randomList(1, 10, ReferenceAttributeTests::randomReferenceAttribute); + List extracted = randomList(1, 10, () -> ReferenceAttributeTests.randomReferenceAttribute(false)); return new Grok(source, child, inputExpr, Grok.pattern(source, pattern), extracted); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/MvExpandSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/MvExpandSerializationTests.java index 08e59095161ed..fe8286db046bb 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/MvExpandSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/MvExpandSerializationTests.java @@ -21,7 +21,7 @@ protected MvExpand createTestInstance() { Source source = randomSource(); LogicalPlan child = randomChild(0); NamedExpression target = FieldAttributeTests.createFieldAttribute(0, false); - Attribute expanded = ReferenceAttributeTests.randomReferenceAttribute(); + Attribute expanded = ReferenceAttributeTests.randomReferenceAttribute(false); return new MvExpand(source, child, target, expanded); } @@ -33,7 +33,7 @@ protected MvExpand mutateInstance(MvExpand instance) throws IOException { switch (between(0, 2)) { case 0 -> child = randomValueOtherThan(child, () -> randomChild(0)); case 1 -> target = randomValueOtherThan(target, () -> FieldAttributeTests.createFieldAttribute(0, false)); - case 2 -> expanded = randomValueOtherThan(expanded, ReferenceAttributeTests::randomReferenceAttribute); + case 2 -> expanded = randomValueOtherThan(expanded, () -> ReferenceAttributeTests.randomReferenceAttribute(false)); } return new MvExpand(instance.source(), child, target, expanded); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/local/LocalSupplierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/local/LocalSupplierTests.java index e691d79dce00a..ccb27b41f2ed6 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/local/LocalSupplierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/local/LocalSupplierTests.java @@ -17,7 +17,6 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.test.AbstractWireTestCase; -import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; @@ -34,10 +33,10 @@ public class LocalSupplierTests extends AbstractWireTestCase { protected LocalSupplier copyInstance(LocalSupplier instance, TransportVersion version) throws IOException { try (BytesStreamOutput output = new BytesStreamOutput()) { output.setTransportVersion(version); - instance.writeTo(new PlanStreamOutput(output, PlanNameRegistry.INSTANCE, null)); + instance.writeTo(new PlanStreamOutput(output, null)); try (StreamInput in = output.bytes().streamInput()) { in.setTransportVersion(version); - return LocalSupplier.readFrom(new PlanStreamInput(in, PlanNameRegistry.INSTANCE, getNamedWriteableRegistry(), null)); + return LocalSupplier.readFrom(new PlanStreamInput(in, getNamedWriteableRegistry(), null)); } } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeSinkExecSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeSinkExecSerializationTests.java index 5b6c682044a8f..dd163609de8a8 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeSinkExecSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeSinkExecSerializationTests.java @@ -17,7 +17,6 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.index.EsIndexSerializationTests; -import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; @@ -115,21 +114,11 @@ private void testManyTypeConflicts(boolean withParent, ByteSizeValue expected) t Project project = new Project(randomSource(), limit, limit.output()); FragmentExec fragmentExec = new FragmentExec(project); ExchangeSinkExec exchangeSinkExec = new ExchangeSinkExec(randomSource(), fragmentExec.output(), false, fragmentExec); - try ( - BytesStreamOutput out = new BytesStreamOutput(); - PlanStreamOutput pso = new PlanStreamOutput(out, new PlanNameRegistry(), configuration()) - ) { - pso.writePhysicalPlanNode(exchangeSinkExec); + try (BytesStreamOutput out = new BytesStreamOutput(); PlanStreamOutput pso = new PlanStreamOutput(out, configuration())) { + pso.writeNamedWriteable(exchangeSinkExec); assertThat(ByteSizeValue.ofBytes(out.bytes().length()), byteSizeEquals(expected)); - try ( - PlanStreamInput psi = new PlanStreamInput( - out.bytes().streamInput(), - new PlanNameRegistry(), - getNamedWriteableRegistry(), - configuration() - ) - ) { - assertThat(psi.readPhysicalPlanNode(), equalTo(exchangeSinkExec)); + try (PlanStreamInput psi = new PlanStreamInput(out.bytes().streamInput(), getNamedWriteableRegistry(), configuration())) { + assertThat(psi.readNamedWriteable(PhysicalPlan.class), equalTo(exchangeSinkExec)); } } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/MvExpandExecSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/MvExpandExecSerializationTests.java new file mode 100644 index 0000000000000..ce159eae909f2 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/MvExpandExecSerializationTests.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plan.physical; + +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.NamedExpression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.FieldAttributeTests; +import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests; + +import java.io.IOException; + +public class MvExpandExecSerializationTests extends AbstractPhysicalPlanSerializationTests { + public static MvExpandExec randomMvExpandExec(int depth) { + Source source = randomSource(); + PhysicalPlan child = randomChild(depth); + NamedExpression target = FieldAttributeTests.createFieldAttribute(0, false); + Attribute expanded = ReferenceAttributeTests.randomReferenceAttribute(false); + return new MvExpandExec(source, child, target, expanded); + } + + @Override + protected MvExpandExec createTestInstance() { + return randomMvExpandExec(0); + } + + @Override + protected MvExpandExec mutateInstance(MvExpandExec instance) throws IOException { + PhysicalPlan child = instance.child(); + NamedExpression target = instance.target(); + Attribute expanded = instance.expanded(); + switch (between(0, 2)) { + case 0 -> child = randomValueOtherThan(child, () -> randomChild(0)); + case 1 -> target = randomValueOtherThan(target, () -> FieldAttributeTests.createFieldAttribute(0, false)); + case 2 -> expanded = randomValueOtherThan(expanded, () -> ReferenceAttributeTests.randomReferenceAttribute(false)); + default -> throw new UnsupportedOperationException(); + } + return new MvExpandExec(instance.source(), child, target, expanded); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/OrderExecSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/OrderExecSerializationTests.java new file mode 100644 index 0000000000000..755f1cd4f52da --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/OrderExecSerializationTests.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plan.physical; + +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.Order; +import org.elasticsearch.xpack.esql.expression.OrderSerializationTests; + +import java.io.IOException; +import java.util.List; + +public class OrderExecSerializationTests extends AbstractPhysicalPlanSerializationTests { + public static OrderExec randomOrderExec(int depth) { + Source source = randomSource(); + PhysicalPlan child = randomChild(depth); + List order = randomList(1, 10, OrderSerializationTests::randomOrder); + return new OrderExec(source, child, order); + } + + @Override + protected OrderExec createTestInstance() { + return randomOrderExec(0); + } + + @Override + protected OrderExec mutateInstance(OrderExec instance) throws IOException { + PhysicalPlan child = instance.child(); + List order = instance.order(); + if (randomBoolean()) { + child = randomValueOtherThan(child, () -> randomChild(0)); + } else { + order = randomValueOtherThan(order, () -> randomList(1, 10, OrderSerializationTests::randomOrder)); + } + return new OrderExec(instance.source(), child, order); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/ProjectExecSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/ProjectExecSerializationTests.java new file mode 100644 index 0000000000000..7196424ddc482 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/ProjectExecSerializationTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plan.physical; + +import org.elasticsearch.xpack.esql.core.expression.NamedExpression; +import org.elasticsearch.xpack.esql.core.tree.Source; + +import java.io.IOException; +import java.util.List; + +public class ProjectExecSerializationTests extends AbstractPhysicalPlanSerializationTests { + public static ProjectExec randomProjectExec(int depth) { + Source source = randomSource(); + PhysicalPlan child = randomChild(depth); + List projections = randomFieldAttributes(0, 10, false); + return new ProjectExec(source, child, projections); + } + + @Override + protected ProjectExec createTestInstance() { + return randomProjectExec(0); + } + + @Override + protected ProjectExec mutateInstance(ProjectExec instance) throws IOException { + PhysicalPlan child = instance.child(); + List projections = instance.projections(); + if (randomBoolean()) { + child = randomValueOtherThan(child, () -> randomChild(0)); + } else { + projections = randomValueOtherThan(projections, () -> randomFieldAttributes(0, 10, false)); + } + return new ProjectExec(instance.source(), child, projections); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/RowExecSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/RowExecSerializationTests.java new file mode 100644 index 0000000000000..3dd44cd20e369 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/RowExecSerializationTests.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plan.physical; + +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.LiteralTests; +import org.elasticsearch.xpack.esql.core.expression.NameId; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.tree.SourceTests; + +import java.io.IOException; +import java.util.List; + +public class RowExecSerializationTests extends AbstractPhysicalPlanSerializationTests { + public static RowExec randomRowExec() { + Source source = randomSource(); + List fields = randomList(1, 10, RowExecSerializationTests::randomAlias); + return new RowExec(source, fields); + } + + private static Alias randomAlias() { + Source source = SourceTests.randomSource(); + String name = randomAlphaOfLength(5); + Expression child = LiteralTests.randomLiteral(); + boolean synthetic = randomBoolean(); + return new Alias(source, name, child, new NameId(), synthetic); + } + + @Override + protected RowExec createTestInstance() { + return randomRowExec(); + } + + @Override + protected RowExec mutateInstance(RowExec instance) throws IOException { + List fields = instance.fields(); + fields = randomValueOtherThan(fields, () -> randomList(1, 10, RowExecSerializationTests::randomAlias)); + return new RowExec(instance.source(), fields); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/ShowExecSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/ShowExecSerializationTests.java new file mode 100644 index 0000000000000..7bac909578e56 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/ShowExecSerializationTests.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plan.physical; + +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral; + +public class ShowExecSerializationTests extends AbstractPhysicalPlanSerializationTests { + public static ShowExec randomShowExec() { + Source source = randomSource(); + List attributes = randomList(1, 10, () -> ReferenceAttributeTests.randomReferenceAttribute(true)); + List> values = randomValues(attributes); + return new ShowExec(source, attributes, values); + } + + private static List> randomValues(List attributes) { + int size = between(0, 1000); + List> result = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + List row = new ArrayList<>(attributes.size()); + for (Attribute a : attributes) { + row.add(randomLiteral(a.dataType()).value()); + } + result.add(row); + } + return result; + } + + @Override + protected ShowExec createTestInstance() { + return randomShowExec(); + } + + @Override + protected ShowExec mutateInstance(ShowExec instance) throws IOException { + List attributes = instance.output(); + List> values = instance.values(); + if (randomBoolean()) { + attributes = randomValueOtherThan( + attributes, + () -> randomList(1, 10, () -> ReferenceAttributeTests.randomReferenceAttribute(true)) + ); + } + List finalAttributes = attributes; + values = randomValueOtherThan(values, () -> randomValues(finalAttributes)); + return new ShowExec(instance.source(), attributes, values); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/TopNExecSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/TopNExecSerializationTests.java new file mode 100644 index 0000000000000..9606079e2f698 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/TopNExecSerializationTests.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plan.physical; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.Order; +import org.elasticsearch.xpack.esql.expression.OrderSerializationTests; + +import java.io.IOException; +import java.util.List; + +public class TopNExecSerializationTests extends AbstractPhysicalPlanSerializationTests { + public static TopNExec randomTopNExec(int depth) { + Source source = randomSource(); + PhysicalPlan child = randomChild(depth); + List order = randomList(1, 10, OrderSerializationTests::randomOrder); + Expression limit = new Literal(randomSource(), randomNonNegativeInt(), DataType.INTEGER); + Integer estimatedRowSize = randomEstimatedRowSize(); + return new TopNExec(source, child, order, limit, estimatedRowSize); + } + + @Override + protected TopNExec createTestInstance() { + return randomTopNExec(0); + } + + @Override + protected TopNExec mutateInstance(TopNExec instance) throws IOException { + PhysicalPlan child = instance.child(); + List order = instance.order(); + Expression limit = instance.limit(); + Integer estimatedRowSize = instance.estimatedRowSize(); + switch (between(0, 3)) { + case 0 -> child = randomValueOtherThan(child, () -> randomChild(0)); + case 1 -> order = randomValueOtherThan(order, () -> randomList(1, 10, OrderSerializationTests::randomOrder)); + case 2 -> limit = randomValueOtherThan(limit, () -> new Literal(randomSource(), randomNonNegativeInt(), DataType.INTEGER)); + case 3 -> estimatedRowSize = randomValueOtherThan( + estimatedRowSize, + AbstractPhysicalPlanSerializationTests::randomEstimatedRowSize + ); + default -> throw new UnsupportedOperationException(); + } + return new TopNExec(instance.source(), child, order, limit, estimatedRowSize); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/FilterTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/FilterTests.java index cccac8c6342a1..06fe05896a57c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/FilterTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/FilterTests.java @@ -28,7 +28,6 @@ import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.index.IndexResolution; -import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; @@ -284,16 +283,14 @@ public static QueryBuilder singleValueQuery(String query, QueryBuilder inner, St out.writeOptionalString(null); out.writeNamedWriteable(inner); out.writeString(field); - source.writeTo(new PlanStreamOutput(out, new PlanNameRegistry(), config)); + source.writeTo(new PlanStreamOutput(out, config)); StreamInput in = new NamedWriteableAwareStreamInput( ByteBufferStreamInput.wrap(BytesReference.toBytes(out.bytes())), SerializationTestUtils.writableRegistry() ); - Object obj = SingleValueQuery.ENTRY.reader.read( - new PlanStreamInput(in, new PlanNameRegistry(), in.namedWriteableRegistry(), config) - ); + Object obj = SingleValueQuery.ENTRY.reader.read(new PlanStreamInput(in, in.namedWriteableRegistry(), config)); return (QueryBuilder) obj; } catch (IOException e) { throw new UncheckedIOException(e); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/querydsl/query/SingleValueQuerySerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/querydsl/query/SingleValueQuerySerializationTests.java index 5ba03a3fd89fb..7dfda41a8e880 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/querydsl/query/SingleValueQuerySerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/querydsl/query/SingleValueQuerySerializationTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.test.AbstractWireTestCase; import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.session.Configuration; @@ -66,9 +65,9 @@ protected final SingleValueQuery.Builder copyInstance(SingleValueQuery.Builder i return copyInstance( instance, getNamedWriteableRegistry(), - (out, v) -> new PlanStreamOutput(out, new PlanNameRegistry(), config).writeNamedWriteable(v), + (out, v) -> new PlanStreamOutput(out, config).writeNamedWriteable(v), in -> { - PlanStreamInput pin = new PlanStreamInput(in, new PlanNameRegistry(), in.namedWriteableRegistry(), config); + PlanStreamInput pin = new PlanStreamInput(in, in.namedWriteableRegistry(), config); return (SingleValueQuery.Builder) pin.readNamedWriteable(QueryBuilder.class); }, version diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/AbstractEsFieldTypeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/AbstractEsFieldTypeTests.java index 9b2bf03b5c8aa..44a83f9c964c2 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/AbstractEsFieldTypeTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/AbstractEsFieldTypeTests.java @@ -15,7 +15,6 @@ import org.elasticsearch.test.AbstractWireTestCase; import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.core.type.EsField; -import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; @@ -45,15 +44,12 @@ public static EsField randomAnyEsField(int maxDepth) { @Override protected EsField copyInstance(EsField instance, TransportVersion version) throws IOException { NamedWriteableRegistry namedWriteableRegistry = getNamedWriteableRegistry(); - try ( - BytesStreamOutput output = new BytesStreamOutput(); - var pso = new PlanStreamOutput(output, new PlanNameRegistry(), EsqlTestUtils.TEST_CFG) - ) { + try (BytesStreamOutput output = new BytesStreamOutput(); var pso = new PlanStreamOutput(output, EsqlTestUtils.TEST_CFG)) { pso.setTransportVersion(version); instance.writeTo(pso); try ( - StreamInput in1 = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry); - var psi = new PlanStreamInput(in1, new PlanNameRegistry(), in1.namedWriteableRegistry(), EsqlTestUtils.TEST_CFG) + StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry); + var psi = new PlanStreamInput(in, in.namedWriteableRegistry(), EsqlTestUtils.TEST_CFG) ) { psi.setTransportVersion(version); return EsField.readFrom(psi); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/EsFieldTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/EsFieldTests.java index e824b4de03e26..18e0405a65892 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/EsFieldTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/EsFieldTests.java @@ -13,7 +13,7 @@ import java.util.Map; public class EsFieldTests extends AbstractEsFieldTypeTests { - static EsField randomEsField(int maxPropertiesDepth) { + public static EsField randomEsField(int maxPropertiesDepth) { String name = randomAlphaOfLength(4); DataType esDataType = randomFrom(DataType.types()); Map properties = randomProperties(maxPropertiesDepth); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/MultiTypeEsFieldTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/MultiTypeEsFieldTests.java index d4ca40b75d2f3..f533c20975aff 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/MultiTypeEsFieldTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/MultiTypeEsFieldTests.java @@ -30,7 +30,6 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToString; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToVersion; -import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.session.Configuration; @@ -101,16 +100,10 @@ protected final NamedWriteableRegistry getNamedWriteableRegistry() { @Override protected final MultiTypeEsField copyInstance(MultiTypeEsField instance, TransportVersion version) throws IOException { - return copyInstance( - instance, - getNamedWriteableRegistry(), - (out, v) -> v.writeTo(new PlanStreamOutput(out, new PlanNameRegistry(), config)), - in -> { - PlanStreamInput pin = new PlanStreamInput(in, new PlanNameRegistry(), in.namedWriteableRegistry(), config); - return EsField.readFrom(pin); - }, - version - ); + return copyInstance(instance, getNamedWriteableRegistry(), (out, v) -> v.writeTo(new PlanStreamOutput(out, config)), in -> { + PlanStreamInput pin = new PlanStreamInput(in, in.namedWriteableRegistry(), config); + return EsField.readFrom(pin); + }, version); } private static Map randomConvertExpressions(String name, boolean toString, DataType dataType) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java new file mode 100644 index 0000000000000..2273150d70ad6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java @@ -0,0 +1,133 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Processor that delegates the {@link java.util.concurrent.Flow.Subscription} to the upstream {@link java.util.concurrent.Flow.Publisher} + * and delegates event transmission to the downstream {@link java.util.concurrent.Flow.Subscriber}. + */ +public abstract class DelegatingProcessor implements Flow.Processor { + private static final Logger log = LogManager.getLogger(DelegatingProcessor.class); + private final AtomicLong pendingRequests = new AtomicLong(); + private final AtomicBoolean isClosed = new AtomicBoolean(false); + private Flow.Subscriber downstream; + private Flow.Subscription upstream; + + @Override + public void subscribe(Flow.Subscriber subscriber) { + if (downstream != null) { + subscriber.onError(new IllegalStateException("Another subscriber is already subscribed.")); + return; + } + + var subscription = forwardingSubscription(); + try { + downstream = subscriber; + downstream.onSubscribe(subscription); + } catch (Exception e) { + log.atDebug().withThrowable(e).log("Another publisher is already publishing to subscriber, canceling."); + subscription.cancel(); + downstream = null; + throw e; + } + } + + private Flow.Subscription forwardingSubscription() { + return new Flow.Subscription() { + @Override + public void request(long n) { + if (isClosed.get()) { + downstream.onComplete(); // shouldn't happen, but reinforce that we're no longer listening + } else if (upstream != null) { + upstream.request(n); + } else { + pendingRequests.accumulateAndGet(n, Long::sum); + } + } + + @Override + public void cancel() { + if (isClosed.compareAndSet(false, true) && upstream != null) { + upstream.cancel(); + } + } + }; + } + + @Override + public void onSubscribe(Flow.Subscription subscription) { + if (upstream != null) { + throw new IllegalStateException("Another upstream already exists. This subscriber can only subscribe to one publisher."); + } + + if (isClosed.get()) { + subscription.cancel(); + return; + } + + upstream = subscription; + var currentRequestCount = pendingRequests.getAndSet(0); + if (currentRequestCount != 0) { + upstream.request(currentRequestCount); + } + } + + @Override + public void onNext(T item) { + if (isClosed.get()) { + upstream.cancel(); + } else { + next(item); + } + } + + /** + * An {@link #onNext(Object)} that is only called when the stream is still open. + * Implementations can pass the resulting R object to the downstream subscriber via {@link #downstream()}, or the upstream can be + * accessed via {@link #upstream()}. + */ + protected abstract void next(T item); + + @Override + public void onError(Throwable throwable) { + if (isClosed.compareAndSet(false, true)) { + if (downstream != null) { + downstream.onError(throwable); + } else { + log.atDebug() + .withThrowable(throwable) + .log("onError was called before the downstream subscription, rethrowing to close listener."); + throw new IllegalStateException("onError was called before the downstream subscription", throwable); + } + } + } + + @Override + public void onComplete() { + if (isClosed.compareAndSet(false, true)) { + if (downstream != null) { + downstream.onComplete(); + } + } + } + + protected Flow.Subscriber downstream() { + return downstream; + } + + protected Flow.Subscription upstream() { + return upstream; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisher.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisher.java index 49a9048a69df1..aad262617d0b3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisher.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisher.java @@ -71,8 +71,13 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer subscriber.onNext(firstResponse)); + if (response.getEntity() == null || response.getEntity().getContentLength() <= 0) { + // on success, we may receive an empty content payload to initiate the stream + this.queue.offer(() -> subscriber.onNext(new HttpResult(response, new byte[0]))); + } else { + var firstResponse = HttpResult.create(settings.getMaxResponseSize(), response); + this.queue.offer(() -> subscriber.onNext(firstResponse)); + } this.listener.onResponse(this); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiChatCompletionResponseHandler.java index 7ca7cf0422fd9..37665b6228c8f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiChatCompletionResponseHandler.java @@ -7,10 +7,16 @@ package org.elasticsearch.xpack.inference.external.openai; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; import org.elasticsearch.xpack.inference.external.http.retry.RetryException; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; + +import java.util.concurrent.Flow; public class OpenAiChatCompletionResponseHandler extends OpenAiResponseHandler { public OpenAiChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { @@ -22,4 +28,19 @@ protected RetryException buildExceptionHandling429(Request request, HttpResult r // We don't retry, if the chat completion input is too large return new RetryException(false, buildError(RATE_LIMIT, request, result)); } + + @Override + public boolean canHandleStreamingResponses() { + return true; + } + + @Override + public InferenceServiceResults parseResult(Request request, HttpResult result, Flow.Publisher flow) { + var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); + var openAiProcessor = new OpenAiStreamingProcessor(); + + flow.subscribe(serverSentEventProcessor); + serverSentEventProcessor.subscribe(openAiProcessor); + return new StreamingChatCompletionResults(openAiProcessor); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java new file mode 100644 index 0000000000000..dcda832091e05 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java @@ -0,0 +1,203 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.openai; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Iterator; +import java.util.Optional; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +/** + * Parses the OpenAI chat completion streaming responses. + * For a request like: + * + *
    + *     
    + *         {
    + *             "inputs": ["Please summarize this text: some text", "Answer the following question: Question"]
    + *         }
    + *     
    + * 
    + * + * The response would look like: + * + *
    + *     
    + *         {
    + *              "id": "chatcmpl-123",
    + *              "object": "chat.completion",
    + *              "created": 1677652288,
    + *              "model": "gpt-3.5-turbo-0613",
    + *              "system_fingerprint": "fp_44709d6fcb",
    + *              "choices": [
    + *                  {
    + *                      "index": 0,
    + *                      "delta": {
    + *                          "content": "\n\nHello there, how ",
    + *                      },
    + *                      "finish_reason": ""
    + *                  }
    + *              ]
    + *          }
    + *
    + *         {
    + *              "id": "chatcmpl-123",
    + *              "object": "chat.completion",
    + *              "created": 1677652288,
    + *              "model": "gpt-3.5-turbo-0613",
    + *              "system_fingerprint": "fp_44709d6fcb",
    + *              "choices": [
    + *                  {
    + *                      "index": 1,
    + *                      "delta": {
    + *                          "content": "may I assist you today?",
    + *                      },
    + *                      "finish_reason": ""
    + *                  }
    + *              ]
    + *          }
    + *
    + *         {
    + *              "id": "chatcmpl-123",
    + *              "object": "chat.completion",
    + *              "created": 1677652288,
    + *              "model": "gpt-3.5-turbo-0613",
    + *              "system_fingerprint": "fp_44709d6fcb",
    + *              "choices": [
    + *                  {
    + *                      "index": 2,
    + *                      "delta": {},
    + *                      "finish_reason": "stop"
    + *                  }
    + *              ]
    + *          }
    + *
    + *          [DONE]
    + *     
    + * 
    + */ +public class OpenAiStreamingProcessor extends DelegatingProcessor, ChunkedToXContent> { + private static final Logger log = LogManager.getLogger(OpenAiStreamingProcessor.class); + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in OpenAI chat completions response"; + private static final String RESULT = "delta"; + + private static final String CHOICES_FIELD = "choices"; + private static final String DELTA_FIELD = "delta"; + private static final String CONTENT_FIELD = "content"; + private static final String FINISH_REASON_FIELD = "finish_reason"; + private static final String STOP_MESSAGE = "stop"; + private static final String DONE_MESSAGE = "[done]"; + + @Override + protected void next(Deque item) { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + var results = new ArrayDeque(item.size()); + for (ServerSentEvent event : item) { + if (ServerSentEventField.DATA == event.name() && event.hasValue()) { + try { + var delta = parse(parserConfig, event); + delta.map(this::deltaChunk).ifPresent(results::offer); + } catch (Exception e) { + log.warn("Failed to parse event from inference provider: {}", event); + onError(new IOException("Failed to parse event from inference provider.", e)); + return; + } + } + } + + if (results.isEmpty()) { + upstream().request(1); + } else { + downstream().onNext(completionChunk(results.iterator())); + } + } + + private Optional parse(XContentParserConfiguration parserConfig, ServerSentEvent event) throws IOException { + if (DONE_MESSAGE.equalsIgnoreCase(event.value())) { + return Optional.empty(); + } + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + // choices is an array, but since we don't send 'n' in the request then we only get one value in the result + positionParserAtTokenAfterField(jsonParser, CHOICES_FIELD, FAILED_TO_FIND_FIELD_TEMPLATE); + + jsonParser.nextToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, jsonParser.currentToken(), jsonParser); + + positionParserAtTokenAfterField(jsonParser, DELTA_FIELD, FAILED_TO_FIND_FIELD_TEMPLATE); + + token = jsonParser.currentToken(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + while (token != null) { + if (token == XContentParser.Token.FIELD_NAME && jsonParser.currentName().equals(CONTENT_FIELD)) { + jsonParser.nextToken(); + var contentToken = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.VALUE_STRING, contentToken, jsonParser); + return Optional.ofNullable(jsonParser.text()); + } else if (token == XContentParser.Token.FIELD_NAME && jsonParser.currentName().equals(FINISH_REASON_FIELD)) { + jsonParser.nextToken(); + var contentToken = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.VALUE_STRING, contentToken, jsonParser); + if (STOP_MESSAGE.equalsIgnoreCase(jsonParser.text())) { + return Optional.empty(); + } + } + token = jsonParser.nextToken(); + } + + throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, CONTENT_FIELD)); + } + } + + private ChunkedToXContent deltaChunk(String delta) { + return params -> Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.field(RESULT, delta), + ChunkedToXContentHelper.endObject() + ); + } + + private ChunkedToXContent completionChunk(Iterator delta) { + return params -> Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.startArray(COMPLETION), + Iterators.flatMap(delta, d -> d.toXContentChunked(params)), + ChunkedToXContentHelper.endArray(), + ChunkedToXContentHelper.endObject() + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEvent.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEvent.java new file mode 100644 index 0000000000000..915072f43651a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEvent.java @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.streaming; + +/** + * Server-Sent Event message: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation + * Messages always contain a {@link ServerSentEventField} and a non-null payload value. + * When the stream is parsed and there is no value associated with a {@link ServerSentEventField}, an empty-string is set as the value. + */ +public record ServerSentEvent(ServerSentEventField name, String value) { + + private static final String EMPTY = ""; + + public ServerSentEvent(ServerSentEventField name) { + this(name, EMPTY); + } + + // treat null value as an empty string, don't break parsing + public ServerSentEvent(ServerSentEventField name, String value) { + this.name = name; + this.value = value != null ? value : EMPTY; + } + + public boolean hasValue() { + return value.isBlank() == false; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventField.java new file mode 100644 index 0000000000000..eabe8248ad0fd --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventField.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.streaming; + +import java.util.Arrays; +import java.util.Locale; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * The named Server-Sent Event fields: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation + * Unnamed fields are not recognized and ignored. + */ +public enum ServerSentEventField { + EVENT, + DATA, + ID, + RETRY; + + private static final Set possibleValues = Arrays.stream(values()) + .map(Enum::name) + .map(name -> name.toLowerCase(Locale.ROOT)) + .collect(Collectors.toSet()); + + static Optional oneOf(String name) { + if (name != null && possibleValues.contains(name.toLowerCase(Locale.ROOT))) { + return Optional.of(valueOf(name.toUpperCase(Locale.ROOT))); + } else { + return Optional.empty(); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventParser.java new file mode 100644 index 0000000000000..9856a116f17d2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventParser.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.streaming; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Optional; +import java.util.regex.Pattern; + +/** + * https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation + * Lines are separated by LF, CR, or CRLF. + * If the line is empty, we do not dispatch the event since we do that automatically. Instead, we discard this event. + * If the line starts with a colon, we discard this event. + * If the line contains a colon, we process it into {@link ServerSentEvent} with a non-empty value. + * If the line does not contain a colon, we process it into {@link ServerSentEvent}with an empty string value. + * If the line's field is not one of {@link ServerSentEventField}, we discard this event. + */ +public class ServerSentEventParser { + private static final Pattern END_OF_LINE_REGEX = Pattern.compile("\\n|\\r|\\r\\n"); + private static final String BOM = "\uFEFF"; + private volatile String previousTokens = ""; + + public Deque parse(byte[] bytes) { + if (bytes == null || bytes.length == 0) { + return new ArrayDeque<>(0); + } + + var body = previousTokens + new String(bytes, StandardCharsets.UTF_8); + var lines = END_OF_LINE_REGEX.split(body, -1); // -1 because we actually want trailing empty strings + + var collector = new ArrayDeque(lines.length); + for (var i = 0; i < lines.length - 1; i++) { + var line = lines[i].replace(BOM, ""); + + if (line.isBlank() == false && line.startsWith(":") == false) { + if (line.contains(":")) { + fieldValueEvent(line).ifPresent(collector::offer); + } else { + ServerSentEventField.oneOf(line).map(ServerSentEvent::new).ifPresent(collector::offer); + } + } + } + + // we can sometimes get bytes for incomplete messages, so we save them for the next onNext invocation + // if we get an onComplete before we clear this cache, we follow the spec to treat it as an incomplete event and discard it since + // it was not followed by a blank line + previousTokens = lines[lines.length - 1]; + return collector; + } + + private Optional fieldValueEvent(String lineWithColon) { + var firstColon = lineWithColon.indexOf(":"); + var fieldStr = lineWithColon.substring(0, firstColon); + var serverSentField = ServerSentEventField.oneOf(fieldStr); + + if ((firstColon + 1) != lineWithColon.length()) { + var value = lineWithColon.substring(firstColon + 1); + if (value.equals(" ") == false) { + var trimmedValue = value.charAt(0) == ' ' ? value.substring(1) : value; + return serverSentField.map(field -> new ServerSentEvent(field, trimmedValue)); + } + } + + // if we have "data:" or "data: ", treat it like a no-value line + return serverSentField.map(ServerSentEvent::new); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventProcessor.java new file mode 100644 index 0000000000000..a028ca1224f35 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventProcessor.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.streaming; + +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.util.Deque; + +public class ServerSentEventProcessor extends DelegatingProcessor> { + private final ServerSentEventParser serverSentEventParser; + + public ServerSentEventProcessor(ServerSentEventParser serverSentEventParser) { + this.serverSentEventParser = serverSentEventParser; + } + + @Override + public void next(HttpResult item) { + if (item.isBodyEmpty()) { + // discard empty result and go to the next + upstream().request(1); + return; + } + + var response = serverSentEventParser.parse(item.body()); + if (response.isEmpty()) { + // discard empty result and go to the next + upstream().request(1); + return; + } + + downstream().onNext(response); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 1577fbc4a642a..81dfba769136b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Tuple; +import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; @@ -80,13 +81,16 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper { public static final String CONTENT_TYPE = "semantic_text"; + private final IndexSettings indexSettings; + public static final TypeParser PARSER = new TypeParser( - (n, c) -> new Builder(n, c.indexVersionCreated(), c::bitSetProducer), + (n, c) -> new Builder(n, c.indexVersionCreated(), c::bitSetProducer, c.getIndexSettings()), List.of(notInMultiFields(CONTENT_TYPE), notFromDynamicTemplates(CONTENT_TYPE)) ); public static class Builder extends FieldMapper.Builder { private final IndexVersion indexVersionCreated; + private final IndexSettings indexSettings; private final Parameter inferenceId = Parameter.stringParam( "inference_id", @@ -113,10 +117,22 @@ public static class Builder extends FieldMapper.Builder { private Function inferenceFieldBuilder; - public Builder(String name, IndexVersion indexVersionCreated, Function bitSetProducer) { + public Builder( + String name, + IndexVersion indexVersionCreated, + Function bitSetProducer, + IndexSettings indexSettings + ) { super(name); this.indexVersionCreated = indexVersionCreated; - this.inferenceFieldBuilder = c -> createInferenceField(c, indexVersionCreated, modelSettings.get(), bitSetProducer); + this.indexSettings = indexSettings; + this.inferenceFieldBuilder = c -> createInferenceField( + c, + indexVersionCreated, + modelSettings.get(), + bitSetProducer, + indexSettings + ); } public Builder setInferenceId(String id) { @@ -170,13 +186,20 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { indexVersionCreated, meta.getValue() ), - builderParams(this, context) + builderParams(this, context), + indexSettings ); } } - private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, BuilderParams builderParams) { + private SemanticTextFieldMapper( + String simpleName, + MappedFieldType mappedFieldType, + BuilderParams builderParams, + IndexSettings indexSettings + ) { super(simpleName, mappedFieldType, builderParams); + this.indexSettings = indexSettings; } @Override @@ -188,7 +211,9 @@ public Iterator iterator() { @Override public FieldMapper.Builder getMergeBuilder() { - return new Builder(leafName(), fieldType().indexVersionCreated, fieldType().getChunksField().bitsetProducer()).init(this); + return new Builder(leafName(), fieldType().indexVersionCreated, fieldType().getChunksField().bitsetProducer(), indexSettings).init( + this + ); } @Override @@ -229,7 +254,8 @@ protected void parseCreateField(DocumentParserContext context) throws IOExceptio Builder builder = (Builder) new Builder( leafName(), fieldType().indexVersionCreated, - fieldType().getChunksField().bitsetProducer() + fieldType().getChunksField().bitsetProducer(), + indexSettings ).init(this); try { mapper = builder.setModelSettings(field.inference().modelSettings()) @@ -473,19 +499,26 @@ private static ObjectMapper createInferenceField( MapperBuilderContext context, IndexVersion indexVersionCreated, @Nullable SemanticTextField.ModelSettings modelSettings, - Function bitSetProducer + Function bitSetProducer, + IndexSettings indexSettings ) { return new ObjectMapper.Builder(INFERENCE_FIELD, Optional.of(ObjectMapper.Subobjects.ENABLED)).dynamic(ObjectMapper.Dynamic.FALSE) - .add(createChunksField(indexVersionCreated, modelSettings, bitSetProducer)) + .add(createChunksField(indexVersionCreated, modelSettings, bitSetProducer, indexSettings)) .build(context); } private static NestedObjectMapper.Builder createChunksField( IndexVersion indexVersionCreated, @Nullable SemanticTextField.ModelSettings modelSettings, - Function bitSetProducer + Function bitSetProducer, + IndexSettings indexSettings ) { - NestedObjectMapper.Builder chunksField = new NestedObjectMapper.Builder(CHUNKS_FIELD, indexVersionCreated, bitSetProducer); + NestedObjectMapper.Builder chunksField = new NestedObjectMapper.Builder( + CHUNKS_FIELD, + indexVersionCreated, + bitSetProducer, + indexSettings + ); chunksField.dynamic(ObjectMapper.Dynamic.FALSE); KeywordFieldMapper.Builder chunkTextField = new KeywordFieldMapper.Builder(CHUNKED_TEXT_FIELD, indexVersionCreated).indexed(false) .docValues(false); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/DelegatingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/DelegatingProcessorTests.java new file mode 100644 index 0000000000000..826eaf6cd6860 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/DelegatingProcessorTests.java @@ -0,0 +1,218 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common; + +import org.elasticsearch.test.ESTestCase; +import org.mockito.ArgumentCaptor; + +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class DelegatingProcessorTests extends ESTestCase { + + public static R onNext(DelegatingProcessor processor, T item) { + var response = new AtomicReference(); + + processor.onSubscribe(mock()); + + Flow.Subscriber downstream = mock(); + doAnswer(ans -> { + response.set(ans.getArgument(0)); + return null; + }).when(downstream).onNext(any()); + processor.subscribe(downstream); + + processor.onNext(item); + assertThat("Response from processor was null", response.get(), notNullValue()); + return response.get(); + } + + public static Throwable onError(DelegatingProcessor processor, T item) { + var response = new AtomicReference(); + + processor.onSubscribe(mock()); + + Flow.Subscriber downstream = mock(); + doAnswer(ans -> { + response.set(ans.getArgument(0)); + return null; + }).when(downstream).onError(any()); + processor.subscribe(downstream); + + processor.onNext(item); + assertThat("Error from processor was null", response.get(), notNullValue()); + return response.get(); + } + + public void testRequestBeforeOnSubscribe() { + var processor = delegatingProcessor(); + var expectedRequestCount = randomLongBetween(2, 100); + + Flow.Subscriber downstream = mock(); + processor.subscribe(downstream); + + var subscription = ArgumentCaptor.forClass(Flow.Subscription.class); + verify(downstream, times(1)).onSubscribe(subscription.capture()); + subscription.getValue().request(expectedRequestCount); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + verify(upstream, times(1)).request(eq(expectedRequestCount)); + } + + public void testRequestAfterOnSubscribe() { + var processor = delegatingProcessor(); + var expectedRequestCount = randomLongBetween(2, 100); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + verify(upstream, never()).request(anyInt()); + + Flow.Subscriber downstream = mock(); + processor.subscribe(downstream); + + var subscription = ArgumentCaptor.forClass(Flow.Subscription.class); + verify(downstream, times(1)).onSubscribe(subscription.capture()); + + subscription.getValue().request(expectedRequestCount); + verify(upstream, times(1)).request(eq(expectedRequestCount)); + } + + public void testOnNextAfterCancelDoesNotForwardItem() { + var expectedItem = "hello"; + + var processor = delegatingProcessor(); + processor.onSubscribe(mock()); + + Flow.Subscriber downstream = mock(); + doAnswer(ans -> { + Flow.Subscription sub = ans.getArgument(0); + sub.cancel(); + return null; + }).when(downstream).onSubscribe(any()); + processor.subscribe(downstream); + + processor.onNext(expectedItem); + + verify(downstream, never()).onNext(any()); + } + + public void testCancelForwardsToUpstream() { + var processor = delegatingProcessor(); + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + + Flow.Subscriber downstream = mock(); + doAnswer(ans -> { + Flow.Subscription sub = ans.getArgument(0); + sub.cancel(); + return null; + }).when(downstream).onSubscribe(any()); + processor.subscribe(downstream); + + verify(upstream, times(1)).cancel(); + } + + public void testRequestForwardsToUpstream() { + var expectedRequestCount = randomLongBetween(2, 20); + var processor = delegatingProcessor(); + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + + Flow.Subscriber downstream = mock(); + doAnswer(ans -> { + Flow.Subscription sub = ans.getArgument(0); + sub.request(expectedRequestCount); + return null; + }).when(downstream).onSubscribe(any()); + processor.subscribe(downstream); + + verify(upstream, times(1)).request(expectedRequestCount); + } + + public void testOnErrorBeforeSubscriptionThrowsException() { + assertThrows(IllegalStateException.class, () -> delegatingProcessor().onError(new NullPointerException())); + } + + public void testOnError() { + var expectedException = new IllegalStateException("hello"); + + var processor = delegatingProcessor(); + + Flow.Subscriber downstream = mock(); + processor.subscribe(downstream); + + processor.onError(expectedException); + + verify(downstream, times(1)).onError(eq(expectedException)); + } + + public void testOnCompleteBeforeSubscriptionInvokesOnComplete() { + var processor = delegatingProcessor(); + + Flow.Subscriber downstream = mock(); + doAnswer(ans -> { + Flow.Subscription sub = ans.getArgument(0); + sub.request(1); + return null; + }).when(downstream).onSubscribe(any()); + + processor.onComplete(); + verify(downstream, times(0)).onComplete(); + + processor.subscribe(downstream); + verify(downstream, times(1)).onComplete(); + } + + public void testOnComplete() { + var processor = delegatingProcessor(); + + Flow.Subscriber downstream = mock(); + processor.subscribe(downstream); + processor.onComplete(); + + verify(downstream, times(1)).onComplete(); + } + + public void testSubscriberOnlyAllowsOnePublisher() { + var publisher1 = delegatingProcessor(); + var publisher2 = delegatingProcessor(); + var subscriber1 = spy(delegatingProcessor()); + + publisher1.subscribe(subscriber1); + verify(subscriber1, times(1)).onSubscribe(any()); + + // verify we cannot reuse subscribers + assertThrows(IllegalStateException.class, () -> publisher2.subscribe(subscriber1)); + + // verify publisher resets its subscriber + var subscriber2 = spy(delegatingProcessor()); + publisher2.subscribe(subscriber2); + verify(subscriber2, times(1)).onSubscribe(any()); + } + + private DelegatingProcessor delegatingProcessor() { + return new DelegatingProcessor<>() { + @Override + public void next(String item) { + downstream().onNext(item); + } + }; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisherTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisherTests.java index 92a332fe545e3..fd166493f09d0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisherTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisherTests.java @@ -7,16 +7,19 @@ package org.elasticsearch.xpack.inference.external.http; +import org.apache.http.HttpEntity; import org.apache.http.HttpResponse; import org.apache.http.nio.ContentDecoder; import org.apache.http.nio.IOControl; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.junit.Before; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; @@ -73,10 +76,7 @@ public void setUp() throws Exception { */ public void testFirstResponseCallsListener() throws IOException { var latch = new CountDownLatch(1); - var listener = ActionListener.>wrap( - r -> latch.countDown(), - e -> fail("Listener onFailure should never be called.") - ); + var listener = ActionTestUtils.>assertNoFailureListener(r -> latch.countDown()); publisher = new StreamingHttpResultPublisher(threadPool, settings, listener); publisher.responseReceived(mock(HttpResponse.class)); @@ -84,6 +84,46 @@ public void testFirstResponseCallsListener() throws IOException { assertThat("Listener's onResponse should be called when we receive a response", latch.getCount(), equalTo(0L)); } + /** + * When we receive an http response with an entity with no content + * Then we call the listener + * And we queue the initial payload + */ + public void testEmptyFirstResponseCallsListener() throws IOException { + var latch = new CountDownLatch(1); + var listener = ActionTestUtils.>assertNoFailureListener(r -> latch.countDown()); + publisher = new StreamingHttpResultPublisher(threadPool, settings, listener); + + var response = mock(HttpResponse.class); + var entity = mock(HttpEntity.class); + when(entity.getContentLength()).thenReturn(-1L); + when(response.getEntity()).thenReturn(entity); + publisher.responseReceived(response); + + assertThat("Listener's onResponse should be called when we receive a response", latch.getCount(), equalTo(0L)); + } + + /** + * When we receive an http response with an entity with content + * Then we call the listener + * And we queue the initial payload + */ + public void testNonEmptyFirstResponseCallsListener() throws IOException { + var latch = new CountDownLatch(1); + var listener = ActionTestUtils.>assertNoFailureListener(r -> latch.countDown()); + publisher = new StreamingHttpResultPublisher(threadPool, settings, listener); + + when(settings.getMaxResponseSize()).thenReturn(ByteSizeValue.ofBytes(9000)); + var response = mock(HttpResponse.class); + var entity = mock(HttpEntity.class); + when(entity.getContentLength()).thenReturn(5L); + when(entity.getContent()).thenReturn(new ByteArrayInputStream(message)); + when(response.getEntity()).thenReturn(entity); + publisher.responseReceived(response); + + assertThat("Listener's onResponse should be called when we receive a response", latch.getCount(), equalTo(0L)); + } + /** * This test combines 4 test since it's easier to verify the exchange of data at once. * diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessorTests.java new file mode 100644 index 0000000000000..992990a476a0c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessorTests.java @@ -0,0 +1,167 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.openai; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.concurrent.Flow; + +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.elasticsearch.xpack.inference.common.DelegatingProcessorTests.onError; +import static org.elasticsearch.xpack.inference.common.DelegatingProcessorTests.onNext; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class OpenAiStreamingProcessorTests extends ESTestCase { + public void testParseOpenAiResponse() throws IOException { + var item = new ArrayDeque(); + item.offer(new ServerSentEvent(ServerSentEventField.DATA, """ + { + "id":"12345", + "object":"chat.completion.chunk", + "created":123456789, + "model":"gpt-4o-mini", + "system_fingerprint": "123456789", + "choices":[ + { + "index":0, + "delta":{ + "content":"test" + }, + "logprobs":null, + "finish_reason":null + } + ] + } + """)); + + var response = onNext(new OpenAiStreamingProcessor(), item); + var json = toJsonString(response); + + assertThat(json, equalTo(""" + {"completion":[{"delta":"test"}]}""")); + } + + public void testParseWithFinish() throws IOException { + var item = new ArrayDeque(); + item.offer(new ServerSentEvent(ServerSentEventField.DATA, """ + { + "id":"12345", + "object":"chat.completion.chunk", + "created":123456789, + "model":"gpt-4o-mini", + "system_fingerprint": "123456789", + "choices":[ + { + "index":0, + "delta":{ + "content":"hello, world" + }, + "logprobs":null, + "finish_reason":null + } + ] + } + """)); + item.offer(new ServerSentEvent(ServerSentEventField.DATA, """ + { + "id":"12345", + "object":"chat.completion.chunk", + "created":123456789, + "model":"gpt-4o-mini", + "system_fingerprint": "123456789", + "choices":[ + { + "index":1, + "delta":{}, + "logprobs":null, + "finish_reason":"stop" + } + ] + } + """)); + + var response = onNext(new OpenAiStreamingProcessor(), item); + var json = toJsonString(response); + + assertThat(json, equalTo(""" + {"completion":[{"delta":"hello, world"}]}""")); + } + + public void testParseErrorCallsOnError() { + var item = new ArrayDeque(); + item.offer(new ServerSentEvent(ServerSentEventField.DATA, "this isn't json")); + + var exception = onError(new OpenAiStreamingProcessor(), item); + assertThat(exception, instanceOf(IOException.class)); + assertThat(exception.getMessage(), equalTo("Failed to parse event from inference provider.")); + assertThat(exception.getCause(), instanceOf(XContentParseException.class)); + } + + public void testEmptyResultsRequestsMoreData() { + var emptyDeque = new ArrayDeque(); + + var processor = new OpenAiStreamingProcessor(); + + Flow.Subscriber downstream = mock(); + processor.subscribe(downstream); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + + processor.next(emptyDeque); + + verify(upstream, times(1)).request(1); + verify(downstream, times(0)).onNext(any()); + } + + public void testDoneMessageIsIgnored() { + var item = new ArrayDeque(); + item.offer(new ServerSentEvent(ServerSentEventField.DATA, "[DONE]")); + + var processor = new OpenAiStreamingProcessor(); + + Flow.Subscriber downstream = mock(); + processor.subscribe(downstream); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + + processor.next(item); + + verify(upstream, times(1)).request(1); + verify(downstream, times(0)).onNext(any()); + } + + private String toJsonString(ChunkedToXContent chunkedToXContent) throws IOException { + try (var builder = XContentFactory.jsonBuilder()) { + chunkedToXContent.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException e) { + logger.error(e.getMessage(), e); + fail(e.getMessage()); + } + }); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventParserTests.java new file mode 100644 index 0000000000000..863d2e3e07c5f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventParserTests.java @@ -0,0 +1,125 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.streaming; + +import org.elasticsearch.test.ESTestCase; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Deque; +import java.util.List; +import java.util.Locale; +import java.util.stream.Collectors; + +import static org.hamcrest.Matchers.equalTo; + +public class ServerSentEventParserTests extends ESTestCase { + public void testParseEvents() { + var payload = (Arrays.stream(ServerSentEventField.values()) + .map(ServerSentEventField::name) + .map(name -> name.toLowerCase(Locale.ROOT)) + .collect(Collectors.joining("\n")) + + "\n").getBytes(StandardCharsets.UTF_8); + + var parser = new ServerSentEventParser(); + var events = parser.parse(payload); + + assertThat(events.size(), equalTo(ServerSentEventField.values().length)); + } + + public void testParseDataEventsWithAllEndOfLines() { + var payload = """ + event: message\n\ + data: test\n\ + \n\ + event: message\r\ + data: test2\r\ + \r\ + event: message\r\n\ + data: test3\r\n\ + \r\n\ + """.getBytes(StandardCharsets.UTF_8); + + var parser = new ServerSentEventParser(); + var events = parser.parse(payload); + + assertEvents( + events, + List.of( + new ServerSentEvent(ServerSentEventField.EVENT, "message"), + new ServerSentEvent(ServerSentEventField.DATA, "test"), + new ServerSentEvent(ServerSentEventField.EVENT, "message"), + new ServerSentEvent(ServerSentEventField.DATA, "test2"), + new ServerSentEvent(ServerSentEventField.EVENT, "message"), + new ServerSentEvent(ServerSentEventField.DATA, "test3") + ) + ); + } + + private void assertEvents(Deque actualEvents, List expectedEvents) { + assertThat(actualEvents.size(), equalTo(expectedEvents.size())); + var expectedEvent = expectedEvents.iterator(); + actualEvents.forEach(event -> assertThat(event, equalTo(expectedEvent.next()))); + } + + // by default, Java's UTF-8 decode does not remove the byte order mark + public void testByteOrderMarkIsRemoved() { + // these are the bytes for "event: message\n\n" + var payload = new byte[] { -17, -69, -65, 101, 118, 101, 110, 116, 58, 32, 109, 101, 115, 115, 97, 103, 101, 10, 10 }; + + var parser = new ServerSentEventParser(); + var events = parser.parse(payload); + + assertEvents(events, List.of(new ServerSentEvent(ServerSentEventField.EVENT, "message"))); + } + + public void testEmptyEventIsSetAsEmptyString() { + var payload = """ + event: + event:\s + + """.getBytes(StandardCharsets.UTF_8); + + var parser = new ServerSentEventParser(); + var events = parser.parse(payload); + + assertEvents( + events, + List.of(new ServerSentEvent(ServerSentEventField.EVENT, ""), new ServerSentEvent(ServerSentEventField.EVENT, "")) + ); + } + + public void testCommentsAreIgnored() { + var parser = new ServerSentEventParser(); + + var events = parser.parse(""" + :some cool comment + :event: message + + """.getBytes(StandardCharsets.UTF_8)); + + assertThat(events.isEmpty(), equalTo(true)); + } + + public void testCarryOverBytes() { + var parser = new ServerSentEventParser(); + + var events = parser.parse(""" + event: message + data""".getBytes(StandardCharsets.UTF_8)); // no newline after 'data' so the parser won't split the message up + + assertEvents(events, List.of(new ServerSentEvent(ServerSentEventField.EVENT, "message"))); + + events = parser.parse(""" + :test + + """.getBytes(StandardCharsets.UTF_8)); + + assertEvents(events, List.of(new ServerSentEvent(ServerSentEventField.DATA, "test"))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventProcessorTests.java new file mode 100644 index 0000000000000..0a0712c69cc3f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventProcessorTests.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.streaming; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.concurrent.Flow; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ServerSentEventProcessorTests extends ESTestCase { + + public void testEmptyBody() { + var processor = new ServerSentEventProcessor(mock()); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + + Flow.Subscriber> downstream = mock(); + processor.subscribe(downstream); + + processor.next(new HttpResult(mock(), new byte[0])); + verify(upstream, times(1)).request(1); + verify(downstream, times(0)).onNext(any()); + } + + public void testEmptyParseResponse() { + ServerSentEventParser parser = mock(); + when(parser.parse(any())).thenReturn(new ArrayDeque<>()); + + var processor = new ServerSentEventProcessor(parser); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + + Flow.Subscriber> downstream = mock(); + processor.subscribe(downstream); + + processor.next(new HttpResult(mock(), "hello".getBytes(StandardCharsets.UTF_8))); + verify(upstream, times(1)).request(1); + verify(downstream, times(0)).onNext(any()); + } + + public void testResponse() { + ServerSentEventParser parser = mock(); + var deque = new ArrayDeque(); + deque.offer(new ServerSentEvent(ServerSentEventField.EVENT, "hello")); + when(parser.parse(any())).thenReturn(deque); + + var processor = new ServerSentEventProcessor(parser); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + + Flow.Subscriber> downstream = mock(); + processor.subscribe(downstream); + + processor.next(new HttpResult(mock(), "hello".getBytes(StandardCharsets.UTF_8))); + verify(upstream, times(0)).request(anyLong()); + verify(downstream, times(1)).onNext(eq(deque)); + } +} diff --git a/x-pack/plugin/security/qa/service-account/src/javaRestTest/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountIT.java b/x-pack/plugin/security/qa/service-account/src/javaRestTest/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountIT.java index c1686a500fb2c..595d48ea92a44 100644 --- a/x-pack/plugin/security/qa/service-account/src/javaRestTest/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountIT.java +++ b/x-pack/plugin/security/qa/service-account/src/javaRestTest/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountIT.java @@ -112,7 +112,8 @@ public class ServiceAccountIT extends ESRestTestCase { "cluster": [ "monitor", "manage_own_api_key", - "read_fleet_secrets" + "read_fleet_secrets", + "cluster:admin/xpack/connector/*" ], "indices": [ { @@ -284,6 +285,35 @@ public class ServiceAccountIT extends ESRestTestCase { "auto_configure" ], "allow_restricted_indices": false + }, + { + "names": [ + ".elastic-connectors*" + ], + "privileges": [ + "read", + "write", + "monitor", + "create_index", + "auto_configure", + "maintenance" + ], + "allow_restricted_indices": false + }, + { + "names": [ + "content-*", + ".search-acl-filter-*" + ], + "privileges": [ + "read", + "write", + "monitor", + "create_index", + "auto_configure", + "maintenance" + ], + "allow_restricted_indices": false } ], "applications": [ { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/ElasticServiceAccounts.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/ElasticServiceAccounts.java index b62ce28422a9c..baa920eee275b 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/ElasticServiceAccounts.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/ElasticServiceAccounts.java @@ -81,7 +81,7 @@ final class ElasticServiceAccounts { "fleet-server", new RoleDescriptor( NAMESPACE + "/fleet-server", - new String[] { "monitor", "manage_own_api_key", "read_fleet_secrets" }, + new String[] { "monitor", "manage_own_api_key", "read_fleet_secrets", "cluster:admin/xpack/connector/*" }, new RoleDescriptor.IndicesPrivileges[] { RoleDescriptor.IndicesPrivileges.builder() .indices( @@ -156,7 +156,17 @@ final class ElasticServiceAccounts { // Fleet Server needs "read" privilege to be able to retrieve multi-agent docs .privileges("read", "write", "create_index", "auto_configure") .allowRestrictedIndices(false) - .build() }, + .build(), + // Custom permissions required for running Elastic connectors integration + RoleDescriptor.IndicesPrivileges.builder() + .indices(".elastic-connectors*") + .privileges("read", "write", "monitor", "create_index", "auto_configure", "maintenance") + .build(), + // Permissions for data indices and access control filters used by Elastic connectors integration + RoleDescriptor.IndicesPrivileges.builder() + .indices("content-*", ".search-acl-filter-*") + .privileges("read", "write", "monitor", "create_index", "auto_configure", "maintenance") + .build(), }, new RoleDescriptor.ApplicationResourcePrivileges[] { RoleDescriptor.ApplicationResourcePrivileges.builder() .application("kibana-*")