Skip to content

Commit

Permalink
feat: sparse, multiple vectors support
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Mar 1, 2024
1 parent 7579916 commit 0e337e9
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 69 deletions.
2 changes: 1 addition & 1 deletion src/main/java/io/qdrant/spark/Qdrant.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
public class Qdrant implements TableProvider, DataSourceRegister {

private final String[] requiredFields =
new String[] {"schema", "collection_name", "embedding_field", "qdrant_url"};
new String[] {"schema", "collection_name", "qdrant_url"};

/**
* Returns the short name of the data source.
Expand Down
80 changes: 25 additions & 55 deletions src/main/java/io/qdrant/spark/QdrantDataWriter.java
Original file line number Diff line number Diff line change
@@ -1,34 +1,30 @@
package io.qdrant.spark;

import static io.qdrant.client.PointIdFactory.id;
import static io.qdrant.client.VectorFactory.vector;
import static io.qdrant.client.VectorsFactory.namedVectors;
import static io.qdrant.client.VectorsFactory.vectors;
import static io.qdrant.spark.QdrantValueFactory.value;

import io.qdrant.client.grpc.JsonWithInt.Value;
import io.qdrant.client.grpc.Points.PointStruct;
import java.io.Serializable;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.write.DataWriter;
import org.apache.spark.sql.connector.write.WriterCommitMessage;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import io.qdrant.client.grpc.JsonWithInt.Value;
import io.qdrant.client.grpc.Points.PointId;
import io.qdrant.client.grpc.Points.PointStruct;
import io.qdrant.client.grpc.Points.Vectors;

/**
* A DataWriter implementation that writes data to Qdrant, a vector search engine. This class takes
* QdrantOptions and StructType as input and writes data to QdrantGRPC. It implements the DataWriter
* interface and overrides its methods write, commit, abort and close. It also has a private method
* write that is used to upload a batch of points to Qdrant. The class uses a Point class to
* A DataWriter implementation that writes data to Qdrant, a vector search
* engine. This class takes
* QdrantOptions and StructType as input and writes data to QdrantGRPC. It
* implements the DataWriter
* interface and overrides its methods write, commit, abort and close. It also
* has a private method
* write that is used to upload a batch of points to Qdrant. The class uses a
* Point class to
* represent a data point and an ArrayList to store the points.
*/
public class QdrantDataWriter implements DataWriter<InternalRow>, Serializable {
Expand All @@ -40,7 +36,7 @@ public class QdrantDataWriter implements DataWriter<InternalRow>, Serializable {

private final ArrayList<PointStruct> points = new ArrayList<>();

public QdrantDataWriter(QdrantOptions options, StructType schema) throws Exception {
public QdrantDataWriter(QdrantOptions options, StructType schema) {
this.options = options;
this.schema = schema;
this.qdrantUrl = options.qdrantUrl;
Expand All @@ -50,43 +46,14 @@ public QdrantDataWriter(QdrantOptions options, StructType schema) throws Excepti
@Override
public void write(InternalRow record) {
PointStruct.Builder pointBuilder = PointStruct.newBuilder();
Map<String, Value> payload = new HashMap<>();

if (this.options.idField == null) {
pointBuilder.setId(id(UUID.randomUUID()));
}
for (StructField field : this.schema.fields()) {
int fieldIndex = this.schema.fieldIndex(field.name());
if (this.options.idField != null && field.name().equals(this.options.idField)) {

DataType dataType = field.dataType();
switch (dataType.typeName()) {
case "string":
pointBuilder.setId(id(UUID.fromString(record.getString(fieldIndex))));
break;

case "integer":
case "long":
pointBuilder.setId(id(record.getInt(fieldIndex)));
break;

default:
throw new IllegalArgumentException("Point ID should be of type string or integer");
}

} else if (field.name().equals(this.options.embeddingField)) {
float[] embeddings = record.getArray(fieldIndex).toFloatArray();
if (options.vectorName != null) {
pointBuilder.setVectors(
namedVectors(Collections.singletonMap(options.vectorName, vector(embeddings))));
} else {
pointBuilder.setVectors(vectors(embeddings));
}
} else {
payload.put(field.name(), value(record, field, fieldIndex));
}
}
PointId pointId = QdrantPointIdHandler.preparePointId(record, this.schema, this.options);
pointBuilder.setId(pointId);

Vectors vectors = QdrantVectorHandler.prepareVectors(record, this.schema, this.options);
pointBuilder.setVectors(vectors);

Map<String, Value> payload = QdrantPayloadHandler.preparePayload(record, this.schema, this.options);
pointBuilder.putAllPayload(payload);
this.points.add(pointBuilder.build());

Expand Down Expand Up @@ -132,8 +99,11 @@ public void write(int retries) {
}

@Override
public void abort() {}
public void abort() {
}

@Override
public void close() {}
public void close() {
}

}
63 changes: 52 additions & 11 deletions src/main/java/io/qdrant/spark/QdrantOptions.java
Original file line number Diff line number Diff line change
@@ -1,38 +1,79 @@
package io.qdrant.spark;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

/** This class represents the options for connecting to a Qdrant instance. */
public class QdrantOptions implements Serializable {
public String qdrantUrl;
public String apiKey;
public String collectionName;
public String embeddingField;
public String idField;
public String vectorName;
public int batchSize = 100;

public int batchSize = 64;
public int retries = 3;

// Should've named the option 'vectorField'. But too late now.
public String embeddingField;
public String vectorName;

public String[] sparseVectorValueFields;
public String[] sparseVectorIndexFields;
public String[] sparseVectorNames;

public String[] vectorFields;
public String[] vectorNames;

public List<String> payloadFieldsToSkip = new ArrayList<String>();

/**
* Constructor for QdrantOptions.
*
* @param options A map of options for connecting to a Qdrant instance.
*/
public QdrantOptions(Map<String, String> options) {
this.qdrantUrl = options.get("qdrant_url");
this.collectionName = options.get("collection_name");
this.embeddingField = options.get("embedding_field");
this.idField = options.get("id_field");
this.apiKey = options.get("api_key");
this.vectorName = options.get("vector_name");
qdrantUrl = options.get("qdrant_url");
collectionName = options.get("collection_name");
embeddingField = options.getOrDefault("embedding_field", "");
idField = options.getOrDefault("id_field", "");
apiKey = options.getOrDefault("api_key", "");
vectorName = options.getOrDefault("vector_name", "");
sparseVectorValueFields = options.getOrDefault("sparse_vector_value_fields", "").split(",");
sparseVectorIndexFields = options.getOrDefault("sparse_vector_index_fields", "").split(",");
sparseVectorNames = options.getOrDefault("sparse_vector_names", "").split(",");
vectorFields = options.getOrDefault("vector_fields", "").split(",");
vectorNames = options.getOrDefault("vector_names", "").split(",");

if (sparseVectorValueFields.length != sparseVectorIndexFields.length
|| sparseVectorValueFields.length != sparseVectorNames.length) {
throw new IllegalArgumentException(
"Sparse vector value fields, index fields and names should be of same length");
}

if (vectorFields.length != vectorNames.length) {
throw new IllegalArgumentException("Vector fields and names should be of same length");
}

if (options.containsKey("batch_size")) {
this.batchSize = Integer.parseInt(options.get("batch_size"));
batchSize = Integer.parseInt(options.get("batch_size"));
}

if (options.containsKey("retries")) {
this.retries = Integer.parseInt(options.get("retries"));
retries = Integer.parseInt(options.get("retries"));
}

payloadFieldsToSkip.add(idField);
payloadFieldsToSkip.add(embeddingField);

payloadFieldsToSkip.addAll(Arrays.asList(sparseVectorValueFields));
payloadFieldsToSkip.addAll(Arrays.asList(sparseVectorIndexFields));
payloadFieldsToSkip.addAll(Arrays.asList(sparseVectorNames));

payloadFieldsToSkip.addAll(Arrays.asList(vectorFields));
payloadFieldsToSkip.addAll(Arrays.asList(vectorNames));

}
}
29 changes: 29 additions & 0 deletions src/main/java/io/qdrant/spark/QdrantPayloadHandler.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package io.qdrant.spark;

import static io.qdrant.spark.QdrantValueFactory.value;

import java.util.HashMap;
import java.util.Map;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import io.qdrant.client.grpc.JsonWithInt.Value;

public class QdrantPayloadHandler {
static Map<String, Value> preparePayload(InternalRow record, StructType schema, QdrantOptions options) {

Map<String, Value> payload = new HashMap<>();
for (StructField field : schema.fields()) {

if (options.payloadFieldsToSkip.contains(field.name())) {
continue;
}
int fieldIndex = schema.fieldIndex(field.name());
payload.put(field.name(), value(record, field, fieldIndex));
}

return payload;
}
}
35 changes: 35 additions & 0 deletions src/main/java/io/qdrant/spark/QdrantPointIdHandler.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.qdrant.spark;

import static io.qdrant.client.PointIdFactory.id;

import java.util.UUID;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructType;

import io.qdrant.client.grpc.Points.PointId;

public class QdrantPointIdHandler {
static PointId preparePointId(InternalRow record, StructType schema, QdrantOptions options) {
String idField = options.idField;

if (idField.isEmpty()) {
return id(UUID.randomUUID());
}

int idFieldIndex = schema.fieldIndex(idField.trim());
DataType idFieldType = schema.fields()[idFieldIndex].dataType();
switch (idFieldType.typeName()) {
case "string":
return id(UUID.fromString(record.getString(idFieldIndex)));

case "integer":
case "long":
return id(record.getInt(idFieldIndex));

default:
throw new IllegalArgumentException("Point ID should be of type string or integer");
}
}
}
74 changes: 74 additions & 0 deletions src/main/java/io/qdrant/spark/QdrantVectorHandler.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package io.qdrant.spark;

import static io.qdrant.client.VectorFactory.vector;
import static io.qdrant.client.VectorsFactory.namedVectors;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.StructType;

import com.google.common.primitives.Floats;
import com.google.common.primitives.Ints;

import io.qdrant.client.grpc.Points.Vector;
import io.qdrant.client.grpc.Points.Vectors;

public class QdrantVectorHandler {
static Vectors prepareVectors(InternalRow record, StructType schema, QdrantOptions options) {

Vectors.Builder vectorsBuilder = Vectors.newBuilder();
Vectors sparseVectors = prepareSparseVectors(record, schema, options);
Vectors denseVectors = prepareDenseVectors(record, schema, options);

vectorsBuilder.mergeFrom(sparseVectors).mergeFrom(denseVectors);

if (options.embeddingField.isEmpty()) {
return vectorsBuilder.build();
}

int vectorFieldIndex = schema.fieldIndex(options.embeddingField.trim());
float[] embeddings = record.getArray(vectorFieldIndex).toFloatArray();

// The vector name defaults to ""
return vectorsBuilder
.mergeFrom(namedVectors(Collections.singletonMap(options.vectorName, vector(embeddings)))).build();

}

private static Vectors prepareSparseVectors(InternalRow record, StructType schema, QdrantOptions options) {
Map<String, Vector> sparseVectors = new HashMap<>();

for (int i = 0; i < options.sparseVectorNames.length; i++) {
String sparseVectorName = options.sparseVectorNames[i];
String sparseVectorValueField = options.sparseVectorValueFields[i];
String sparseVectorIndexField = options.sparseVectorIndexFields[i];
int sparseVectorValueFieldIndex = schema.fieldIndex(sparseVectorValueField.trim());
int sparseVectorIndexFieldIndex = schema.fieldIndex(sparseVectorIndexField.trim());
List<Float> sparseVectorValues = Floats.asList(record.getArray(sparseVectorValueFieldIndex).toFloatArray());
List<Integer> sparseVectorIndices = Ints.asList(record.getArray(sparseVectorIndexFieldIndex).toIntArray());

sparseVectors.put(sparseVectorName, vector(sparseVectorValues, sparseVectorIndices));
}

return namedVectors(sparseVectors);
}

private static Vectors prepareDenseVectors(InternalRow record, StructType schema, QdrantOptions options) {
Map<String, Vector> denseVectors = new HashMap<>();

for (int i = 0; i < options.vectorNames.length; i++) {
String vectorName = options.vectorNames[i];
String vectorField = options.vectorFields[i];
int vectorFieldIndex = schema.fieldIndex(vectorField.trim());
float[] vectorValues = record.getArray(vectorFieldIndex).toFloatArray();

denseVectors.put(vectorName, vector(vectorValues));
}

return namedVectors(denseVectors);
}
}
4 changes: 2 additions & 2 deletions src/test/java/io/qdrant/spark/TestQdrantOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public void testQdrantOptions() {
assertEquals("my-id-field", qdrantOptions.idField);

// Test default values
assertEquals(100, qdrantOptions.batchSize);
assertEquals(3, qdrantOptions.retries);
assertEquals(qdrantOptions.batchSize, 64);
assertEquals(qdrantOptions.retries, 3);
}
}

0 comments on commit 0e337e9

Please sign in to comment.