Skip to content

Commit

Permalink
Hide inference services API
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Nov 18, 2024
1 parent e4f4c95 commit 092f2db
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 153 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -131,153 +128,6 @@ public void testApisWithoutTaskType() throws IOException {
deleteModel(modelId);
}

@SuppressWarnings("unchecked")
public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
if (ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) {
assertThat(services.size(), equalTo(18));
} else {
assertThat(services.size(), equalTo(17));
}

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("provider");
}

Arrays.sort(providers);

var providerList = new ArrayList<>(
Arrays.asList(
"alibabacloud-ai-search",
"amazonbedrock",
"anthropic",
"azureaistudio",
"azureopenai",
"cohere",
"elasticsearch",
"googleaistudio",
"googlevertexai",
"hugging_face",
"mistral",
"openai",
"streaming_completion_test_service",
"test_reranking_service",
"test_service",
"text_embedding_test_service",
"watsonxai"
)
);
if (ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) {
providerList.add(6, "elastic");
}
assertArrayEquals(providers, providerList.toArray());
}

@SuppressWarnings("unchecked")
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
assertThat(services.size(), equalTo(13));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("provider");
}

Arrays.sort(providers);
assertArrayEquals(
providers,
List.of(
"alibabacloud-ai-search",
"amazonbedrock",
"azureaistudio",
"azureopenai",
"cohere",
"elasticsearch",
"googleaistudio",
"googlevertexai",
"hugging_face",
"mistral",
"openai",
"text_embedding_test_service",
"watsonxai"
).toArray()
);
}

@SuppressWarnings("unchecked")
public void testGetServicesWithRerankTaskType() throws IOException {
List<Object> services = getServices(TaskType.RERANK);
assertThat(services.size(), equalTo(5));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("provider");
}

Arrays.sort(providers);
assertArrayEquals(
providers,
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "test_reranking_service").toArray()
);
}

@SuppressWarnings("unchecked")
public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(9));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("provider");
}

Arrays.sort(providers);
assertArrayEquals(
providers,
List.of(
"alibabacloud-ai-search",
"amazonbedrock",
"anthropic",
"azureaistudio",
"azureopenai",
"cohere",
"googleaistudio",
"openai",
"streaming_completion_test_service"
).toArray()
);
}

@SuppressWarnings("unchecked")
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);

if (ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) {
assertThat(services.size(), equalTo(5));
} else {
assertThat(services.size(), equalTo(4));
}

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("provider");
}

Arrays.sort(providers);

var providerList = new ArrayList<>(Arrays.asList("alibabacloud-ai-search", "elasticsearch", "hugging_face", "test_service"));
if (ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) {
providerList.add(1, "elastic");
}
assertArrayEquals(providers, providerList.toArray());
}

public void testSkipValidationAndStart() throws IOException {
String openAiConfigWithBadApiKey = """
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceEndpointAction;
import org.elasticsearch.xpack.inference.rest.RestGetInferenceDiagnosticsAction;
import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction;
import org.elasticsearch.xpack.inference.rest.RestGetInferenceServicesAction;
import org.elasticsearch.xpack.inference.rest.RestInferenceAction;
import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction;
import org.elasticsearch.xpack.inference.rest.RestStreamInferenceAction;
Expand Down Expand Up @@ -182,8 +181,7 @@ public List<RestHandler> getRestHandlers(
new RestPutInferenceModelAction(),
new RestUpdateInferenceModelAction(),
new RestDeleteInferenceEndpointAction(),
new RestGetInferenceDiagnosticsAction(),
new RestGetInferenceServicesAction()
new RestGetInferenceDiagnosticsAction()
);
}

Expand Down

0 comments on commit 092f2db

Please sign in to comment.