diff --git a/src/main/java/org/opensearch/knn/training/VectorReader.java b/src/main/java/org/opensearch/knn/training/VectorReader.java index 6392ecabe..9b7db6d99 100644 --- a/src/main/java/org/opensearch/knn/training/VectorReader.java +++ b/src/main/java/org/opensearch/knn/training/VectorReader.java @@ -58,8 +58,15 @@ public VectorReader(Client client) { * @param vectorConsumer consumer used to do something with the collected vectors after each search * @param listener ActionListener that should be called once all search operations complete */ - public void read(ClusterService clusterService, String indexName, String fieldName, int maxVectorCount, - int searchSize, Consumer> vectorConsumer, ActionListener listener) { + public void read( + ClusterService clusterService, + String indexName, + String fieldName, + int maxVectorCount, + int searchSize, + Consumer> vectorConsumer, + ActionListener listener + ) { ValidationException validationException = null; @@ -94,11 +101,17 @@ public void read(ClusterService clusterService, String indexName, String fieldNa // Start reading vectors from index SearchScrollRequestBuilder searchScrollRequestBuilder = createSearchScrollRequestBuilder(); - ActionListener vectorReaderListener = new VectorReaderListener(client, fieldName, - maxVectorCount, 0, listener, vectorConsumer, searchScrollRequestBuilder); - - createSearchRequestBuilder(indexName, fieldName, Integer.min(maxVectorCount, searchSize)) - .execute(vectorReaderListener); + ActionListener vectorReaderListener = new VectorReaderListener( + client, + fieldName, + maxVectorCount, + 0, + listener, + vectorConsumer, + searchScrollRequestBuilder + ); + + createSearchRequestBuilder(indexName, fieldName, Integer.min(maxVectorCount, searchSize)).execute(vectorReaderListener); } private SearchRequestBuilder createSearchRequestBuilder(String indexName, String fieldName, int resultSize) { @@ -142,9 +155,15 @@ private static class VectorReaderListener implements ActionListener listener, Consumer> vectorConsumer, - SearchScrollRequestBuilder searchScrollRequestBuilder) { + public VectorReaderListener( + Client client, + String fieldName, + int maxVectorCount, + int collectedVectorCount, + ActionListener listener, + Consumer> vectorConsumer, + SearchScrollRequestBuilder searchScrollRequestBuilder + ) { this.client = client; this.fieldName = fieldName; this.maxVectorCount = maxVectorCount; @@ -154,7 +173,6 @@ public VectorReaderListener(Client client, String fieldName, int maxVectorCount, this.searchScrollRequestBuilder = searchScrollRequestBuilder; } - @Override @SuppressWarnings("unchecked") public void onResponse(SearchResponse searchResponse) { @@ -165,9 +183,9 @@ public void onResponse(SearchResponse searchResponse) { List trainingData = new ArrayList<>(); for (int i = 0; i < vectorsToAdd; i++) { - trainingData.add(((List) hits[i].getSourceAsMap().get(fieldName)).stream() - .map(Double::floatValue) - .toArray(Float[]::new)); + trainingData.add( + ((List) hits[i].getSourceAsMap().get(fieldName)).stream().map(Number::floatValue).toArray(Float[]::new) + ); } this.collectedVectorCount += trainingData.size(); @@ -180,10 +198,9 @@ public void onResponse(SearchResponse searchResponse) { String scrollId = searchResponse.getScrollId(); if (scrollId != null) { - client.prepareClearScroll().addScrollId(scrollId).execute(ActionListener.wrap( - clearScrollResponse -> listener.onResponse(searchResponse), - listener::onFailure) - ); + client.prepareClearScroll() + .addScrollId(scrollId) + .execute(ActionListener.wrap(clearScrollResponse -> listener.onResponse(searchResponse), listener::onFailure)); } else { listener.onResponse(searchResponse); } @@ -201,10 +218,9 @@ public void onFailure(Exception e) { String scrollId = searchScrollRequestBuilder.request().scrollId(); if (scrollId != null) { - client.prepareClearScroll().addScrollId(scrollId).execute(ActionListener.wrap( - clearScrollResponse -> listener.onFailure(e), - listener::onFailure) - ); + client.prepareClearScroll() + .addScrollId(scrollId) + .execute(ActionListener.wrap(clearScrollResponse -> listener.onFailure(e), listener::onFailure)); } else { listener.onFailure(e); } diff --git a/src/test/java/org/opensearch/knn/training/VectorReaderTests.java b/src/test/java/org/opensearch/knn/training/VectorReaderTests.java index 2b88603d9..0b1e15289 100644 --- a/src/test/java/org/opensearch/knn/training/VectorReaderTests.java +++ b/src/test/java/org/opensearch/knn/training/VectorReaderTests.java @@ -11,9 +11,8 @@ package org.opensearch.knn.training; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchResponse; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.ValidationException; import org.opensearch.knn.KNNSingleNodeTestCase; @@ -32,32 +31,65 @@ public class VectorReaderTests extends KNNSingleNodeTestCase { - public static Logger logger = LogManager.getLogger(VectorReaderTests.class); + private final static int DEFAULT_LATCH_TIMEOUT = 100; + private final static String DEFAULT_INDEX_NAME = "test-index"; + private final static String DEFAULT_FIELD_NAME = "test-field"; + private final static int DEFAULT_DIMENSION = 16; + private final static int DEFAULT_NUM_VECTORS = 100; + private final static int DEFAULT_MAX_VECTOR_COUNT = 10000; + private final static int DEFAULT_SEARCH_SIZE = 10; public void testRead_valid_completeIndex() throws InterruptedException, ExecutionException, IOException { - // Create an index with knn disabled - String indexName = "test-index"; - String fieldName = "test-field"; - int dim = 16; - int numVectors = 100; - createIndex(indexName); - - // Add a field mapping to the index - createKnnIndexMapping(indexName, fieldName, dim); + createIndex(DEFAULT_INDEX_NAME); + createKnnIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_DIMENSION); // Create list of random vectors and ingest Random random = new Random(); List vectors = new ArrayList<>(); - for (int i = 0; i < numVectors; i++) { - Float[] vector = new Float[dim]; + for (int i = 0; i < DEFAULT_NUM_VECTORS; i++) { + Float[] vector = random.doubles(DEFAULT_DIMENSION).boxed().map(Double::floatValue).toArray(Float[]::new); + vectors.add(vector); + addKnnDoc(DEFAULT_INDEX_NAME, Integer.toString(i), DEFAULT_FIELD_NAME, vector); + } - for (int j = 0; j < dim; j++) { - vector[j] = random.nextFloat(); - } + // Configure VectorReader + ClusterService clusterService = node().injector().getInstance(ClusterService.class); + VectorReader vectorReader = new VectorReader(client()); - vectors.add(vector); + // Read all vectors and confirm they match vectors + TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, + testVectorConsumer, + createOnSearchResponseCountDownListener(inProgressLatch) + ); + + assertLatchDecremented(inProgressLatch); + + List consumedVectors = testVectorConsumer.getVectorsConsumed(); + assertEquals(DEFAULT_NUM_VECTORS, consumedVectors.size()); + + List flatVectors = vectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); + List flatConsumedVectors = consumedVectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); + assertEquals(new HashSet<>(flatVectors), new HashSet<>(flatConsumedVectors)); + } - addKnnDoc(indexName, Integer.toString(i), fieldName, vector); + public void testRead_valid_trainVectorsIngestedAsIntegers() throws IOException, ExecutionException, InterruptedException { + createIndex(DEFAULT_INDEX_NAME); + createKnnIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_DIMENSION); + + // Create list of random vectors and ingest + Random random = new Random(); + List vectors = new ArrayList<>(); + for (int i = 0; i < DEFAULT_NUM_VECTORS; i++) { + Integer[] vector = random.ints(DEFAULT_DIMENSION).boxed().toArray(Integer[]::new); + vectors.add(vector); + addKnnDoc(DEFAULT_INDEX_NAME, Integer.toString(i), DEFAULT_FIELD_NAME, vector); } // Configure VectorReader @@ -66,16 +98,23 @@ public void testRead_valid_completeIndex() throws InterruptedException, Executio // Read all vectors and confirm they match vectors TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); - final CountDownLatch inProgressLatch1 = new CountDownLatch(1); - vectorReader.read(clusterService, indexName, fieldName, 10000, 10, testVectorConsumer, - ActionListener.wrap(response -> inProgressLatch1.countDown(), e -> fail(e.toString()))); - - assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, + testVectorConsumer, + createOnSearchResponseCountDownListener(inProgressLatch) + ); + + assertLatchDecremented(inProgressLatch); List consumedVectors = testVectorConsumer.getVectorsConsumed(); - assertEquals(numVectors, consumedVectors.size()); + assertEquals(DEFAULT_NUM_VECTORS, consumedVectors.size()); - List flatVectors = vectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); + List flatVectors = vectors.stream().flatMap(Arrays::stream).map(Integer::floatValue).collect(Collectors.toList()); List flatConsumedVectors = consumedVectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); assertEquals(new HashSet<>(flatVectors), new HashSet<>(flatConsumedVectors)); } @@ -83,35 +122,25 @@ public void testRead_valid_completeIndex() throws InterruptedException, Executio public void testRead_valid_incompleteIndex() throws InterruptedException, ExecutionException, IOException { // Check if we get the right number of vectors if the index contains docs that are missing fields // Create an index with knn disabled - String indexName = "test-index"; - String fieldName = "test-field"; - int dim = 16; - int numVectors = 100; - createIndex(indexName); + createIndex(DEFAULT_INDEX_NAME); // Add a field mapping to the index - createKnnIndexMapping(indexName, fieldName, dim); + createKnnIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_DIMENSION); // Create list of random vectors and ingest Random random = new Random(); List vectors = new ArrayList<>(); - for (int i = 0; i < numVectors; i++) { - Float[] vector = new Float[dim]; - - for (int j = 0; j < dim; j++) { - vector[j] = random.nextFloat(); - } - + for (int i = 0; i < DEFAULT_NUM_VECTORS; i++) { + Float[] vector = random.doubles(DEFAULT_DIMENSION).boxed().map(Double::floatValue).toArray(Float[]::new); vectors.add(vector); - - addKnnDoc(indexName, Integer.toString(i ), fieldName, vector); + addKnnDoc(DEFAULT_INDEX_NAME, Integer.toString(i), DEFAULT_FIELD_NAME, vector); } // Create documents that do not have fieldName for training int docsWithoutKNN = 100; String fieldNameWithoutKnn = "test-field-2"; for (int i = 0; i < docsWithoutKNN; i++) { - addDoc(indexName, Integer.toString(i + numVectors), fieldNameWithoutKnn, "dummyValue"); + addDoc(DEFAULT_INDEX_NAME, Integer.toString(i + DEFAULT_NUM_VECTORS), fieldNameWithoutKnn, "dummyValue"); } // Configure VectorReader @@ -120,14 +149,21 @@ public void testRead_valid_incompleteIndex() throws InterruptedException, Execut // Read all vectors and confirm they match vectors TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); - final CountDownLatch inProgressLatch1 = new CountDownLatch(1); - vectorReader.read(clusterService, indexName, fieldName, 10000, 10, testVectorConsumer, - ActionListener.wrap(response -> inProgressLatch1.countDown(), e -> fail(e.toString()))); - - assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, + testVectorConsumer, + createOnSearchResponseCountDownListener(inProgressLatch) + ); + + assertLatchDecremented(inProgressLatch); List consumedVectors = testVectorConsumer.getVectorsConsumed(); - assertEquals(numVectors, consumedVectors.size()); + assertEquals(DEFAULT_NUM_VECTORS, consumedVectors.size()); List flatVectors = vectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); List flatConsumedVectors = consumedVectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); @@ -137,26 +173,17 @@ public void testRead_valid_incompleteIndex() throws InterruptedException, Execut public void testRead_valid_OnlyGetMaxVectors() throws InterruptedException, ExecutionException, IOException { // Check if we can limit the number of docs via max operation // Create an index with knn disabled - String indexName = "test-index"; - String fieldName = "test-field"; - int dim = 16; - int numVectorsIndex = 100; int maxNumVectorsRead = 20; - createIndex(indexName); + createIndex(DEFAULT_INDEX_NAME); // Add a field mapping to the index - createKnnIndexMapping(indexName, fieldName, dim); + createKnnIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_DIMENSION); // Create list of random vectors and ingest Random random = new Random(); - for (int i = 0; i < numVectorsIndex; i++) { - Float[] vector = new Float[dim]; - - for (int j = 0; j < dim; j++) { - vector[j] = random.nextFloat(); - } - - addKnnDoc(indexName, Integer.toString(i ), fieldName, vector); + for (int i = 0; i < DEFAULT_NUM_VECTORS; i++) { + Float[] vector = random.doubles(DEFAULT_DIMENSION).boxed().map(Double::floatValue).toArray(Float[]::new); + addKnnDoc(DEFAULT_INDEX_NAME, Integer.toString(i), DEFAULT_FIELD_NAME, vector); } // Configure VectorReader @@ -165,11 +192,18 @@ public void testRead_valid_OnlyGetMaxVectors() throws InterruptedException, Exec // Read maxNumVectorsRead vectors TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); - final CountDownLatch inProgressLatch1 = new CountDownLatch(1); - vectorReader.read(clusterService, indexName, fieldName, maxNumVectorsRead, 10, testVectorConsumer, - ActionListener.wrap(response -> inProgressLatch1.countDown(), e -> fail(e.toString()))); - - assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + maxNumVectorsRead, + DEFAULT_SEARCH_SIZE, + testVectorConsumer, + createOnSearchResponseCountDownListener(inProgressLatch) + ); + + assertLatchDecremented(inProgressLatch); List consumedVectors = testVectorConsumer.getVectorsConsumed(); assertEquals(maxNumVectorsRead, consumedVectors.size()); @@ -177,82 +211,138 @@ public void testRead_valid_OnlyGetMaxVectors() throws InterruptedException, Exec public void testRead_invalid_maxVectorCount() { // Create the index - String indexName = "test-index"; - String fieldName = "test-field"; - int dim = 16; - createIndex(indexName); + createIndex(DEFAULT_INDEX_NAME); // Add a field mapping to the index - createKnnIndexMapping(indexName, fieldName, dim); + createKnnIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_DIMENSION); // Configure VectorReader ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); - expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, -10, 10, null, null)); + int invalidMaxVectorCount = -10; + expectThrows( + ValidationException.class, + () -> vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + invalidMaxVectorCount, + DEFAULT_SEARCH_SIZE, + null, + null + ) + ); } public void testRead_invalid_searchSize() { // Create the index - String indexName = "test-index"; - String fieldName = "test-field"; - int dim = 16; - createIndex(indexName); + createIndex(DEFAULT_INDEX_NAME); // Add a field mapping to the index - createKnnIndexMapping(indexName, fieldName, dim); + createKnnIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_DIMENSION); // Configure VectorReader ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Search size is negative - expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 100, -10, null, null)); + int invalidSearchSize1 = -10; + expectThrows( + ValidationException.class, + () -> vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + invalidSearchSize1, + null, + null + ) + ); // Search size is greater than 10000 - expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 100, 20000, null, null)); + int invalidSearchSize2 = 20000; + expectThrows( + ValidationException.class, + () -> vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + invalidSearchSize2, + null, + null + ) + ); } public void testRead_invalid_indexDoesNotExist() { // Check that read throws a validation exception when the index does not exist - String indexName = "test-index"; - String fieldName = "test-field"; - // Configure VectorReader ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Should throw a validation exception because index does not exist - expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 10000, 10, null, null)); + expectThrows( + ValidationException.class, + () -> vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, + null, + null + ) + ); } public void testRead_invalid_fieldDoesNotExist() { // Check that read throws a validation exception when the field does not exist - String indexName = "test-index"; - String fieldName = "test-field"; - createIndex(indexName); + createIndex(DEFAULT_INDEX_NAME); // Configure VectorReader ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Should throw a validation exception because field is not k-NN - expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 10000, 10, null, null)); + expectThrows( + ValidationException.class, + () -> vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, + null, + null + ) + ); } public void testRead_invalid_fieldIsNotKnn() throws InterruptedException, ExecutionException, IOException { // Check that read throws a validation exception when the field does not exist - String indexName = "test-index"; - String fieldName = "test-field"; - createIndex(indexName); - addDoc(indexName, "test-id", fieldName, "dummy"); + createIndex(DEFAULT_INDEX_NAME); + addDoc(DEFAULT_INDEX_NAME, "test-id", DEFAULT_FIELD_NAME, "dummy"); // Configure VectorReader ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Should throw a validation exception because field does not exist - expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 10000, 10, null, null)); + expectThrows( + ValidationException.class, + () -> vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_FIELD_NAME, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, + null, + null + ) + ); } private static class TestVectorConsumer implements Consumer> { @@ -272,4 +362,12 @@ public List getVectorsConsumed() { return vectorsConsumed; } } -} \ No newline at end of file + + private void assertLatchDecremented(CountDownLatch countDownLatch) throws InterruptedException { + assertTrue(countDownLatch.await(DEFAULT_LATCH_TIMEOUT, TimeUnit.SECONDS)); + } + + private ActionListener createOnSearchResponseCountDownListener(CountDownLatch countDownLatch) { + return ActionListener.wrap(response -> countDownLatch.countDown(), Throwable::printStackTrace); + } +}