From cf52a6d08319f8aeeaa52cf55046af3b17fc6cf5 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Mon, 6 Jun 2022 12:07:05 -0400 Subject: [PATCH] [Backport 1.3] Change VectorReaderListener to expect number array (#420) Refactors VectorReaderListener onResponse to expect arrays of Number type from search result instead of Double type. Adds test case to confirm that it can handle Integer type. Cleans up tests in VectorReaderTest class. Signed-off-by: John Mazanec (cherry picked from commit 77353512c1f15e0dc996428a982941a7ee3036fb) --- .../opensearch/knn/training/VectorReader.java | 60 ++-- .../knn/training/VectorReaderTests.java | 286 ++++++++++++------ 2 files changed, 230 insertions(+), 116 deletions(-) diff --git a/src/main/java/org/opensearch/knn/training/VectorReader.java b/src/main/java/org/opensearch/knn/training/VectorReader.java index 6392ecabe..9b7db6d99 100644 --- a/src/main/java/org/opensearch/knn/training/VectorReader.java +++ b/src/main/java/org/opensearch/knn/training/VectorReader.java @@ -58,8 +58,15 @@ public VectorReader(Client client) { * @param vectorConsumer consumer used to do something with the collected vectors after each search * @param listener ActionListener that should be called once all search operations complete */ - public void read(ClusterService clusterService, String indexName, String fieldName, int maxVectorCount, - int searchSize, Consumer> vectorConsumer, ActionListener listener) { + public void read( + ClusterService clusterService, + String indexName, + String fieldName, + int maxVectorCount, + int searchSize, + Consumer> vectorConsumer, + ActionListener listener + ) { ValidationException validationException = null; @@ -94,11 +101,17 @@ public void read(ClusterService clusterService, String indexName, String fieldNa // Start reading vectors from index SearchScrollRequestBuilder searchScrollRequestBuilder = createSearchScrollRequestBuilder(); - ActionListener vectorReaderListener = new VectorReaderListener(client, fieldName, - maxVectorCount, 0, listener, vectorConsumer, searchScrollRequestBuilder); - - createSearchRequestBuilder(indexName, fieldName, Integer.min(maxVectorCount, searchSize)) - .execute(vectorReaderListener); + ActionListener vectorReaderListener = new VectorReaderListener( + client, + fieldName, + maxVectorCount, + 0, + listener, + vectorConsumer, + searchScrollRequestBuilder + ); + + createSearchRequestBuilder(indexName, fieldName, Integer.min(maxVectorCount, searchSize)).execute(vectorReaderListener); } private SearchRequestBuilder createSearchRequestBuilder(String indexName, String fieldName, int resultSize) { @@ -142,9 +155,15 @@ private static class VectorReaderListener implements ActionListener listener, Consumer> vectorConsumer, - SearchScrollRequestBuilder searchScrollRequestBuilder) { + public VectorReaderListener( + Client client, + String fieldName, + int maxVectorCount, + int collectedVectorCount, + ActionListener listener, + Consumer> vectorConsumer, + SearchScrollRequestBuilder searchScrollRequestBuilder + ) { this.client = client; this.fieldName = fieldName; this.maxVectorCount = maxVectorCount; @@ -154,7 +173,6 @@ public VectorReaderListener(Client client, String fieldName, int maxVectorCount, this.searchScrollRequestBuilder = searchScrollRequestBuilder; } - @Override @SuppressWarnings("unchecked") public void onResponse(SearchResponse searchResponse) { @@ -165,9 +183,9 @@ public void onResponse(SearchResponse searchResponse) { List trainingData = new ArrayList<>(); for (int i = 0; i < vectorsToAdd; i++) { - trainingData.add(((List) hits[i].getSourceAsMap().get(fieldName)).stream() - .map(Double::floatValue) - .toArray(Float[]::new)); + trainingData.add( + ((List) hits[i].getSourceAsMap().get(fieldName)).stream().map(Number::floatValue).toArray(Float[]::new) + ); } this.collectedVectorCount += trainingData.size(); @@ -180,10 +198,9 @@ public void onResponse(SearchResponse searchResponse) { String scrollId = searchResponse.getScrollId(); if (scrollId != null) { - client.prepareClearScroll().addScrollId(scrollId).execute(ActionListener.wrap( - clearScrollResponse -> listener.onResponse(searchResponse), - listener::onFailure) - ); + client.prepareClearScroll() + .addScrollId(scrollId) + .execute(ActionListener.wrap(clearScrollResponse -> listener.onResponse(searchResponse), listener::onFailure)); } else { listener.onResponse(searchResponse); } @@ -201,10 +218,9 @@ public void onFailure(Exception e) { String scrollId = searchScrollRequestBuilder.request().scrollId(); if (scrollId != null) { - client.prepareClearScroll().addScrollId(scrollId).execute(ActionListener.wrap( - clearScrollResponse -> listener.onFailure(e), - listener::onFailure) - ); + client.prepareClearScroll() + .addScrollId(scrollId) + .execute(ActionListener.wrap(clearScrollResponse -> listener.onFailure(e), listener::onFailure)); } else { listener.onFailure(e); } diff --git a/src/test/java/org/opensearch/knn/training/VectorReaderTests.java b/src/test/java/org/opensearch/knn/training/VectorReaderTests.java index 2b88603d9..0b1e15289 100644 --- a/src/test/java/org/opensearch/knn/training/VectorReaderTests.java +++ b/src/test/java/org/opensearch/knn/training/VectorReaderTests.java @@ -11,9 +11,8 @@ package org.opensearch.knn.training; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchResponse; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.ValidationException; import org.opensearch.knn.KNNSingleNodeTestCase; @@ -32,32 +31,65 @@ public class VectorReaderTests extends KNNSingleNodeTestCase { - public static Logger logger = LogManager.getLogger(VectorReaderTests.class); + private final static int DEFAULT_LATCH_TIMEOUT = 100; + private final static String DEFAULT_INDEX_NAME = "test-index"; + private final static String DEFAULT_FIELD_NAME = "test-field"; + private final static int DEFAULT_DIMENSION = 16; + private final static int DEFAULT_NUM_VECTORS = 100; + private final static int DEFAULT_MAX_VECTOR_COUNT = 10000; + private final static int DEFAULT_SEARCH_SIZE = 10; public void testRead_valid_completeIndex() throws InterruptedException, ExecutionException, IOException { - // Create an index with knn disabled - String indexName = "test-index"; - String fieldName = "test-field"; - int dim = 16; - int numVectors = 100; - createIndex(indexName); - - // Add a field mapping to the index - createKnnIndexMapping(indexName, fieldName, dim); + createIndex(DEFAULT_INDEX_NAME); + createKnnIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_DIMENSION); // Create list of random vectors and ingest Random random = new Random(); List vectors = new ArrayList<>(); - for (int i = 0; i < numVectors; i++) { - Float[] vector = new Float[dim]; + for (int i = 0; i < DEFAULT_NUM_VECTORS; i++) { + Float[] vector = random.doubles(DEFAULT_DIMENSION).boxed().map(Double::floatValue).toArray(Float[]::new); + vectors.add(vector); + addKnnDoc(DEFAULT_INDEX_NAME, Integer.toString(i), DEFAULT_FIELD_NAME, vector); + } - for (int j = 0; j < dim; j++) { - vector[j] = random.nextFloat(); - } + // Configure VectorReader + ClusterService clusterService = node().injector().getInstance(ClusterService.class); + VectorReader vectorReader = new VectorReader(client()); - vectors.add(vector); + // Read all vectors and confirm they match vectors + TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, + testVectorConsumer, + createOnSearchResponseCountDownListener(inProgressLatch) + ); + + assertLatchDecremented(inProgressLatch); + + List consumedVectors = testVectorConsumer.getVectorsConsumed(); + assertEquals(DEFAULT_NUM_VECTORS, consumedVectors.size()); + + List flatVectors = vectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); + List flatConsumedVectors = consumedVectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); + assertEquals(new HashSet<>(flatVectors), new HashSet<>(flatConsumedVectors)); + } - addKnnDoc(indexName, Integer.toString(i), fieldName, vector); + public void testRead_valid_trainVectorsIngestedAsIntegers() throws IOException, ExecutionException, InterruptedException { + createIndex(DEFAULT_INDEX_NAME); + createKnnIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_DIMENSION); + + // Create list of random vectors and ingest + Random random = new Random(); + List vectors = new ArrayList<>(); + for (int i = 0; i < DEFAULT_NUM_VECTORS; i++) { + Integer[] vector = random.ints(DEFAULT_DIMENSION).boxed().toArray(Integer[]::new); + vectors.add(vector); + addKnnDoc(DEFAULT_INDEX_NAME, Integer.toString(i), DEFAULT_FIELD_NAME, vector); } // Configure VectorReader @@ -66,16 +98,23 @@ public void testRead_valid_completeIndex() throws InterruptedException, Executio // Read all vectors and confirm they match vectors TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); - final CountDownLatch inProgressLatch1 = new CountDownLatch(1); - vectorReader.read(clusterService, indexName, fieldName, 10000, 10, testVectorConsumer, - ActionListener.wrap(response -> inProgressLatch1.countDown(), e -> fail(e.toString()))); - - assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, + testVectorConsumer, + createOnSearchResponseCountDownListener(inProgressLatch) + ); + + assertLatchDecremented(inProgressLatch); List consumedVectors = testVectorConsumer.getVectorsConsumed(); - assertEquals(numVectors, consumedVectors.size()); + assertEquals(DEFAULT_NUM_VECTORS, consumedVectors.size()); - List flatVectors = vectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); + List flatVectors = vectors.stream().flatMap(Arrays::stream).map(Integer::floatValue).collect(Collectors.toList()); List flatConsumedVectors = consumedVectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); assertEquals(new HashSet<>(flatVectors), new HashSet<>(flatConsumedVectors)); } @@ -83,35 +122,25 @@ public void testRead_valid_completeIndex() throws InterruptedException, Executio public void testRead_valid_incompleteIndex() throws InterruptedException, ExecutionException, IOException { // Check if we get the right number of vectors if the index contains docs that are missing fields // Create an index with knn disabled - String indexName = "test-index"; - String fieldName = "test-field"; - int dim = 16; - int numVectors = 100; - createIndex(indexName); + createIndex(DEFAULT_INDEX_NAME); // Add a field mapping to the index - createKnnIndexMapping(indexName, fieldName, dim); + createKnnIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_DIMENSION); // Create list of random vectors and ingest Random random = new Random(); List vectors = new ArrayList<>(); - for (int i = 0; i < numVectors; i++) { - Float[] vector = new Float[dim]; - - for (int j = 0; j < dim; j++) { - vector[j] = random.nextFloat(); - } - + for (int i = 0; i < DEFAULT_NUM_VECTORS; i++) { + Float[] vector = random.doubles(DEFAULT_DIMENSION).boxed().map(Double::floatValue).toArray(Float[]::new); vectors.add(vector); - - addKnnDoc(indexName, Integer.toString(i ), fieldName, vector); + addKnnDoc(DEFAULT_INDEX_NAME, Integer.toString(i), DEFAULT_FIELD_NAME, vector); } // Create documents that do not have fieldName for training int docsWithoutKNN = 100; String fieldNameWithoutKnn = "test-field-2"; for (int i = 0; i < docsWithoutKNN; i++) { - addDoc(indexName, Integer.toString(i + numVectors), fieldNameWithoutKnn, "dummyValue"); + addDoc(DEFAULT_INDEX_NAME, Integer.toString(i + DEFAULT_NUM_VECTORS), fieldNameWithoutKnn, "dummyValue"); } // Configure VectorReader @@ -120,14 +149,21 @@ public void testRead_valid_incompleteIndex() throws InterruptedException, Execut // Read all vectors and confirm they match vectors TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); - final CountDownLatch inProgressLatch1 = new CountDownLatch(1); - vectorReader.read(clusterService, indexName, fieldName, 10000, 10, testVectorConsumer, - ActionListener.wrap(response -> inProgressLatch1.countDown(), e -> fail(e.toString()))); - - assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, + testVectorConsumer, + createOnSearchResponseCountDownListener(inProgressLatch) + ); + + assertLatchDecremented(inProgressLatch); List consumedVectors = testVectorConsumer.getVectorsConsumed(); - assertEquals(numVectors, consumedVectors.size()); + assertEquals(DEFAULT_NUM_VECTORS, consumedVectors.size()); List flatVectors = vectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); List flatConsumedVectors = consumedVectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); @@ -137,26 +173,17 @@ public void testRead_valid_incompleteIndex() throws InterruptedException, Execut public void testRead_valid_OnlyGetMaxVectors() throws InterruptedException, ExecutionException, IOException { // Check if we can limit the number of docs via max operation // Create an index with knn disabled - String indexName = "test-index"; - String fieldName = "test-field"; - int dim = 16; - int numVectorsIndex = 100; int maxNumVectorsRead = 20; - createIndex(indexName); + createIndex(DEFAULT_INDEX_NAME); // Add a field mapping to the index - createKnnIndexMapping(indexName, fieldName, dim); + createKnnIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_DIMENSION); // Create list of random vectors and ingest Random random = new Random(); - for (int i = 0; i < numVectorsIndex; i++) { - Float[] vector = new Float[dim]; - - for (int j = 0; j < dim; j++) { - vector[j] = random.nextFloat(); - } - - addKnnDoc(indexName, Integer.toString(i ), fieldName, vector); + for (int i = 0; i < DEFAULT_NUM_VECTORS; i++) { + Float[] vector = random.doubles(DEFAULT_DIMENSION).boxed().map(Double::floatValue).toArray(Float[]::new); + addKnnDoc(DEFAULT_INDEX_NAME, Integer.toString(i), DEFAULT_FIELD_NAME, vector); } // Configure VectorReader @@ -165,11 +192,18 @@ public void testRead_valid_OnlyGetMaxVectors() throws InterruptedException, Exec // Read maxNumVectorsRead vectors TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); - final CountDownLatch inProgressLatch1 = new CountDownLatch(1); - vectorReader.read(clusterService, indexName, fieldName, maxNumVectorsRead, 10, testVectorConsumer, - ActionListener.wrap(response -> inProgressLatch1.countDown(), e -> fail(e.toString()))); - - assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + maxNumVectorsRead, + DEFAULT_SEARCH_SIZE, + testVectorConsumer, + createOnSearchResponseCountDownListener(inProgressLatch) + ); + + assertLatchDecremented(inProgressLatch); List consumedVectors = testVectorConsumer.getVectorsConsumed(); assertEquals(maxNumVectorsRead, consumedVectors.size()); @@ -177,82 +211,138 @@ public void testRead_valid_OnlyGetMaxVectors() throws InterruptedException, Exec public void testRead_invalid_maxVectorCount() { // Create the index - String indexName = "test-index"; - String fieldName = "test-field"; - int dim = 16; - createIndex(indexName); + createIndex(DEFAULT_INDEX_NAME); // Add a field mapping to the index - createKnnIndexMapping(indexName, fieldName, dim); + createKnnIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_DIMENSION); // Configure VectorReader ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); - expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, -10, 10, null, null)); + int invalidMaxVectorCount = -10; + expectThrows( + ValidationException.class, + () -> vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + invalidMaxVectorCount, + DEFAULT_SEARCH_SIZE, + null, + null + ) + ); } public void testRead_invalid_searchSize() { // Create the index - String indexName = "test-index"; - String fieldName = "test-field"; - int dim = 16; - createIndex(indexName); + createIndex(DEFAULT_INDEX_NAME); // Add a field mapping to the index - createKnnIndexMapping(indexName, fieldName, dim); + createKnnIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_DIMENSION); // Configure VectorReader ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Search size is negative - expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 100, -10, null, null)); + int invalidSearchSize1 = -10; + expectThrows( + ValidationException.class, + () -> vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + invalidSearchSize1, + null, + null + ) + ); // Search size is greater than 10000 - expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 100, 20000, null, null)); + int invalidSearchSize2 = 20000; + expectThrows( + ValidationException.class, + () -> vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + invalidSearchSize2, + null, + null + ) + ); } public void testRead_invalid_indexDoesNotExist() { // Check that read throws a validation exception when the index does not exist - String indexName = "test-index"; - String fieldName = "test-field"; - // Configure VectorReader ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Should throw a validation exception because index does not exist - expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 10000, 10, null, null)); + expectThrows( + ValidationException.class, + () -> vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, + null, + null + ) + ); } public void testRead_invalid_fieldDoesNotExist() { // Check that read throws a validation exception when the field does not exist - String indexName = "test-index"; - String fieldName = "test-field"; - createIndex(indexName); + createIndex(DEFAULT_INDEX_NAME); // Configure VectorReader ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Should throw a validation exception because field is not k-NN - expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 10000, 10, null, null)); + expectThrows( + ValidationException.class, + () -> vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, + null, + null + ) + ); } public void testRead_invalid_fieldIsNotKnn() throws InterruptedException, ExecutionException, IOException { // Check that read throws a validation exception when the field does not exist - String indexName = "test-index"; - String fieldName = "test-field"; - createIndex(indexName); - addDoc(indexName, "test-id", fieldName, "dummy"); + createIndex(DEFAULT_INDEX_NAME); + addDoc(DEFAULT_INDEX_NAME, "test-id", DEFAULT_FIELD_NAME, "dummy"); // Configure VectorReader ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Should throw a validation exception because field does not exist - expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 10000, 10, null, null)); + expectThrows( + ValidationException.class, + () -> vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, + null, + null + ) + ); } private static class TestVectorConsumer implements Consumer> { @@ -272,4 +362,12 @@ public List getVectorsConsumed() { return vectorsConsumed; } } -} \ No newline at end of file + + private void assertLatchDecremented(CountDownLatch countDownLatch) throws InterruptedException { + assertTrue(countDownLatch.await(DEFAULT_LATCH_TIMEOUT, TimeUnit.SECONDS)); + } + + private ActionListener createOnSearchResponseCountDownListener(CountDownLatch countDownLatch) { + return ActionListener.wrap(response -> countDownLatch.countDown(), Throwable::printStackTrace); + } +}