Skip to content

Commit

Permalink
[Backport 1.3] Change VectorReaderListener to expect number array (#420)
Browse files Browse the repository at this point in the history
Refactors VectorReaderListener onResponse to expect arrays of Number
type from search result instead of Double type. Adds test case to
confirm that it can handle Integer type. Cleans up tests in
VectorReaderTest class.

Signed-off-by: John Mazanec <jmazane@amazon.com>
(cherry picked from commit 7735351)
  • Loading branch information
jmazanec15 authored Jun 6, 2022
1 parent ebead98 commit cf52a6d
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 116 deletions.
60 changes: 38 additions & 22 deletions src/main/java/org/opensearch/knn/training/VectorReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,15 @@ public VectorReader(Client client) {
* @param vectorConsumer consumer used to do something with the collected vectors after each search
* @param listener ActionListener that should be called once all search operations complete
*/
public void read(ClusterService clusterService, String indexName, String fieldName, int maxVectorCount,
int searchSize, Consumer<List<Float[]>> vectorConsumer, ActionListener<SearchResponse> listener) {
public void read(
ClusterService clusterService,
String indexName,
String fieldName,
int maxVectorCount,
int searchSize,
Consumer<List<Float[]>> vectorConsumer,
ActionListener<SearchResponse> listener
) {

ValidationException validationException = null;

Expand Down Expand Up @@ -94,11 +101,17 @@ public void read(ClusterService clusterService, String indexName, String fieldNa
// Start reading vectors from index
SearchScrollRequestBuilder searchScrollRequestBuilder = createSearchScrollRequestBuilder();

ActionListener<SearchResponse> vectorReaderListener = new VectorReaderListener(client, fieldName,
maxVectorCount, 0, listener, vectorConsumer, searchScrollRequestBuilder);

createSearchRequestBuilder(indexName, fieldName, Integer.min(maxVectorCount, searchSize))
.execute(vectorReaderListener);
ActionListener<SearchResponse> vectorReaderListener = new VectorReaderListener(
client,
fieldName,
maxVectorCount,
0,
listener,
vectorConsumer,
searchScrollRequestBuilder
);

createSearchRequestBuilder(indexName, fieldName, Integer.min(maxVectorCount, searchSize)).execute(vectorReaderListener);
}

private SearchRequestBuilder createSearchRequestBuilder(String indexName, String fieldName, int resultSize) {
Expand Down Expand Up @@ -142,9 +155,15 @@ private static class VectorReaderListener implements ActionListener<SearchRespon
* @param vectorConsumer Consumer used to do something with the vectors
* @param searchScrollRequestBuilder Search scroll request builder used to get next set of vectors
*/
public VectorReaderListener(Client client, String fieldName, int maxVectorCount, int collectedVectorCount,
ActionListener<SearchResponse> listener, Consumer<List<Float[]>> vectorConsumer,
SearchScrollRequestBuilder searchScrollRequestBuilder) {
public VectorReaderListener(
Client client,
String fieldName,
int maxVectorCount,
int collectedVectorCount,
ActionListener<SearchResponse> listener,
Consumer<List<Float[]>> vectorConsumer,
SearchScrollRequestBuilder searchScrollRequestBuilder
) {
this.client = client;
this.fieldName = fieldName;
this.maxVectorCount = maxVectorCount;
Expand All @@ -154,7 +173,6 @@ public VectorReaderListener(Client client, String fieldName, int maxVectorCount,
this.searchScrollRequestBuilder = searchScrollRequestBuilder;
}


@Override
@SuppressWarnings("unchecked")
public void onResponse(SearchResponse searchResponse) {
Expand All @@ -165,9 +183,9 @@ public void onResponse(SearchResponse searchResponse) {
List<Float[]> trainingData = new ArrayList<>();

for (int i = 0; i < vectorsToAdd; i++) {
trainingData.add(((List<Double>) hits[i].getSourceAsMap().get(fieldName)).stream()
.map(Double::floatValue)
.toArray(Float[]::new));
trainingData.add(
((List<Number>) hits[i].getSourceAsMap().get(fieldName)).stream().map(Number::floatValue).toArray(Float[]::new)
);
}

this.collectedVectorCount += trainingData.size();
Expand All @@ -180,10 +198,9 @@ public void onResponse(SearchResponse searchResponse) {
String scrollId = searchResponse.getScrollId();

if (scrollId != null) {
client.prepareClearScroll().addScrollId(scrollId).execute(ActionListener.wrap(
clearScrollResponse -> listener.onResponse(searchResponse),
listener::onFailure)
);
client.prepareClearScroll()
.addScrollId(scrollId)
.execute(ActionListener.wrap(clearScrollResponse -> listener.onResponse(searchResponse), listener::onFailure));
} else {
listener.onResponse(searchResponse);
}
Expand All @@ -201,10 +218,9 @@ public void onFailure(Exception e) {
String scrollId = searchScrollRequestBuilder.request().scrollId();

if (scrollId != null) {
client.prepareClearScroll().addScrollId(scrollId).execute(ActionListener.wrap(
clearScrollResponse -> listener.onFailure(e),
listener::onFailure)
);
client.prepareClearScroll()
.addScrollId(scrollId)
.execute(ActionListener.wrap(clearScrollResponse -> listener.onFailure(e), listener::onFailure));
} else {
listener.onFailure(e);
}
Expand Down
Loading

0 comments on commit cf52a6d

Please sign in to comment.