Skip to content

Commit

Permalink
Add batching strategy to more vector stores
Browse files Browse the repository at this point in the history
 Apply batching when adding Documents to the following vector stores:

- Azure vector store
- Cassandra
- MongoDB Atlas
- OpenSearch
- Oracle
- Pinecone

 This improves efficiency by processing multiple Documents at once instead of individually, reducing the overhead for each operation.

Related to #1261
  • Loading branch information
sobychacko authored and Mark Pollack committed Sep 17, 2024
1 parent 4b123a7 commit 66455b9
Show file tree
Hide file tree
Showing 23 changed files with 225 additions and 95 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.azure;

import com.azure.core.credential.AzureKeyCredential;
Expand All @@ -23,7 +24,9 @@

import java.util.List;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.azure.AzureVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.ObjectProvider;
Expand All @@ -36,6 +39,7 @@

/**
* @author Christian Tzolov
* @author Soby Chacko
*/
@AutoConfiguration
@ConditionalOnClass({ EmbeddingModel.class, SearchIndexClient.class, AzureVectorStore.class })
Expand All @@ -51,15 +55,22 @@ public SearchIndexClient searchIndexClient(AzureVectorStoreProperties properties
.buildClient();
}

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
public AzureVectorStore vectorStore(SearchIndexClient searchIndexClient, EmbeddingModel embeddingModel,
AzureVectorStoreProperties properties, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

var vectorStore = new AzureVectorStore(searchIndexClient, embeddingModel, properties.isInitializeSchema(),
List.of(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null));
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);

vectorStore.setIndexName(properties.getIndexName());

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2024 - 2024 the original author or authors.
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.cassandra;

import java.time.Duration;
Expand All @@ -22,7 +23,9 @@

import io.micrometer.observation.ObservationRegistry;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.CassandraVectorStore;
import org.springframework.ai.vectorstore.CassandraVectorStoreConfig;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
Expand All @@ -38,18 +41,26 @@
/**
* @author Mick Semb Wever
* @author Christian Tzolov
* @author Soby Chacko
* @since 1.0.0
*/
@AutoConfiguration(after = CassandraAutoConfiguration.class)
@ConditionalOnClass({ CassandraVectorStore.class, CqlSession.class })
@EnableConfigurationProperties(CassandraVectorStoreProperties.class)
public class CassandraVectorStoreAutoConfiguration {

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
public CassandraVectorStore vectorStore(EmbeddingModel embeddingModel, CassandraVectorStoreProperties properties,
CqlSession cqlSession, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

var builder = CassandraVectorStoreConfig.builder().withCqlSession(cqlSession);

Expand All @@ -69,7 +80,7 @@ public CassandraVectorStore vectorStore(EmbeddingModel embeddingModel, Cassandra

return new CassandraVectorStore(builder.build(), embeddingModel,
observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null));
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

package org.springframework.ai.autoconfigure.vectorstore.gemfire;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.GemFireVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.ObjectProvider;
Expand Down Expand Up @@ -47,11 +49,18 @@ GemFireVectorStoreAutoConfiguration.PropertiesGemFireConnectionDetails gemfireCo
return new GemFireVectorStoreAutoConfiguration.PropertiesGemFireConnectionDetails(properties);
}

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
public GemFireVectorStore gemfireVectorStore(EmbeddingModel embeddingModel, GemFireVectorStoreProperties properties,
GemFireConnectionDetails gemFireConnectionDetails, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {
var builder = new GemFireVectorStore.GemFireVectorStoreConfig.Builder();

builder.setHost(gemFireConnectionDetails.getHost())
Expand All @@ -65,7 +74,7 @@ public GemFireVectorStore gemfireVectorStore(EmbeddingModel embeddingModel, GemF
.setSslEnabled(properties.isSslEnabled());
return new GemFireVectorStore(builder.build(), embeddingModel, properties.isInitializeSchema(),
observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null));
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);
}

private static class PropertiesGemFireConnectionDetails implements GemFireConnectionDetails {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.mongo;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.MongoDBAtlasVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.ObjectProvider;
Expand Down Expand Up @@ -45,11 +48,18 @@
@EnableConfigurationProperties(MongoDBAtlasVectorStoreProperties.class)
public class MongoDBAtlasVectorStoreAutoConfiguration {

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
MongoDBAtlasVectorStore vectorStore(MongoTemplate mongoTemplate, EmbeddingModel embeddingModel,
MongoDBAtlasVectorStoreProperties properties, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

var builder = MongoDBAtlasVectorStore.MongoDBVectorStoreConfig.builder();

Expand All @@ -66,7 +76,7 @@ MongoDBAtlasVectorStore vectorStore(MongoTemplate mongoTemplate, EmbeddingModel

return new MongoDBAtlasVectorStore(mongoTemplate, embeddingModel, config, properties.isInitializeSchema(),
observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null));
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.opensearch;

import org.apache.hc.client5.http.auth.AuthScope;
Expand All @@ -24,7 +25,10 @@
import org.opensearch.client.transport.aws.AwsSdk2Transport;
import org.opensearch.client.transport.aws.AwsSdk2TransportOptions;
import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.OpenSearchVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.ObjectProvider;
Expand Down Expand Up @@ -58,17 +62,24 @@ PropertiesOpenSearchConnectionDetails openSearchConnectionDetails(OpenSearchVect
return new PropertiesOpenSearchConnectionDetails(properties);
}

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
OpenSearchVectorStore vectorStore(OpenSearchVectorStoreProperties properties, OpenSearchClient openSearchClient,
EmbeddingModel embeddingModel, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {
var indexName = Optional.ofNullable(properties.getIndexName()).orElse(OpenSearchVectorStore.DEFAULT_INDEX_NAME);
var mappingJson = Optional.ofNullable(properties.getMappingJson())
.orElse(OpenSearchVectorStore.DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION_1536);
return new OpenSearchVectorStore(indexName, openSearchClient, embeddingModel, mappingJson,
properties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null));
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);
}

@Configuration(proxyBeanMethods = false)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,11 +13,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.oracle;

import javax.sql.DataSource;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.OracleVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.ObjectProvider;
Expand All @@ -35,22 +38,30 @@
* @author Loïc Lefèvre
* @author Eddú Meléndez
* @author Christian Tzolov
* @author Soby Chacko
*/
@AutoConfiguration(after = JdbcTemplateAutoConfiguration.class)
@ConditionalOnClass({ OracleVectorStore.class, DataSource.class, JdbcTemplate.class })
@EnableConfigurationProperties(OracleVectorStoreProperties.class)
public class OracleVectorStoreAutoConfiguration {

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
public OracleVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel,
OracleVectorStoreProperties properties, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {
return new OracleVectorStore(jdbcTemplate, embeddingModel, properties.getTableName(), properties.getIndexType(),
properties.getDistanceType(), properties.getDimensions(), properties.getSearchAccuracy(),
properties.isInitializeSchema(), properties.isRemoveExistingVectorStoreTable(),
properties.isForcedNormalization(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null));
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,9 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.pinecone;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.PineconeVectorStore;
import org.springframework.ai.vectorstore.PineconeVectorStore.PineconeVectorStoreConfig;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
Expand All @@ -30,17 +33,25 @@

/**
* @author Christian Tzolov
* @author Soby Chacko
*/
@AutoConfiguration
@ConditionalOnClass({ PineconeVectorStore.class, EmbeddingModel.class })
@EnableConfigurationProperties(PineconeVectorStoreProperties.class)
public class PineconeVectorStoreAutoConfiguration {

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
public PineconeVectorStore vectorStore(EmbeddingModel embeddingModel, PineconeVectorStoreProperties properties,
ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

var config = PineconeVectorStoreConfig.builder()
.withApiKey(properties.getApiKey())
Expand All @@ -55,7 +66,7 @@ public PineconeVectorStore vectorStore(EmbeddingModel embeddingModel, PineconeVe

return new PineconeVectorStore(config, embeddingModel,
observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null));
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);
}

}
Loading

0 comments on commit 66455b9

Please sign in to comment.