From c94dc938fd8d75260a18b72972d2797e2b683103 Mon Sep 17 00:00:00 2001 From: Anush008 Date: Sun, 28 Apr 2024 11:50:38 +0530 Subject: [PATCH] refactor: QdrantVectorHandler --- .../io/qdrant/spark/QdrantVectorHandler.java | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/main/java/io/qdrant/spark/QdrantVectorHandler.java b/src/main/java/io/qdrant/spark/QdrantVectorHandler.java index e074751..e42e5d6 100644 --- a/src/main/java/io/qdrant/spark/QdrantVectorHandler.java +++ b/src/main/java/io/qdrant/spark/QdrantVectorHandler.java @@ -9,7 +9,6 @@ import io.qdrant.client.grpc.Points.Vectors; import java.util.Collections; import java.util.HashMap; -import java.util.List; import java.util.Map; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.StructType; @@ -26,8 +25,7 @@ public static Vectors prepareVectors( // Maitaining support for the "embedding_field" and "vector_name" options if (!options.embeddingField.isEmpty()) { - int embeddingFieldIndex = schema.fieldIndex(options.embeddingField); - float[] embeddings = record.getArray(embeddingFieldIndex).toFloatArray(); + float[] embeddings = extractFloatArray(record, schema, options.embeddingField); // 'options.vectorName' defaults to "" vectorsBuilder.mergeFrom( namedVectors(Collections.singletonMap(options.vectorName, vector(embeddings)))); @@ -42,9 +40,10 @@ private static Vectors prepareSparseVectors( for (int i = 0; i < options.sparseVectorNames.length; i++) { String name = options.sparseVectorNames[i]; - List values = extractFloatArray(record, schema, options.sparseVectorValueFields[i]); - List indices = extractIntArray(record, schema, options.sparseVectorIndexFields[i]); - sparseVectors.put(name, vector(values, indices)); + float[] values = extractFloatArray(record, schema, options.sparseVectorValueFields[i]); + int[] indices = extractIntArray(record, schema, options.sparseVectorIndexFields[i]); + + sparseVectors.put(name, vector(Floats.asList(values), Ints.asList(indices))); } return namedVectors(sparseVectors); @@ -56,22 +55,21 @@ private static Vectors prepareDenseVectors( for (int i = 0; i < options.vectorNames.length; i++) { String name = options.vectorNames[i]; - List values = extractFloatArray(record, schema, options.vectorFields[i]); + float[] values = extractFloatArray(record, schema, options.vectorFields[i]); denseVectors.put(name, vector(values)); } return namedVectors(denseVectors); } - private static List extractFloatArray( + private static float[] extractFloatArray( InternalRow record, StructType schema, String fieldName) { int fieldIndex = schema.fieldIndex(fieldName); - return Floats.asList(record.getArray(fieldIndex).toFloatArray()); + return record.getArray(fieldIndex).toFloatArray(); } - private static List extractIntArray( - InternalRow record, StructType schema, String fieldName) { + private static int[] extractIntArray(InternalRow record, StructType schema, String fieldName) { int fieldIndex = schema.fieldIndex(fieldName); - return Ints.asList(record.getArray(fieldIndex).toIntArray()); + return record.getArray(fieldIndex).toIntArray(); } }