Skip to content

Commit

Permalink
Replace KnnQueryVector by KnnFloatVectorQuery (#767)
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski authored Feb 17, 2023
1 parent b8f2deb commit 71d7232
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import lombok.NonNull;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
Expand Down Expand Up @@ -73,13 +73,13 @@ public static Query create(CreateQueryRequest createQueryRequest) {
);
try {
final Query filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext);
return new KnnVectorQuery(fieldName, vector, k, filterQuery);
return new KnnFloatVectorQuery(fieldName, vector, k, filterQuery);
} catch (IOException e) {
throw new RuntimeException("Cannot create knn query with filter", e);
}
}
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
return new KnnVectorQuery(fieldName, vector, k);
return new KnnFloatVectorQuery(fieldName, vector, k);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package org.opensearch.knn.index.query;

import com.google.common.collect.ImmutableMap;
import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.cluster.ClusterModule;
Expand Down Expand Up @@ -174,7 +174,7 @@ public void testDoToQuery_KnnQueryWithFilter() throws Exception {
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
Query query = knnQueryBuilder.doToQuery(mockQueryShardContext);
assertNotNull(query);
assertTrue(query instanceof KnnVectorQuery);
assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class));
}

public void testDoToQuery_FromModel() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.knn.index.query;

import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.QueryBuilder;
Expand Down Expand Up @@ -47,7 +47,7 @@ public void testCreateLuceneDefaultQuery() {
.collect(Collectors.toList());
for (KNNEngine knnEngine : luceneDefaultQueryEngineList) {
Query query = KNNQueryFactory.create(knnEngine, testIndexName, testFieldName, testQueryVector, testK);
assertTrue(query instanceof KnnVectorQuery);
assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class));
}
}

Expand All @@ -70,7 +70,7 @@ public void testCreateLuceneQueryWithFilter() {
.filter(filter)
.build();
Query query = KNNQueryFactory.create(createQueryRequest);
assertTrue(query instanceof KnnVectorQuery);
assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class));
}
}
}

0 comments on commit 71d7232

Please sign in to comment.