From 71d7232097cdd78c3b67cee47a5b5da41102388a Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 17 Feb 2023 09:23:34 -0800 Subject: [PATCH] Replace KnnQueryVector by KnnFloatVectorQuery (#767) Signed-off-by: Martin Gaievski --- .../org/opensearch/knn/index/query/KNNQueryFactory.java | 6 +++--- .../opensearch/knn/index/query/KNNQueryBuilderTests.java | 4 ++-- .../opensearch/knn/index/query/KNNQueryFactoryTests.java | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index c68ce9502..188bbc150 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -11,7 +11,7 @@ import lombok.NonNull; import lombok.Setter; import lombok.extern.log4j.Log4j2; -import org.apache.lucene.search.KnnVectorQuery; +import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; @@ -73,13 +73,13 @@ public static Query create(CreateQueryRequest createQueryRequest) { ); try { final Query filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext); - return new KnnVectorQuery(fieldName, vector, k, filterQuery); + return new KnnFloatVectorQuery(fieldName, vector, k, filterQuery); } catch (IOException e) { throw new RuntimeException("Cannot create knn query with filter", e); } } log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - return new KnnVectorQuery(fieldName, vector, k); + return new KnnFloatVectorQuery(fieldName, vector, k); } /** diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index a75c0b648..1a015f588 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -6,7 +6,7 @@ package org.opensearch.knn.index.query; import com.google.common.collect.ImmutableMap; -import org.apache.lucene.search.KnnVectorQuery; +import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.opensearch.Version; import org.opensearch.cluster.ClusterModule; @@ -174,7 +174,7 @@ public void testDoToQuery_KnnQueryWithFilter() throws Exception { when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); assertNotNull(query); - assertTrue(query instanceof KnnVectorQuery); + assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class)); } public void testDoToQuery_FromModel() { diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index 908ea1021..0f8f43bf2 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -5,7 +5,7 @@ package org.opensearch.knn.index.query; -import org.apache.lucene.search.KnnVectorQuery; +import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.query.QueryBuilder; @@ -47,7 +47,7 @@ public void testCreateLuceneDefaultQuery() { .collect(Collectors.toList()); for (KNNEngine knnEngine : luceneDefaultQueryEngineList) { Query query = KNNQueryFactory.create(knnEngine, testIndexName, testFieldName, testQueryVector, testK); - assertTrue(query instanceof KnnVectorQuery); + assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class)); } } @@ -70,7 +70,7 @@ public void testCreateLuceneQueryWithFilter() { .filter(filter) .build(); Query query = KNNQueryFactory.create(createQueryRequest); - assertTrue(query instanceof KnnVectorQuery); + assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class)); } } }