-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: sparse, multiple vectors support
- Loading branch information
Showing
7 changed files
with
218 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,79 @@ | ||
package io.qdrant.spark; | ||
|
||
import java.io.Serializable; | ||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
/** This class represents the options for connecting to a Qdrant instance. */ | ||
public class QdrantOptions implements Serializable { | ||
public String qdrantUrl; | ||
public String apiKey; | ||
public String collectionName; | ||
public String embeddingField; | ||
public String idField; | ||
public String vectorName; | ||
public int batchSize = 100; | ||
|
||
public int batchSize = 64; | ||
public int retries = 3; | ||
|
||
// Should've named the option 'vectorField'. But too late now. | ||
public String embeddingField; | ||
public String vectorName; | ||
|
||
public String[] sparseVectorValueFields; | ||
public String[] sparseVectorIndexFields; | ||
public String[] sparseVectorNames; | ||
|
||
public String[] vectorFields; | ||
public String[] vectorNames; | ||
|
||
public List<String> payloadFieldsToSkip = new ArrayList<String>(); | ||
|
||
/** | ||
* Constructor for QdrantOptions. | ||
* | ||
* @param options A map of options for connecting to a Qdrant instance. | ||
*/ | ||
public QdrantOptions(Map<String, String> options) { | ||
this.qdrantUrl = options.get("qdrant_url"); | ||
this.collectionName = options.get("collection_name"); | ||
this.embeddingField = options.get("embedding_field"); | ||
this.idField = options.get("id_field"); | ||
this.apiKey = options.get("api_key"); | ||
this.vectorName = options.get("vector_name"); | ||
qdrantUrl = options.get("qdrant_url"); | ||
collectionName = options.get("collection_name"); | ||
embeddingField = options.getOrDefault("embedding_field", ""); | ||
idField = options.getOrDefault("id_field", ""); | ||
apiKey = options.getOrDefault("api_key", ""); | ||
vectorName = options.getOrDefault("vector_name", ""); | ||
sparseVectorValueFields = options.getOrDefault("sparse_vector_value_fields", "").split(","); | ||
sparseVectorIndexFields = options.getOrDefault("sparse_vector_index_fields", "").split(","); | ||
sparseVectorNames = options.getOrDefault("sparse_vector_names", "").split(","); | ||
vectorFields = options.getOrDefault("vector_fields", "").split(","); | ||
vectorNames = options.getOrDefault("vector_names", "").split(","); | ||
|
||
if (sparseVectorValueFields.length != sparseVectorIndexFields.length | ||
|| sparseVectorValueFields.length != sparseVectorNames.length) { | ||
throw new IllegalArgumentException( | ||
"Sparse vector value fields, index fields and names should be of same length"); | ||
} | ||
|
||
if (vectorFields.length != vectorNames.length) { | ||
throw new IllegalArgumentException("Vector fields and names should be of same length"); | ||
} | ||
|
||
if (options.containsKey("batch_size")) { | ||
this.batchSize = Integer.parseInt(options.get("batch_size")); | ||
batchSize = Integer.parseInt(options.get("batch_size")); | ||
} | ||
|
||
if (options.containsKey("retries")) { | ||
this.retries = Integer.parseInt(options.get("retries")); | ||
retries = Integer.parseInt(options.get("retries")); | ||
} | ||
|
||
payloadFieldsToSkip.add(idField); | ||
payloadFieldsToSkip.add(embeddingField); | ||
|
||
payloadFieldsToSkip.addAll(Arrays.asList(sparseVectorValueFields)); | ||
payloadFieldsToSkip.addAll(Arrays.asList(sparseVectorIndexFields)); | ||
payloadFieldsToSkip.addAll(Arrays.asList(sparseVectorNames)); | ||
|
||
payloadFieldsToSkip.addAll(Arrays.asList(vectorFields)); | ||
payloadFieldsToSkip.addAll(Arrays.asList(vectorNames)); | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
package io.qdrant.spark; | ||
|
||
import static io.qdrant.spark.QdrantValueFactory.value; | ||
|
||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import org.apache.spark.sql.catalyst.InternalRow; | ||
import org.apache.spark.sql.types.StructField; | ||
import org.apache.spark.sql.types.StructType; | ||
|
||
import io.qdrant.client.grpc.JsonWithInt.Value; | ||
|
||
public class QdrantPayloadHandler { | ||
static Map<String, Value> preparePayload(InternalRow record, StructType schema, QdrantOptions options) { | ||
|
||
Map<String, Value> payload = new HashMap<>(); | ||
for (StructField field : schema.fields()) { | ||
|
||
if (options.payloadFieldsToSkip.contains(field.name())) { | ||
continue; | ||
} | ||
int fieldIndex = schema.fieldIndex(field.name()); | ||
payload.put(field.name(), value(record, field, fieldIndex)); | ||
} | ||
|
||
return payload; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
package io.qdrant.spark; | ||
|
||
import static io.qdrant.client.PointIdFactory.id; | ||
|
||
import java.util.UUID; | ||
|
||
import org.apache.spark.sql.catalyst.InternalRow; | ||
import org.apache.spark.sql.types.DataType; | ||
import org.apache.spark.sql.types.StructType; | ||
|
||
import io.qdrant.client.grpc.Points.PointId; | ||
|
||
public class QdrantPointIdHandler { | ||
static PointId preparePointId(InternalRow record, StructType schema, QdrantOptions options) { | ||
String idField = options.idField; | ||
|
||
if (idField.isEmpty()) { | ||
return id(UUID.randomUUID()); | ||
} | ||
|
||
int idFieldIndex = schema.fieldIndex(idField.trim()); | ||
DataType idFieldType = schema.fields()[idFieldIndex].dataType(); | ||
switch (idFieldType.typeName()) { | ||
case "string": | ||
return id(UUID.fromString(record.getString(idFieldIndex))); | ||
|
||
case "integer": | ||
case "long": | ||
return id(record.getInt(idFieldIndex)); | ||
|
||
default: | ||
throw new IllegalArgumentException("Point ID should be of type string or integer"); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
package io.qdrant.spark; | ||
|
||
import static io.qdrant.client.VectorFactory.vector; | ||
import static io.qdrant.client.VectorsFactory.namedVectors; | ||
|
||
import java.util.Collections; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import org.apache.spark.sql.catalyst.InternalRow; | ||
import org.apache.spark.sql.types.StructType; | ||
|
||
import com.google.common.primitives.Floats; | ||
import com.google.common.primitives.Ints; | ||
|
||
import io.qdrant.client.grpc.Points.Vector; | ||
import io.qdrant.client.grpc.Points.Vectors; | ||
|
||
public class QdrantVectorHandler { | ||
static Vectors prepareVectors(InternalRow record, StructType schema, QdrantOptions options) { | ||
|
||
Vectors.Builder vectorsBuilder = Vectors.newBuilder(); | ||
Vectors sparseVectors = prepareSparseVectors(record, schema, options); | ||
Vectors denseVectors = prepareDenseVectors(record, schema, options); | ||
|
||
vectorsBuilder.mergeFrom(sparseVectors).mergeFrom(denseVectors); | ||
|
||
if (options.embeddingField.isEmpty()) { | ||
return vectorsBuilder.build(); | ||
} | ||
|
||
int vectorFieldIndex = schema.fieldIndex(options.embeddingField.trim()); | ||
float[] embeddings = record.getArray(vectorFieldIndex).toFloatArray(); | ||
|
||
// The vector name defaults to "" | ||
return vectorsBuilder | ||
.mergeFrom(namedVectors(Collections.singletonMap(options.vectorName, vector(embeddings)))).build(); | ||
|
||
} | ||
|
||
private static Vectors prepareSparseVectors(InternalRow record, StructType schema, QdrantOptions options) { | ||
Map<String, Vector> sparseVectors = new HashMap<>(); | ||
|
||
for (int i = 0; i < options.sparseVectorNames.length; i++) { | ||
String sparseVectorName = options.sparseVectorNames[i]; | ||
String sparseVectorValueField = options.sparseVectorValueFields[i]; | ||
String sparseVectorIndexField = options.sparseVectorIndexFields[i]; | ||
int sparseVectorValueFieldIndex = schema.fieldIndex(sparseVectorValueField.trim()); | ||
int sparseVectorIndexFieldIndex = schema.fieldIndex(sparseVectorIndexField.trim()); | ||
List<Float> sparseVectorValues = Floats.asList(record.getArray(sparseVectorValueFieldIndex).toFloatArray()); | ||
List<Integer> sparseVectorIndices = Ints.asList(record.getArray(sparseVectorIndexFieldIndex).toIntArray()); | ||
|
||
sparseVectors.put(sparseVectorName, vector(sparseVectorValues, sparseVectorIndices)); | ||
} | ||
|
||
return namedVectors(sparseVectors); | ||
} | ||
|
||
private static Vectors prepareDenseVectors(InternalRow record, StructType schema, QdrantOptions options) { | ||
Map<String, Vector> denseVectors = new HashMap<>(); | ||
|
||
for (int i = 0; i < options.vectorNames.length; i++) { | ||
String vectorName = options.vectorNames[i]; | ||
String vectorField = options.vectorFields[i]; | ||
int vectorFieldIndex = schema.fieldIndex(vectorField.trim()); | ||
float[] vectorValues = record.getArray(vectorFieldIndex).toFloatArray(); | ||
|
||
denseVectors.put(vectorName, vector(vectorValues)); | ||
} | ||
|
||
return namedVectors(denseVectors); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters