diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/chroma.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/chroma.adoc index a6e7793b90..2d4e0446fa 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/chroma.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/chroma.adoc @@ -205,19 +205,20 @@ TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Man === Sample Code -Create a `RestTemplate` instance with proper ChromaDB authorization configurations and Use it to create a `ChromaApi` instance: +Create a `RestClient.Builder` instance with proper ChromaDB authorization configurations and Use it to create a `ChromaApi` instance: [source,java] ---- @Bean -public RestTemplate restTemplate() { - return new RestTemplate(); +public RestClient.Builder builder() { + return RestClient.builder().requestFactory(new SimpleClientHttpRequestFactory()); } + @Bean -public ChromaApi chromaApi(RestTemplate restTemplate) { +public ChromaApi chromaApi(RestClient.Builder restClientBuilder) { String chromaUrl = "http://localhost:8000"; - ChromaApi chromaApi = new ChromaApi(chromaUrl, restTemplate); + ChromaApi chromaApi = new ChromaApi(chromaUrl, restClientBuilder); return chromaApi; } ---- diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java index 88d0bcf5bb..592f99421f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java @@ -15,8 +15,6 @@ */ package org.springframework.ai.autoconfigure.vectorstore.chroma; -import com.fasterxml.jackson.databind.ObjectMapper; - import org.springframework.ai.chroma.ChromaApi; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.vectorstore.ChromaVectorStore; @@ -25,15 +23,18 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; +import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.util.StringUtils; -import org.springframework.web.client.RestTemplate; +import org.springframework.web.client.RestClient; + +import com.fasterxml.jackson.databind.ObjectMapper; /** * @author Christian Tzolov * @author Eddú Meléndez */ @AutoConfiguration -@ConditionalOnClass({ EmbeddingModel.class, RestTemplate.class, ChromaVectorStore.class, ObjectMapper.class }) +@ConditionalOnClass({ EmbeddingModel.class, RestClient.class, ChromaVectorStore.class, ObjectMapper.class }) @EnableConfigurationProperties({ ChromaApiProperties.class, ChromaVectorStoreProperties.class }) public class ChromaVectorStoreAutoConfiguration { @@ -45,18 +46,18 @@ PropertiesChromaConnectionDetails chromaConnectionDetails(ChromaApiProperties pr @Bean @ConditionalOnMissingBean - public RestTemplate restTemplate() { - return new RestTemplate(); + public RestClient.Builder builder() { + return RestClient.builder().requestFactory(new SimpleClientHttpRequestFactory()); } @Bean @ConditionalOnMissingBean - public ChromaApi chromaApi(ChromaApiProperties apiProperties, RestTemplate restTemplate, + public ChromaApi chromaApi(ChromaApiProperties apiProperties, RestClient.Builder restClientBuilder, ChromaConnectionDetails connectionDetails) { String chromaUrl = String.format("%s:%s", connectionDetails.getHost(), connectionDetails.getPort()); - var chromaApi = new ChromaApi(chromaUrl, restTemplate, new ObjectMapper()); + var chromaApi = new ChromaApi(chromaUrl, restClientBuilder, new ObjectMapper()); if (StringUtils.hasText(apiProperties.getKeyToken())) { chromaApi.withKeyToken(apiProperties.getKeyToken()); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java index 63862fbffe..d20f148099 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java @@ -43,7 +43,7 @@ public class ChromaVectorStoreAutoConfigurationIT { @Container - static ChromaDBContainer chroma = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.4.15"); + static ChromaDBContainer chroma = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.5.0"); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(ChromaVectorStoreAutoConfiguration.class)) diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java index 04474cd532..343489ffef 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java @@ -19,49 +19,55 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Consumer; import java.util.regex.Matcher; import java.util.regex.Pattern; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; - import org.springframework.ai.chroma.ChromaApi.QueryRequest.Include; -import org.springframework.http.HttpEntity; +import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; +import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.http.client.support.BasicAuthenticationInterceptor; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import org.springframework.web.client.HttpServerErrorException; -import org.springframework.web.client.RestTemplate; +import org.springframework.web.client.RestClient; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; /** * Single-class Chroma API implementation based on the (unofficial) Chroma REST API. * * @author Christian Tzolov + * @author Eddú Meléndez */ public class ChromaApi { // Regular expression pattern that looks for a message inside the ValueError(...). private static Pattern VALUE_ERROR_PATTERN = Pattern.compile("ValueError\\('([^']*)'\\)"); - private final String baseUrl; - - private final RestTemplate restTemplate; + private RestClient restClient; private final ObjectMapper objectMapper; private String keyToken; - public ChromaApi(String baseUrl, RestTemplate restTemplate) { - this(baseUrl, restTemplate, new ObjectMapper()); + public ChromaApi(String baseUrl) { + this(baseUrl, RestClient.builder().requestFactory(new SimpleClientHttpRequestFactory()), new ObjectMapper()); } - public ChromaApi(String baseUrl, RestTemplate restTemplate, ObjectMapper objectMapper) { - this.baseUrl = baseUrl; - this.restTemplate = restTemplate; + public ChromaApi(String baseUrl, RestClient.Builder restClientBuilder) { + this(baseUrl, restClientBuilder, new ObjectMapper()); + } + + public ChromaApi(String baseUrl, RestClient.Builder restClientBuilder, ObjectMapper objectMapper) { + Consumer defaultHeaders = headers -> { + headers.setContentType(MediaType.APPLICATION_JSON); + }; + this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build(); this.objectMapper = objectMapper; } @@ -82,7 +88,9 @@ public ChromaApi withKeyToken(String keyToken) { * @param password Credentials password. */ public ChromaApi withBasicAuthCredentials(String username, String password) { - this.restTemplate.getInterceptors().add(new BasicAuthenticationInterceptor(username, password)); + this.restClient = this.restClient.mutate() + .requestInterceptor(new BasicAuthenticationInterceptor(username, password)) + .build(); return this; } @@ -265,9 +273,23 @@ public List toEmbeddingResponseList(QueryResponse queryResponse) { public Collection createCollection(CreateCollectionRequest createCollectionRequest) { - return this.restTemplate - .exchange(this.baseUrl + "/api/v1/collections", HttpMethod.POST, - this.getHttpEntityFor(createCollectionRequest), Collection.class) + return this.restClient.post() + .uri("/api/v1/collections") + .headers(this::httpHeaders) + .body(createCollectionRequest) + .retrieve() + .toEntity(Collection.class) + .getBody(); + } + + public Map createCollection2(CreateCollectionRequest createCollectionRequest) { + + return this.restClient.post() + .uri("/api/v1/collections") + .headers(this::httpHeaders) + .body(createCollectionRequest) + .retrieve() + .toEntity(Map.class) .getBody(); } @@ -278,16 +300,21 @@ public Collection createCollection(CreateCollectionRequest createCollectionReque */ public void deleteCollection(String collectionName) { - this.restTemplate.exchange(this.baseUrl + "/api/v1/collections/{collection_name}", HttpMethod.DELETE, - new HttpEntity<>(httpHeaders()), Void.class, collectionName); + this.restClient.delete() + .uri("/api/v1/collections/{collection_name}", collectionName) + .headers(this::httpHeaders) + .retrieve() + .toBodilessEntity(); } public Collection getCollection(String collectionName) { try { - return this.restTemplate - .exchange(this.baseUrl + "/api/v1/collections/{collection_name}", HttpMethod.GET, - new HttpEntity<>(httpHeaders()), Collection.class, collectionName) + return this.restClient.get() + .uri("/api/v1/collections/{collection_name}", collectionName) + .headers(this::httpHeaders) + .retrieve() + .toEntity(Collection.class) .getBody(); } catch (HttpServerErrorException e) { @@ -305,9 +332,11 @@ private static class CollectionList extends ArrayList { public List listCollections() { - return this.restTemplate - .exchange(this.baseUrl + "/api/v1/collections", HttpMethod.GET, new HttpEntity<>(httpHeaders()), - CollectionList.class) + return this.restClient.get() + .uri("/api/v1/collections") + .headers(this::httpHeaders) + .retrieve() + .toEntity(CollectionList.class) .getBody(); } @@ -317,41 +346,55 @@ public List listCollections() { public void upsertEmbeddings(String collectionId, AddEmbeddingsRequest embedding) { - this.restTemplate - .exchange(this.baseUrl + "/api/v1/collections/{collection_id}/upsert", HttpMethod.POST, - this.getHttpEntityFor(embedding), Boolean.class, collectionId) - .getBody(); + this.restClient.post() + .uri("/api/v1/collections/{collection_id}/upsert", collectionId) + .headers(this::httpHeaders) + .body(embedding) + .retrieve() + .toBodilessEntity(); } public List deleteEmbeddings(String collectionId, DeleteEmbeddingsRequest deleteRequest) { - return this.restTemplate - .exchange(this.baseUrl + "/api/v1/collections/{collection_id}/delete", HttpMethod.POST, - this.getHttpEntityFor(deleteRequest), List.class, collectionId) + return this.restClient.post() + .uri("/api/v1/collections/{collection_id}/delete", collectionId) + .headers(this::httpHeaders) + .body(deleteRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference>() { + }) .getBody(); } public Long countEmbeddings(String collectionId) { - return this.restTemplate - .exchange(this.baseUrl + "/api/v1/collections/{collection_id}/count", HttpMethod.GET, - new HttpEntity<>(httpHeaders()), Long.class, collectionId) + return this.restClient.get() + .uri("/api/v1/collections/{collection_id}/count", collectionId) + .headers(this::httpHeaders) + .retrieve() + .toEntity(Long.class) .getBody(); } public QueryResponse queryCollection(String collectionId, QueryRequest queryRequest) { - return this.restTemplate - .exchange(this.baseUrl + "/api/v1/collections/{collection_id}/query", HttpMethod.POST, - this.getHttpEntityFor(queryRequest), QueryResponse.class, collectionId) + return this.restClient.post() + .uri("/api/v1/collections/{collection_id}/query", collectionId) + .headers(this::httpHeaders) + .body(queryRequest) + .retrieve() + .toEntity(QueryResponse.class) .getBody(); } public GetEmbeddingResponse getEmbeddings(String collectionId, GetEmbeddingsRequest getEmbeddingsRequest) { - return this.restTemplate - .exchange(this.baseUrl + "/api/v1/collections/{collection_id}/get", HttpMethod.POST, - this.getHttpEntityFor(getEmbeddingsRequest), GetEmbeddingResponse.class, collectionId) + return this.restClient.post() + .uri("/api/v1/collections/{collection_id}/get", collectionId) + .headers(this::httpHeaders) + .body(getEmbeddingsRequest) + .retrieve() + .toEntity(GetEmbeddingResponse.class) .getBody(); } @@ -365,17 +408,10 @@ public Map where(String text) { } } - private HttpEntity getHttpEntityFor(T body) { - return new HttpEntity<>(body, httpHeaders()); - } - - private HttpHeaders httpHeaders() { - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_JSON); + private void httpHeaders(HttpHeaders headers) { if (StringUtils.hasText(this.keyToken)) { headers.setBearerAuth(this.keyToken); } - return headers; } private String getValueErrorMessage(String logString) { diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java index 0d1576ea0b..0565cf16b6 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java @@ -15,26 +15,24 @@ */ package org.springframework.ai.chroma; +import static org.assertj.core.api.Assertions.assertThat; + import java.util.List; import java.util.Map; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.SpringBootConfiguration; -import org.springframework.boot.test.context.SpringBootTest; -import org.springframework.context.annotation.Bean; import org.springframework.ai.chroma.ChromaApi.AddEmbeddingsRequest; import org.springframework.ai.chroma.ChromaApi.Collection; import org.springframework.ai.chroma.ChromaApi.GetEmbeddingsRequest; import org.springframework.ai.chroma.ChromaApi.QueryRequest; -import org.springframework.web.client.RestTemplate; - -import static org.assertj.core.api.Assertions.assertThat; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; /** * @author Christian Tzolov @@ -45,7 +43,7 @@ public class ChromaApiIT { @Container - static ChromaDBContainer chromaContainer = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.4.22"); + static ChromaDBContainer chromaContainer = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.4.12"); @Autowired ChromaApi chroma; @@ -179,13 +177,8 @@ public void testQueryWhere() { public static class Config { @Bean - public RestTemplate restTemplate() { - return new RestTemplate(); - } - - @Bean - public ChromaApi chromaApi(RestTemplate restTemplate) { - return new ChromaApi(chromaContainer.getEndpoint(), restTemplate); + public ChromaApi chromaApi() { + return new ChromaApi(chromaContainer.getEndpoint()); } } diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/BasicAuthChromaWhereIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/BasicAuthChromaWhereIT.java index 460863313e..5780f7258b 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/BasicAuthChromaWhereIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/BasicAuthChromaWhereIT.java @@ -15,28 +15,27 @@ */ package org.springframework.ai.vectorstore; +import static org.assertj.core.api.Assertions.assertThat; + import java.util.List; import java.util.Map; import org.junit.jupiter.api.Test; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - import org.springframework.ai.chroma.ChromaApi; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.OpenAiEmbeddingModel; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; -import org.springframework.web.client.RestTemplate; +import org.springframework.http.client.SimpleClientHttpRequestFactory; +import org.springframework.web.client.RestClient; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.utility.MountableFile; -import static org.assertj.core.api.Assertions.assertThat; - /** * ChromaDB with Basic Authentication: * https://docs.trychroma.com/usage-guide#basic-authentication @@ -55,7 +54,7 @@ public class BasicAuthChromaWhereIT { * https://docs.trychroma.com/usage-guide#basic-authentication */ @Container - static ChromaDBContainer chromaContainer = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.4.22") + static ChromaDBContainer chromaContainer = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.5.0") .withEnv("CHROMA_SERVER_AUTH_CREDENTIALS_FILE", "/chroma/server.htpasswd") .withEnv("CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER", "chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider") @@ -96,14 +95,13 @@ public void withInFiltersExpressions1() { public static class TestApplication { @Bean - public RestTemplate restTemplate() { - return new RestTemplate(); + public RestClient.Builder builder() { + return RestClient.builder().requestFactory(new SimpleClientHttpRequestFactory()); } @Bean - public ChromaApi chromaApi(RestTemplate restTemplate) { - return new ChromaApi(chromaContainer.getEndpoint(), restTemplate).withBasicAuthCredentials("admin", - "password"); + public ChromaApi chromaApi(RestClient.Builder builder) { + return new ChromaApi(chromaContainer.getEndpoint(), builder).withBasicAuthCredentials("admin", "password"); } @Bean diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java index c29477f0d7..ec5e1fb75d 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java @@ -15,16 +15,14 @@ */ package org.springframework.ai.vectorstore; +import static org.assertj.core.api.Assertions.assertThat; + import java.util.Collections; import java.util.List; import java.util.Map; import java.util.UUID; import org.junit.jupiter.api.Test; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - import org.springframework.ai.chroma.ChromaApi; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -33,9 +31,11 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; -import org.springframework.web.client.RestTemplate; - -import static org.assertj.core.api.Assertions.assertThat; +import org.springframework.http.client.SimpleClientHttpRequestFactory; +import org.springframework.web.client.RestClient; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; /** * @author Christian Tzolov @@ -45,7 +45,7 @@ public class ChromaVectorStoreIT { @Container - static ChromaDBContainer chromaContainer = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.4.22"); + static ChromaDBContainer chromaContainer = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.5.0"); List documents = List.of( new Document("Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!!", @@ -202,13 +202,13 @@ public void searchThresholdTest() { public static class TestApplication { @Bean - public RestTemplate restTemplate() { - return new RestTemplate(); + public RestClient.Builder builder() { + return RestClient.builder().requestFactory(new SimpleClientHttpRequestFactory()); } @Bean - public ChromaApi chromaApi(RestTemplate restTemplate) { - return new ChromaApi(chromaContainer.getEndpoint(), restTemplate); + public ChromaApi chromaApi(RestClient.Builder builder) { + return new ChromaApi(chromaContainer.getEndpoint(), builder); } @Bean diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/TokenSecuredChromaWhereIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/TokenSecuredChromaWhereIT.java index 761ac1fc19..7b6f5e8dd7 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/TokenSecuredChromaWhereIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/TokenSecuredChromaWhereIT.java @@ -15,26 +15,25 @@ */ package org.springframework.ai.vectorstore; +import static org.assertj.core.api.Assertions.assertThat; + import java.util.List; import java.util.Map; import org.junit.jupiter.api.Test; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - import org.springframework.ai.chroma.ChromaApi; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.OpenAiEmbeddingModel; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; -import org.springframework.web.client.RestTemplate; - -import static org.assertj.core.api.Assertions.assertThat; +import org.springframework.http.client.SimpleClientHttpRequestFactory; +import org.springframework.web.client.RestClient; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; /** * ChromaDB with static API Token Authentication: @@ -57,7 +56,7 @@ public class TokenSecuredChromaWhereIT { * https://docs.trychroma.com/usage-guide#static-api-token-authentication */ @Container - static ChromaDBContainer chromaContainer = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.4.22") + static ChromaDBContainer chromaContainer = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.5.0") .withEnv("CHROMA_SERVER_AUTH_CREDENTIALS", CHROMA_SERVER_AUTH_CREDENTIALS) .withEnv("CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER", "chromadb.auth.token.TokenConfigServerAuthCredentialsProvider") @@ -127,13 +126,13 @@ public void withInFiltersExpressions() { public static class TestApplication { @Bean - public RestTemplate restTemplate() { - return new RestTemplate(); + public RestClient.Builder builder() { + return RestClient.builder().requestFactory(new SimpleClientHttpRequestFactory()); } @Bean - public ChromaApi chromaApi(RestTemplate restTemplate) { - var chromaApi = new ChromaApi(chromaContainer.getEndpoint(), restTemplate); + public ChromaApi chromaApi(RestClient.Builder builder) { + var chromaApi = new ChromaApi(chromaContainer.getEndpoint(), builder); chromaApi.withKeyToken(CHROMA_SERVER_AUTH_CREDENTIALS); return chromaApi; }