From 4dadb23fd34616506649f0661f1930607fd4c95f Mon Sep 17 00:00:00 2001 From: Max Hniebergall <137079448+maxhniebergall@users.noreply.github.com> Date: Mon, 7 Oct 2024 14:02:08 -0400 Subject: [PATCH 01/22] [Inference API] Backport deprecate elser service (#114052) * [Inference API] Deprecate elser service * Update docs/changelog/114052.yaml * Update pr number 113216.yaml * Delete docs/changelog/113216.yaml * Update 114052.yaml * Delete docs/changelog/114052.yaml --------- Co-authored-by: Elastic Machine --- .../inference/InferenceServiceRegistry.java | 8 +- .../integration/ModelRegistryIT.java | 27 +- .../InferenceNamedWriteablesProvider.java | 4 +- .../xpack/inference/InferencePlugin.java | 2 - .../action/TransportInferenceAction.java | 2 + .../TransportPutInferenceModelAction.java | 7 + ...InferenceServiceSparseEmbeddingsModel.java | 2 +- ...erviceSparseEmbeddingsServiceSettings.java | 2 +- .../BaseElasticsearchInternalService.java | 1 - .../ElasticsearchInternalService.java | 126 +++- .../ElasticsearchInternalServiceSettings.java | 2 +- .../ElserInternalModel.java | 3 +- .../ElserInternalServiceSettings.java | 5 +- .../ElserMlNodeTaskSettings.java | 2 +- .../{elser => elasticsearch}/ElserModels.java | 4 +- .../services/elser/ElserInternalService.java | 300 ---------- .../inference/ModelConfigurationsTests.java | 4 +- ...enceServiceSparseEmbeddingsModelTests.java | 2 +- ...eSparseEmbeddingsServiceSettingsTests.java | 4 +- .../elastic/ElasticInferenceServiceTests.java | 2 +- ...ticsearchInternalServiceSettingsTests.java | 1 - .../ElasticsearchInternalServiceTests.java | 160 ++++- .../ElserInternalServiceSettingsTests.java | 6 +- .../ElserMlNodeTaskSettingsTests.java | 2 +- .../elasticsearch/ElserModelsTests.java | 39 ++ .../elser/ElserInternalServiceTests.java | 548 ------------------ .../services/elser/ElserModelsTests.java | 33 -- 27 files changed, 361 insertions(+), 937 deletions(-) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/{elser => elasticsearch}/ElserInternalModel.java (93%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/{elser => elasticsearch}/ElserInternalServiceSettings.java (89%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/{elser => elasticsearch}/ElserMlNodeTaskSettings.java (96%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/{elser => elasticsearch}/ElserModels.java (87%) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/{elser => elasticsearch}/ElserInternalServiceSettingsTests.java (89%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/{elser => elasticsearch}/ElserMlNodeTaskSettingsTests.java (93%) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserModelsTests.java diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java index 40b4e37f36509..f1ce94173a550 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java @@ -46,7 +46,13 @@ public Map getServices() { } public Optional getService(String serviceName) { - return Optional.ofNullable(services.get(serviceName)); + + if ("elser".equals(serviceName)) { // ElserService.NAME before removal + // here we are aliasing the elser service to use the elasticsearch service instead + return Optional.ofNullable(services.get("elasticsearch")); // ElasticsearchInternalService.NAME + } else { + return Optional.ofNullable(services.get(serviceName)); + } } public List getNamedWriteables() { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 524cd5014c19e..ea8b32f36f54c 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -28,11 +28,10 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceTests; -import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettingsTests; import org.junit.Before; import java.io.IOException; @@ -118,10 +117,10 @@ public void testGetModel() throws Exception { assertEquals(model.getConfigurations().getService(), modelHolder.get().service()); - var elserService = new ElserInternalService( + var elserService = new ElasticsearchInternalService( new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class), mock(ThreadPool.class)) ); - ElserInternalModel roundTripModel = elserService.parsePersistedConfigWithSecrets( + ElasticsearchInternalModel roundTripModel = (ElasticsearchInternalModel) elserService.parsePersistedConfigWithSecrets( modelHolder.get().inferenceEntityId(), modelHolder.get().taskType(), modelHolder.get().settings(), @@ -277,7 +276,17 @@ public void testGetModelWithSecrets() throws InterruptedException { } private Model buildElserModelConfig(String inferenceEntityId, TaskType taskType) { - return ElserInternalServiceTests.randomModelConfig(inferenceEntityId, taskType); + return switch (taskType) { + case SPARSE_EMBEDDING -> new org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalModel( + inferenceEntityId, + taskType, + ElasticsearchInternalService.NAME, + ElserInternalServiceSettingsTests.createRandom(), + ElserMlNodeTaskSettingsTests.createRandom() + ); + default -> throw new IllegalArgumentException("task type " + taskType + " is not supported"); + }; + } protected void blockingCall(Consumer> function, AtomicReference response, AtomicReference error) @@ -300,7 +309,7 @@ private static Model buildModelWithUnknownField(String inferenceEntityId) { new ModelWithUnknownField( inferenceEntityId, TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, + ElasticsearchInternalService.NAME, ElserInternalServiceSettingsTests.createRandom(), ElserMlNodeTaskSettingsTests.createRandom() ) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 336626cd1db20..02bddb6076d69 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -64,9 +64,9 @@ import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettings; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.MultilingualE5SmallInternalServiceSettings; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings; -import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index f2f019490444e..0ab395f4bfa39 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -86,7 +86,6 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService; @@ -229,7 +228,6 @@ public void loadExtensions(ExtensionLoader loader) { public List getInferenceServiceFactories() { return List.of( - ElserInternalService::new, context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()), context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()), context -> new OpenAiService(httpFactory.get(), serviceComponents.get()), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 4186b281a35b5..d2a73b7df77c1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.inference.InferenceService; @@ -42,6 +43,7 @@ public class TransportInferenceAction extends HandledTransportAction serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMap(config, ModelConfigurations.TASK_SETTINGS); + String serviceName = (String) config.remove(ModelConfigurations.SERVICE); // required for elser service in elasticsearch service throwIfNotEmptyMap(config, name()); String modelId = (String) serviceSettingsMap.get(ElasticsearchInternalServiceSettings.MODEL_ID); if (modelId == null) { - throw new ValidationException().addValidationError("Error parsing request config, model id is missing"); - } - if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) { + if (OLD_ELSER_SERVICE_NAME.equals(serviceName)) { + // TODO complete deprecation of null model ID + // throw new ValidationException().addValidationError("Error parsing request config, model id is missing"); + DEPRECATION_LOGGER.critical( + DeprecationCategory.API, + "inference_api_null_model_id_in_elasticsearch_service", + "Putting elasticsearch service inference endpoints (including elser service) without a model_id field is" + + " deprecated and will be removed in a future release. Please specify a model_id field." + ); + platformArch.accept( + modelListener.delegateFailureAndWrap( + (delegate, arch) -> elserCase(inferenceEntityId, taskType, config, arch, serviceSettingsMap, modelListener) + ) + ); + } else { + throw new IllegalArgumentException("Error parsing service settings, model_id must be provided"); + } + } else if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) { platformArch.accept( modelListener.delegateFailureAndWrap( (delegate, arch) -> e5Case(inferenceEntityId, taskType, config, arch, serviceSettingsMap, modelListener) ) ); + } else if (ElserModels.isValidModel(modelId)) { + platformArch.accept( + modelListener.delegateFailureAndWrap( + (delegate, arch) -> elserCase(inferenceEntityId, taskType, config, arch, serviceSettingsMap, modelListener) + ) + ); } else { customElandCase(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, modelListener); } @@ -239,7 +270,86 @@ static boolean modelVariantValidForArchitecture(Set platformArchitecture // platform agnostic model is always compatible return true; } + return modelId.equals( + selectDefaultModelVariantBasedOnClusterArchitecture( + platformArchitectures, + MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86, + MULTILINGUAL_E5_SMALL_MODEL_ID + ) + ); + } + private void elserCase( + String inferenceEntityId, + TaskType taskType, + Map config, + Set platformArchitectures, + Map serviceSettingsMap, + ActionListener modelListener + ) { + var esServiceSettingsBuilder = ElasticsearchInternalServiceSettings.fromRequestMap(serviceSettingsMap); + final String defaultModelId = selectDefaultModelVariantBasedOnClusterArchitecture( + platformArchitectures, + ELSER_V2_MODEL_LINUX_X86, + ELSER_V2_MODEL + ); + if (false == defaultModelId.equals(esServiceSettingsBuilder.getModelId())) { + + if (esServiceSettingsBuilder.getModelId() == null) { + // TODO remove this case once we remove the option to not pass model ID + esServiceSettingsBuilder.setModelId(defaultModelId); + } else if (esServiceSettingsBuilder.getModelId().equals(ELSER_V2_MODEL)) { + logger.warn( + "The platform agnostic model [{}] was requested on Linux x86_64. " + + "It is recommended to use the optimized model instead [{}]", + ELSER_V2_MODEL, + ELSER_V2_MODEL_LINUX_X86 + ); + } else { + throw new IllegalArgumentException( + "Error parsing request config, model id does not match any models available on this platform. Was [" + + esServiceSettingsBuilder.getModelId() + + "]. You may need to use a platform agnostic model." + ); + } + } + + DEPRECATION_LOGGER.warn( + DeprecationCategory.API, + "inference_api_elser_service", + "The [{}] service is deprecated and will be removed in a future release. Use the [{}] service instead, with" + + " [model_id] set to [{}] in the [service_settings]", + OLD_ELSER_SERVICE_NAME, + ElasticsearchInternalService.NAME, + defaultModelId + ); + + if (modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(platformArchitectures, esServiceSettingsBuilder.getModelId())) { + throw new IllegalArgumentException( + "Error parsing request config, model id does not match any models available on this platform. Was [" + + esServiceSettingsBuilder.getModelId() + + "]" + ); + } + + throwIfNotEmptyMap(config, name()); + throwIfNotEmptyMap(serviceSettingsMap, name()); + + modelListener.onResponse( + new ElserInternalModel( + inferenceEntityId, + taskType, + NAME, + new ElserInternalServiceSettings(esServiceSettingsBuilder.build()), + ElserMlNodeTaskSettings.DEFAULT + ) + ); + } + + private static boolean modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic( + Set platformArchitectures, + String modelId + ) { return modelId.equals( selectDefaultModelVariantBasedOnClusterArchitecture( platformArchitectures, @@ -276,6 +386,14 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M NAME, new MultilingualE5SmallInternalServiceSettings(ElasticsearchInternalServiceSettings.fromPersistedMap(serviceSettingsMap)) ); + } else if (ElserModels.isValidModel(modelId)) { + return new ElserInternalModel( + inferenceEntityId, + taskType, + NAME, + new ElserInternalServiceSettings(ElasticsearchInternalServiceSettings.fromPersistedMap(serviceSettingsMap)), + ElserMlNodeTaskSettings.DEFAULT + ); } else { return createCustomElandModel( inferenceEntityId, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java index 1acf19c5373b7..f8b5837ef387e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java @@ -83,7 +83,7 @@ protected static ElasticsearchInternalServiceSettings.Builder fromMap( validationException ); - // model id is optional as the ELSER and E5 service will default it + // model id is optional as the ELSER service will default it. TODO make this a required field once the elser service is removed String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); if (numAllocations == null && adaptiveAllocationsSettings == null) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java similarity index 93% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalModel.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java index bb668c314649d..827eb178f7633 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.elser; +package org.elasticsearch.xpack.inference.services.elasticsearch; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; @@ -13,7 +13,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel; public class ElserInternalModel extends ElasticsearchInternalModel { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java similarity index 89% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java index fcbabd5a88fc6..f7bcd95c8bd28 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java @@ -5,14 +5,13 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.elser; +package org.elasticsearch.xpack.inference.services.elasticsearch; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings; import java.io.IOException; import java.util.Arrays; @@ -22,7 +21,7 @@ public class ElserInternalServiceSettings extends ElasticsearchInternalServiceSe public static final String NAME = "elser_mlnode_service_settings"; - public static ElasticsearchInternalServiceSettings.Builder fromRequestMap(Map map) { + public static Builder fromRequestMap(Map map) { ValidationException validationException = new ValidationException(); var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettings.java similarity index 96% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettings.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettings.java index 9b9f6e41113e5..934edaa96a15c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettings.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.elser; +package org.elasticsearch.xpack.inference.services.elasticsearch; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserModels.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModels.java similarity index 87% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserModels.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModels.java index af94d2813dd2c..37f528ea3a750 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserModels.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModels.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.elser; +package org.elasticsearch.xpack.inference.services.elasticsearch; import java.util.Set; @@ -23,7 +23,7 @@ public class ElserModels { ); public static boolean isValidModel(String model) { - return VALID_ELSER_MODEL_IDS.contains(model); + return model != null && VALID_ELSER_MODEL_IDS.contains(model); } public static boolean isValidEisModel(String model) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java deleted file mode 100644 index d36b8eca7661e..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java +++ /dev/null @@ -1,300 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - * - * this file has been contributed to by a Generative AI - */ - -package org.elasticsearch.xpack.inference.services.elser; - -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.ChunkingOptions; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceServiceExtension; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.action.InferModelAction; -import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; -import org.elasticsearch.xpack.inference.services.ServiceUtils; -import org.elasticsearch.xpack.inference.services.elasticsearch.BaseElasticsearchInternalService; - -import java.util.ArrayList; -import java.util.EnumSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.function.Consumer; - -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; -import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL; -import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL_LINUX_X86; - -public class ElserInternalService extends BaseElasticsearchInternalService { - - public static final String NAME = "elser"; - - private static final String OLD_MODEL_ID_FIELD_NAME = "model_version"; - - public ElserInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) { - super(context); - } - - // for testing - ElserInternalService( - InferenceServiceExtension.InferenceServiceFactoryContext context, - Consumer>> platformArch - ) { - super(context, platformArch); - } - - @Override - protected EnumSet supportedTaskTypes() { - return EnumSet.of(TaskType.SPARSE_EMBEDDING); - } - - @Override - public void parseRequestConfig( - String inferenceEntityId, - TaskType taskType, - Map config, - ActionListener parsedModelListener - ) { - try { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - var serviceSettingsBuilder = ElserInternalServiceSettings.fromRequestMap(serviceSettingsMap); - - Map taskSettingsMap; - // task settings are optional - if (config.containsKey(ModelConfigurations.TASK_SETTINGS)) { - taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); - } else { - taskSettingsMap = Map.of(); - } - - var taskSettings = taskSettingsFromMap(taskType, taskSettingsMap); - - throwIfNotEmptyMap(config, NAME); - throwIfNotEmptyMap(serviceSettingsMap, NAME); - throwIfNotEmptyMap(taskSettingsMap, NAME); - - if (serviceSettingsBuilder.getModelId() == null) { - platformArch.accept(parsedModelListener.delegateFailureAndWrap((delegate, arch) -> { - serviceSettingsBuilder.setModelId( - selectDefaultModelVariantBasedOnClusterArchitecture(arch, ELSER_V2_MODEL_LINUX_X86, ELSER_V2_MODEL) - ); - parsedModelListener.onResponse( - new ElserInternalModel( - inferenceEntityId, - taskType, - NAME, - new ElserInternalServiceSettings(serviceSettingsBuilder.build()), - taskSettings - ) - ); - })); - } else { - parsedModelListener.onResponse( - new ElserInternalModel( - inferenceEntityId, - taskType, - NAME, - new ElserInternalServiceSettings(serviceSettingsBuilder.build()), - taskSettings - ) - ); - } - } catch (Exception e) { - parsedModelListener.onFailure(e); - } - } - - @Override - public ElserInternalModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - return parsePersistedConfig(inferenceEntityId, taskType, config); - } - - @Override - public ElserInternalModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - - // Change from old model_version field name to new model_id field name as of - // TransportVersions.ML_TEXT_EMBEDDING_INFERENCE_SERVICE_ADDED - if (serviceSettingsMap.containsKey(OLD_MODEL_ID_FIELD_NAME)) { - String modelId = ServiceUtils.removeAsType(serviceSettingsMap, OLD_MODEL_ID_FIELD_NAME, String.class); - serviceSettingsMap.put(ElserInternalServiceSettings.MODEL_ID, modelId); - } - - var serviceSettings = ElserInternalServiceSettings.fromPersistedMap(serviceSettingsMap); - - Map taskSettingsMap; - // task settings are optional - if (config.containsKey(ModelConfigurations.TASK_SETTINGS)) { - taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); - } else { - taskSettingsMap = Map.of(); - } - - var taskSettings = taskSettingsFromMap(taskType, taskSettingsMap); - - return new ElserInternalModel(inferenceEntityId, taskType, NAME, new ElserInternalServiceSettings(serviceSettings), taskSettings); - } - - @Override - public void infer( - Model model, - @Nullable String query, - List inputs, - boolean stream, - Map taskSettings, - InputType inputType, - TimeValue timeout, - ActionListener listener - ) { - // No task settings to override with requestTaskSettings - - try { - checkCompatibleTaskType(model.getConfigurations().getTaskType()); - } catch (Exception e) { - listener.onFailure(e); - return; - } - - var request = buildInferenceRequest( - model.getConfigurations().getInferenceEntityId(), - TextExpansionConfigUpdate.EMPTY_UPDATE, - inputs, - inputType, - timeout, - false // chunk - ); - - client.execute( - InferModelAction.INSTANCE, - request, - listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse(SparseEmbeddingResults.of(inferenceResult.getInferenceResults())) - ) - ); - } - - public void chunkedInfer( - Model model, - List input, - Map taskSettings, - InputType inputType, - @Nullable ChunkingOptions chunkingOptions, - TimeValue timeout, - ActionListener> listener - ) { - chunkedInfer(model, null, input, taskSettings, inputType, chunkingOptions, timeout, listener); - } - - @Override - public void chunkedInfer( - Model model, - @Nullable String query, - List inputs, - Map taskSettings, - InputType inputType, - @Nullable ChunkingOptions chunkingOptions, - TimeValue timeout, - ActionListener> listener - ) { - try { - checkCompatibleTaskType(model.getConfigurations().getTaskType()); - } catch (Exception e) { - listener.onFailure(e); - return; - } - - var configUpdate = chunkingOptions != null - ? new TokenizationConfigUpdate(chunkingOptions.windowSize(), chunkingOptions.span()) - : new TokenizationConfigUpdate(null, null); - - var request = buildInferenceRequest( - model.getConfigurations().getInferenceEntityId(), - configUpdate, - inputs, - inputType, - timeout, - true // chunk - ); - - client.execute( - InferModelAction.INSTANCE, - request, - listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse(translateChunkedResults(inferenceResult.getInferenceResults())) - ) - ); - } - - private void checkCompatibleTaskType(TaskType taskType) { - if (TaskType.SPARSE_EMBEDDING.isAnyOrSame(taskType) == false) { - throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); - } - } - - private static ElserMlNodeTaskSettings taskSettingsFromMap(TaskType taskType, Map config) { - if (taskType != TaskType.SPARSE_EMBEDDING) { - throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); - } - - // no config options yet - return ElserMlNodeTaskSettings.DEFAULT; - } - - private List translateChunkedResults(List inferenceResults) { - var translated = new ArrayList(); - - for (var inferenceResult : inferenceResults) { - if (inferenceResult instanceof MlChunkedTextExpansionResults mlChunkedResult) { - translated.add(InferenceChunkedSparseEmbeddingResults.ofMlResult(mlChunkedResult)); - } else if (inferenceResult instanceof ErrorInferenceResults error) { - translated.add(new ErrorChunkedInferenceResults(error.getException())); - } else { - throw new ElasticsearchStatusException( - "Expected a chunked inference [{}] received [{}]", - RestStatus.INTERNAL_SERVER_ERROR, - MlChunkedTextExpansionResults.NAME, - inferenceResult.getWriteableName() - ); - } - } - return translated; - } - - @Override - public String name() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_8_12_0; - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java index 5a1922fd200f5..03613901c7816 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java @@ -16,8 +16,8 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettings; public class ModelConfigurationsTests extends AbstractWireSerializingTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java index af13ce7944685..c9f4234331221 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java @@ -11,7 +11,7 @@ import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.elser.ElserModels; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; public class ElasticInferenceServiceSparseEmbeddingsModelTests extends ESTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java index a2b36cf9abdd5..1751e1c3be5e8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java @@ -16,13 +16,13 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; -import org.elasticsearch.xpack.inference.services.elser.ElserModels; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; import java.io.IOException; import java.util.HashMap; import java.util.Map; -import static org.elasticsearch.xpack.inference.services.elser.ElserModelsTests.randomElserModel; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModelsTests.randomElserModel; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 1e88d85dcc3f2..4ed60d93efe0c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -36,7 +36,7 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.inference.services.ServiceFields; -import org.elasticsearch.xpack.inference.services.elser.ElserModels; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java index 41afef88d22c6..419db748d793d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings; import java.io.IOException; import java.util.HashMap; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index de9298f1b08dd..cd6da4c0ad8d8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -9,10 +9,12 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; +import org.apache.logging.log4j.Level; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -69,6 +71,8 @@ import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -97,9 +101,11 @@ public void shutdownThreadPool() { } public void testParseRequestConfig() { + // Null model variant var service = createService(mock(Client.class)); - var settings = new HashMap(); - settings.put( + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, ElasticsearchInternalService.NAME); + config.put( ModelConfigurations.SERVICE_SETTINGS, new HashMap<>( Map.of(ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, 1, ElasticsearchInternalServiceSettings.NUM_THREADS, 4) @@ -112,15 +118,16 @@ public void testParseRequestConfig() { ); var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, modelListener); + service.parseRequestConfig(randomInferenceEntityId, taskType, config, modelListener); } public void testParseRequestConfig_Misconfigured() { - // Null model variant + // Non-existent model variant { var service = createService(mock(Client.class)); - var settings = new HashMap(); - settings.put( + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, ElasticsearchInternalService.NAME); + config.put( ModelConfigurations.SERVICE_SETTINGS, new HashMap<>( Map.of(ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, 1, ElasticsearchInternalServiceSettings.NUM_THREADS, 4) @@ -133,20 +140,21 @@ public void testParseRequestConfig_Misconfigured() { ); var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, modelListener); + service.parseRequestConfig(randomInferenceEntityId, taskType, config, modelListener); } // Invalid config map { var service = createService(mock(Client.class)); - var settings = new HashMap(); - settings.put( + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, ElasticsearchInternalService.NAME); + config.put( ModelConfigurations.SERVICE_SETTINGS, new HashMap<>( Map.of(ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, 1, ElasticsearchInternalServiceSettings.NUM_THREADS, 4) ) ); - settings.put("not_a_valid_config_setting", randomAlphaOfLength(10)); + config.put("not_a_valid_config_setting", randomAlphaOfLength(10)); ActionListener modelListener = ActionListener.wrap( model -> fail("Model parsing should have failed"), @@ -154,7 +162,7 @@ public void testParseRequestConfig_Misconfigured() { ); var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, modelListener); + service.parseRequestConfig(randomInferenceEntityId, taskType, config, modelListener); } } @@ -182,7 +190,7 @@ public void testParseRequestConfig_E5() { randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings, - getModelVerificationActionListener(e5ServiceSettings) + getE5ModelVerificationActionListener(e5ServiceSettings) ); } @@ -214,7 +222,7 @@ public void testParseRequestConfig_E5() { randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings, - getModelVerificationActionListener(e5ServiceSettings) + getE5ModelVerificationActionListener(e5ServiceSettings) ); } @@ -247,6 +255,106 @@ public void testParseRequestConfig_E5() { } } + public void testParseRequestConfig_elser() { + // General happy case + { + Client mockClient = mock(Client.class); + when(mockClient.threadPool()).thenReturn(threadPool); + var service = createService(mockClient); + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, OLD_ELSER_SERVICE_NAME); + config.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, + 1, + ElasticsearchInternalServiceSettings.NUM_THREADS, + 4, + ElasticsearchInternalServiceSettings.MODEL_ID, + ElserModels.ELSER_V2_MODEL + ) + ) + ); + + var elserServiceSettings = new ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null); + + service.parseRequestConfig( + randomInferenceEntityId, + TaskType.SPARSE_EMBEDDING, + config, + getElserModelVerificationActionListener( + elserServiceSettings, + null, + "The [elser] service is deprecated and will be removed in a future release. Use the [elasticsearch] service " + + "instead, with [model_id] set to [.elser_model_2] in the [service_settings]" + ) + ); + } + + // null model ID returns elser model for the provided platform (not linux) + { + Client mockClient = mock(Client.class); + when(mockClient.threadPool()).thenReturn(threadPool); + var service = createService(mockClient); + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, OLD_ELSER_SERVICE_NAME); + config.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of(ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, 1, ElasticsearchInternalServiceSettings.NUM_THREADS, 4) + ) + ); + + var elserServiceSettings = new ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null); + + String criticalWarning = + "Putting elasticsearch service inference endpoints (including elser service) without a model_id field is" + + " deprecated and will be removed in a future release. Please specify a model_id field."; + String warnWarning = + "The [elser] service is deprecated and will be removed in a future release. Use the [elasticsearch] service " + + "instead, with [model_id] set to [.elser_model_2] in the [service_settings]"; + service.parseRequestConfig( + randomInferenceEntityId, + TaskType.SPARSE_EMBEDDING, + config, + getElserModelVerificationActionListener(elserServiceSettings, criticalWarning, warnWarning) + ); + assertWarnings(true, new DeprecationWarning(DeprecationLogger.CRITICAL, criticalWarning)); + } + + // Invalid service settings + { + Client mockClient = mock(Client.class); + when(mockClient.threadPool()).thenReturn(threadPool); + var service = createService(mockClient); + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, OLD_ELSER_SERVICE_NAME); + config.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, + 1, + ElasticsearchInternalServiceSettings.NUM_THREADS, + 4, + ElasticsearchInternalServiceSettings.MODEL_ID, + ElserModels.ELSER_V2_MODEL, + "not_a_valid_service_setting", + randomAlphaOfLength(10) + ) + ) + ); + + ActionListener modelListener = ActionListener.wrap( + model -> fail("Model parsing should have failed"), + e -> assertThat(e, instanceOf(ElasticsearchStatusException.class)) + ); + + service.parseRequestConfig(randomInferenceEntityId, TaskType.SPARSE_EMBEDDING, config, modelListener); + } + } + @SuppressWarnings("unchecked") public void testParseRequestConfig_Rerank() { // with task settings @@ -374,7 +482,7 @@ public void testParseRequestConfig_SparseEmbedding() { service.parseRequestConfig(randomInferenceEntityId, TaskType.SPARSE_EMBEDDING, settings, modelListener); } - private ActionListener getModelVerificationActionListener(MultilingualE5SmallInternalServiceSettings e5ServiceSettings) { + private ActionListener getE5ModelVerificationActionListener(MultilingualE5SmallInternalServiceSettings e5ServiceSettings) { return ActionListener.wrap(model -> { assertEquals( new MultilingualE5SmallModel( @@ -388,6 +496,30 @@ private ActionListener getModelVerificationActionListener(MultilingualE5S }, e -> { fail("Model parsing failed " + e.getMessage()); }); } + private ActionListener getElserModelVerificationActionListener( + ElserInternalServiceSettings elserServiceSettings, + String criticalWarning, + String warnWarning + ) { + return ActionListener.wrap(model -> { + assertWarnings( + true, + new DeprecationWarning(DeprecationLogger.CRITICAL, criticalWarning), + new DeprecationWarning(Level.WARN, warnWarning) + ); + assertEquals( + new ElserInternalModel( + randomInferenceEntityId, + TaskType.SPARSE_EMBEDDING, + NAME, + elserServiceSettings, + ElserMlNodeTaskSettings.DEFAULT + ), + model + ); + }, e -> { fail("Model parsing failed " + e.getMessage()); }); + } + public void testParsePersistedConfig() { // Null model variant diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettingsTests.java similarity index 89% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettingsTests.java index ffbdf1a5a6178..f4e97b2c2e5e0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettingsTests.java @@ -5,18 +5,16 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.elser; +package org.elasticsearch.xpack.inference.services.elasticsearch; import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettingsTests; import java.io.IOException; import java.util.HashSet; -import static org.elasticsearch.xpack.inference.services.elser.ElserModelsTests.randomElserModel; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModelsTests.randomElserModel; public class ElserInternalServiceSettingsTests extends AbstractWireSerializingTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettingsTests.java similarity index 93% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettingsTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettingsTests.java index d55065a5f9b27..a7de3fe8b8fdc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettingsTests.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.elser; +package org.elasticsearch.xpack.inference.services.elasticsearch; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java new file mode 100644 index 0000000000000..fa0148ac69df5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elasticsearch; + +import org.elasticsearch.test.ESTestCase; + +public class ElserModelsTests extends ESTestCase { + + public static String randomElserModel() { + return randomFrom(org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.VALID_ELSER_MODEL_IDS); + } + + public void testIsValidModel() { + assertTrue(org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidModel(randomElserModel())); + } + + public void testIsValidEisModel() { + assertTrue( + org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidEisModel( + org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL + ) + ); + } + + public void testIsInvalidModel() { + assertFalse(org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidModel("invalid")); + } + + public void testIsInvalidEisModel() { + assertFalse( + org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidEisModel(ElserModels.ELSER_V2_MODEL_LINUX_X86) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java deleted file mode 100644 index 09abeb9b9b389..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java +++ /dev/null @@ -1,548 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - * - * this file was contributed to by a generative AI - */ - -package org.elasticsearch.xpack.inference.services.elser; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.ChunkingOptions; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceServiceExtension; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.action.InferModelAction; -import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; -import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceChunkedTextExpansionResultsTests; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; -import org.elasticsearch.xpack.inference.InferencePlugin; -import org.junit.After; -import org.junit.Before; -import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; - -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -public class ElserInternalServiceTests extends ESTestCase { - - private static ThreadPool threadPool; - - @Before - public void setUpThreadPool() { - threadPool = createThreadPool(InferencePlugin.inferenceUtilityExecutor(Settings.EMPTY)); - } - - @After - public void shutdownThreadPool() { - TestThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); - } - - public static Model randomModelConfig(String inferenceEntityId, TaskType taskType) { - return switch (taskType) { - case SPARSE_EMBEDDING -> new ElserInternalModel( - inferenceEntityId, - taskType, - ElserInternalService.NAME, - ElserInternalServiceSettingsTests.createRandom(), - ElserMlNodeTaskSettingsTests.createRandom() - ); - default -> throw new IllegalArgumentException("task type " + taskType + " is not supported"); - }; - } - - public void testParseConfigStrict() { - var service = createService(mock(Client.class)); - - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - ElserInternalServiceSettings.NUM_THREADS, - 4, - "model_id", - ".elser_model_1" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); - - var expectedModel = new ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, - new ElserInternalServiceSettings(1, 4, ".elser_model_1", null), - ElserMlNodeTaskSettings.DEFAULT - ); - - var modelVerificationListener = getModelVerificationListener(expectedModel); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, modelVerificationListener); - - } - - public void testParseConfigWithoutModelId() { - Client mockClient = mock(Client.class); - when(mockClient.threadPool()).thenReturn(threadPool); - var service = createService(mockClient); - - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>(Map.of(ElserInternalServiceSettings.NUM_ALLOCATIONS, 1, ElserInternalServiceSettings.NUM_THREADS, 4)) - ); - - var expectedModel = new ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, - new ElserInternalServiceSettings(1, 4, ".elser_model_2", null), - ElserMlNodeTaskSettings.DEFAULT - ); - - var modelVerificationListener = getModelVerificationListener(expectedModel); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, modelVerificationListener); - - } - - public void testParseConfigLooseWithOldModelId() { - var service = createService(mock(Client.class)); - - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - ElserInternalServiceSettings.NUM_THREADS, - 4, - "model_version", - ".elser_model_1" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); - - var expectedModel = new ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, - new ElserInternalServiceSettings(1, 4, ".elser_model_1", null), - ElserMlNodeTaskSettings.DEFAULT - ); - - var realModel = service.parsePersistedConfig("foo", TaskType.SPARSE_EMBEDDING, settings); - - assertEquals(expectedModel, realModel); - - } - - private static ActionListener getModelVerificationListener(ElserInternalModel expectedModel) { - return ActionListener.wrap( - (model) -> { assertEquals(expectedModel, model); }, - (e) -> fail("Model verification should not fail " + e.getMessage()) - ); - } - - public void testParseConfigStrictWithNoTaskSettings() { - var service = createService(mock(Client.class), Set.of("Aarch64")); - - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>(Map.of(ElserInternalServiceSettings.NUM_ALLOCATIONS, 1, ElserInternalServiceSettings.NUM_THREADS, 4)) - ); - - var expectedModel = new ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, - new ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null), - ElserMlNodeTaskSettings.DEFAULT - ); - - var modelVerificationListener = getModelVerificationListener(expectedModel); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, modelVerificationListener); - } - - public void testParseConfigStrictWithUnknownSettings() { - - var service = createService(mock(Client.class)); - - for (boolean throwOnUnknown : new boolean[] { true, false }) { - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - ElserInternalServiceSettings.NUM_THREADS, - 4, - ElserInternalServiceSettings.MODEL_ID, - ".elser_model_2" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); - settings.put("foo", "bar"); - - ActionListener errorVerificationListener = ActionListener.wrap((model) -> { - if (throwOnUnknown) { - fail("Model verification should fail when throwOnUnknown is true"); - } - }, (e) -> { - if (throwOnUnknown) { - assertThat( - e.getMessage(), - containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") - ); - } else { - fail("Model verification should not fail when throwOnUnknown is false"); - } - }); - - if (throwOnUnknown == false) { - var parsed = service.parsePersistedConfigWithSecrets( - "foo", - TaskType.SPARSE_EMBEDDING, - settings, - Collections.emptyMap() - ); - } else { - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, errorVerificationListener); - } - } - - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - ElserInternalServiceSettings.NUM_THREADS, - 4, - ElserInternalServiceSettings.MODEL_ID, - ".elser_model_2" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of("foo", "bar")); - - ActionListener errorVerificationListener = ActionListener.wrap((model) -> { - if (throwOnUnknown) { - fail("Model verification should fail when throwOnUnknown is true"); - } - }, (e) -> { - if (throwOnUnknown) { - assertThat( - e.getMessage(), - containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") - ); - } else { - fail("Model verification should not fail when throwOnUnknown is false"); - } - }); - if (throwOnUnknown == false) { - var parsed = service.parsePersistedConfigWithSecrets( - "foo", - TaskType.SPARSE_EMBEDDING, - settings, - Collections.emptyMap() - ); - } else { - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, errorVerificationListener); - } - } - - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - ElserInternalServiceSettings.NUM_THREADS, - 4, - ElserInternalServiceSettings.MODEL_ID, - ".elser_model_2", - "foo", - "bar" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of("foo", "bar")); - - ActionListener errorVerificationListener = ActionListener.wrap((model) -> { - if (throwOnUnknown) { - fail("Model verification should fail when throwOnUnknown is true"); - } - }, (e) -> { - if (throwOnUnknown) { - assertThat( - e.getMessage(), - containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") - ); - } else { - fail("Model verification should not fail when throwOnUnknown is false"); - } - }); - if (throwOnUnknown == false) { - var parsed = service.parsePersistedConfigWithSecrets( - "foo", - TaskType.SPARSE_EMBEDDING, - settings, - Collections.emptyMap() - ); - } else { - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, errorVerificationListener); - } - } - } - } - - public void testParseRequestConfig_DefaultModel() { - { - var service = createService(mock(Client.class), Set.of()); - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>(Map.of(ElserInternalServiceSettings.NUM_ALLOCATIONS, 1, ElserInternalServiceSettings.NUM_THREADS, 4)) - ); - - ActionListener modelActionListener = ActionListener.wrap((model) -> { - assertEquals(".elser_model_2", ((ElserInternalModel) model).getServiceSettings().modelId()); - }, (e) -> { fail(e, "Model verification should not fail"); }); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, modelActionListener); - } - { - var service = createService(mock(Client.class), Set.of("linux-x86_64")); - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>(Map.of(ElserInternalServiceSettings.NUM_ALLOCATIONS, 1, ElserInternalServiceSettings.NUM_THREADS, 4)) - ); - - ActionListener modelActionListener = ActionListener.wrap((model) -> { - assertEquals(".elser_model_2_linux-x86_64", ((ElserInternalModel) model).getServiceSettings().modelId()); - }, (e) -> { fail(e, "Model verification should not fail"); }); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, modelActionListener); - } - } - - @SuppressWarnings("unchecked") - public void testChunkInfer() { - var mlTrainedModelResults = new ArrayList(); - mlTrainedModelResults.add(InferenceChunkedTextExpansionResultsTests.createRandomResults()); - mlTrainedModelResults.add(InferenceChunkedTextExpansionResultsTests.createRandomResults()); - mlTrainedModelResults.add(new ErrorInferenceResults(new RuntimeException("boom"))); - var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); - - ThreadPool threadpool = new TestThreadPool("test"); - Client client = mock(Client.class); - when(client.threadPool()).thenReturn(threadpool); - doAnswer(invocationOnMock -> { - var listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(response); - return null; - }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); - - var model = new ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - "elser", - new ElserInternalServiceSettings(1, 1, "elser", null), - new ElserMlNodeTaskSettings() - ); - var service = createService(client); - - var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { - assertThat(chunkedResponse, hasSize(3)); - assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result1 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(0); - assertEquals(((MlChunkedTextExpansionResults) mlTrainedModelResults.get(0)).getChunks(), result1.getChunkedResults()); - assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result2 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(1); - assertEquals(((MlChunkedTextExpansionResults) mlTrainedModelResults.get(1)).getChunks(), result2.getChunkedResults()); - var result3 = (ErrorChunkedInferenceResults) chunkedResponse.get(2); - assertThat(result3.getException(), instanceOf(RuntimeException.class)); - assertThat(result3.getException().getMessage(), containsString("boom")); - gotResults.set(true); - }, ESTestCase::fail); - - service.chunkedInfer( - model, - null, - List.of("foo", "bar"), - Map.of(), - InputType.SEARCH, - new ChunkingOptions(null, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.runAfter(resultsListener, () -> terminate(threadpool)) - ); - - if (gotResults.get() == false) { - terminate(threadpool); - } - assertTrue("Listener not called", gotResults.get()); - } - - @SuppressWarnings("unchecked") - public void testChunkInferSetsTokenization() { - var expectedSpan = new AtomicInteger(); - var expectedWindowSize = new AtomicReference(); - - ThreadPool threadpool = new TestThreadPool("test"); - Client client = mock(Client.class); - try { - when(client.threadPool()).thenReturn(threadpool); - doAnswer(invocationOnMock -> { - var request = (InferTrainedModelDeploymentAction.Request) invocationOnMock.getArguments()[1]; - assertThat(request.getUpdate(), instanceOf(TokenizationConfigUpdate.class)); - var update = (TokenizationConfigUpdate) request.getUpdate(); - assertEquals(update.getSpanSettings().span(), expectedSpan.get()); - assertEquals(update.getSpanSettings().maxSequenceLength(), expectedWindowSize.get()); - return null; - }).when(client) - .execute( - same(InferTrainedModelDeploymentAction.INSTANCE), - any(InferTrainedModelDeploymentAction.Request.class), - any(ActionListener.class) - ); - - var model = new ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - "elser", - new ElserInternalServiceSettings(1, 1, "elser", null), - new ElserMlNodeTaskSettings() - ); - var service = createService(client); - - expectedSpan.set(-1); - expectedWindowSize.set(null); - service.chunkedInfer( - model, - List.of("foo", "bar"), - Map.of(), - InputType.SEARCH, - null, - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) - ); - - expectedSpan.set(-1); - expectedWindowSize.set(256); - service.chunkedInfer( - model, - List.of("foo", "bar"), - Map.of(), - InputType.SEARCH, - new ChunkingOptions(256, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) - ); - } finally { - terminate(threadpool); - } - } - - @SuppressWarnings("unchecked") - public void testPutModel() { - var client = mock(Client.class); - ArgumentCaptor argument = ArgumentCaptor.forClass(PutTrainedModelAction.Request.class); - - doAnswer(invocation -> { - var listener = (ActionListener) invocation.getArguments()[2]; - listener.onResponse(new PutTrainedModelAction.Response(mock(TrainedModelConfig.class))); - return null; - }).when(client).execute(Mockito.same(PutTrainedModelAction.INSTANCE), argument.capture(), any()); - - when(client.threadPool()).thenReturn(threadPool); - - var service = createService(client); - - var model = new ElserInternalModel( - "my-elser", - TaskType.SPARSE_EMBEDDING, - "elser", - new ElserInternalServiceSettings(1, 1, ".elser_model_2", null), - ElserMlNodeTaskSettings.DEFAULT - ); - - service.putModel(model, new ActionListener<>() { - @Override - public void onResponse(Boolean success) { - assertTrue(success); - } - - @Override - public void onFailure(Exception e) { - fail(e); - } - }); - - var putConfig = argument.getValue().getTrainedModelConfig(); - assertEquals("text_field", putConfig.getInput().getFieldNames().get(0)); - } - - private ElserInternalService createService(Client client) { - var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool); - return new ElserInternalService(context); - } - - private ElserInternalService createService(Client client, Set architectures) { - var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool); - return new ElserInternalService(context, (l) -> l.onResponse(architectures)); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserModelsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserModelsTests.java deleted file mode 100644 index f56e941dcc8c0..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserModelsTests.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elser; - -import org.elasticsearch.test.ESTestCase; - -public class ElserModelsTests extends ESTestCase { - - public static String randomElserModel() { - return randomFrom(ElserModels.VALID_ELSER_MODEL_IDS); - } - - public void testIsValidModel() { - assertTrue(ElserModels.isValidModel(randomElserModel())); - } - - public void testIsValidEisModel() { - assertTrue(ElserModels.isValidEisModel(ElserModels.ELSER_V2_MODEL)); - } - - public void testIsInvalidModel() { - assertFalse(ElserModels.isValidModel("invalid")); - } - - public void testIsInvalidEisModel() { - assertFalse(ElserModels.isValidEisModel(ElserModels.ELSER_V2_MODEL_LINUX_X86)); - } -} From fc55a674f6b939c0a0f45f9a4ba09b87e4f11dda Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Mon, 7 Oct 2024 21:32:34 +0300 Subject: [PATCH 02/22] backporting rrf related test updates (#114257) --- muted-tests.yml | 3 - .../retriever/RankDocRetrieverBuilderIT.java | 756 ------------ .../xpack/rank/rrf/RRFRetrieverBuilderIT.java | 36 +- .../rrf/RRFRetrieverBuilderNestedDocsIT.java | 23 +- .../test/rrf/350_rrf_retriever_pagination.yml | 1009 ++++++++++------- ...rrf_retriever_search_api_compatibility.yml | 506 ++++++++- 6 files changed, 1098 insertions(+), 1235 deletions(-) delete mode 100644 server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RankDocRetrieverBuilderIT.java diff --git a/muted-tests.yml b/muted-tests.yml index 536863e7fedfd..d7f6a22b472fb 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -112,9 +112,6 @@ tests: - class: org.elasticsearch.xpack.ml.integration.MlJobIT method: testDeleteJobAsync issue: https://github.com/elastic/elasticsearch/issues/112212 -- class: org.elasticsearch.search.retriever.RankDocRetrieverBuilderIT - method: testRankDocsRetrieverWithCollapse - issue: https://github.com/elastic/elasticsearch/issues/112254 - class: org.elasticsearch.smoketest.DocsClientYamlTestSuiteIT method: test {yaml=reference/rest-api/watcher/put-watch/line_120} issue: https://github.com/elastic/elasticsearch/issues/99517 diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RankDocRetrieverBuilderIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RankDocRetrieverBuilderIT.java deleted file mode 100644 index b78448bfd873f..0000000000000 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RankDocRetrieverBuilderIT.java +++ /dev/null @@ -1,756 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.search.retriever; - -import org.apache.lucene.search.TotalHits; -import org.apache.lucene.search.join.ScoreMode; -import org.apache.lucene.util.SetOnce; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.search.MultiSearchRequest; -import org.elasticsearch.action.search.MultiSearchResponse; -import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.action.search.SearchRequestBuilder; -import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.action.search.TransportMultiSearchAction; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.Maps; -import org.elasticsearch.index.query.InnerHitBuilder; -import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.index.query.QueryRewriteContext; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.search.MockSearchService; -import org.elasticsearch.search.aggregations.bucket.terms.Terms; -import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; -import org.elasticsearch.search.builder.PointInTimeBuilder; -import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.collapse.CollapseBuilder; -import org.elasticsearch.search.rank.RankDoc; -import org.elasticsearch.search.sort.FieldSortBuilder; -import org.elasticsearch.search.sort.NestedSortBuilder; -import org.elasticsearch.search.sort.ScoreSortBuilder; -import org.elasticsearch.search.sort.ShardDocSortField; -import org.elasticsearch.search.sort.SortBuilder; -import org.elasticsearch.search.sort.SortOrder; -import org.elasticsearch.test.ESIntegTestCase; -import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentType; -import org.junit.Before; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; - -import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS; -import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; -import static org.hamcrest.Matchers.equalTo; - -public class RankDocRetrieverBuilderIT extends ESIntegTestCase { - - @Override - protected Collection> nodePlugins() { - return List.of(MockSearchService.TestPlugin.class); - } - - public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {} - - private static String INDEX = "test_index"; - private static final String ID_FIELD = "_id"; - private static final String DOC_FIELD = "doc"; - private static final String TEXT_FIELD = "text"; - private static final String VECTOR_FIELD = "vector"; - private static final String TOPIC_FIELD = "topic"; - private static final String LAST_30D_FIELD = "views.last30d"; - private static final String ALL_TIME_FIELD = "views.all"; - - @Before - public void setup() throws Exception { - String mapping = """ - { - "properties": { - "vector": { - "type": "dense_vector", - "dims": 3, - "element_type": "float", - "index": true, - "similarity": "l2_norm", - "index_options": { - "type": "hnsw" - } - }, - "text": { - "type": "text" - }, - "doc": { - "type": "keyword" - }, - "topic": { - "type": "keyword" - }, - "views": { - "type": "nested", - "properties": { - "last30d": { - "type": "integer" - }, - "all": { - "type": "integer" - } - } - } - } - } - """; - createIndex(INDEX, Settings.builder().put(SETTING_NUMBER_OF_SHARDS, 1).put(SETTING_NUMBER_OF_REPLICAS, 0).build()); - admin().indices().preparePutMapping(INDEX).setSource(mapping, XContentType.JSON).get(); - indexDoc( - INDEX, - "doc_1", - DOC_FIELD, - "doc_1", - TOPIC_FIELD, - "technology", - TEXT_FIELD, - "the quick brown fox jumps over the lazy dog", - LAST_30D_FIELD, - 100 - ); - indexDoc( - INDEX, - "doc_2", - DOC_FIELD, - "doc_2", - TOPIC_FIELD, - "astronomy", - TEXT_FIELD, - "you know, for Search!", - VECTOR_FIELD, - new float[] { 1.0f, 2.0f, 3.0f }, - LAST_30D_FIELD, - 3 - ); - indexDoc(INDEX, "doc_3", DOC_FIELD, "doc_3", TOPIC_FIELD, "technology", VECTOR_FIELD, new float[] { 6.0f, 6.0f, 6.0f }); - indexDoc( - INDEX, - "doc_4", - DOC_FIELD, - "doc_4", - TOPIC_FIELD, - "technology", - TEXT_FIELD, - "aardvark is a really awesome animal, but not very quick", - ALL_TIME_FIELD, - 100, - LAST_30D_FIELD, - 40 - ); - indexDoc(INDEX, "doc_5", DOC_FIELD, "doc_5", TOPIC_FIELD, "science", TEXT_FIELD, "irrelevant stuff"); - indexDoc( - INDEX, - "doc_6", - DOC_FIELD, - "doc_6", - TEXT_FIELD, - "quick quick quick quick search", - VECTOR_FIELD, - new float[] { 10.0f, 30.0f, 100.0f }, - LAST_30D_FIELD, - 15 - ); - indexDoc( - INDEX, - "doc_7", - DOC_FIELD, - "doc_7", - TOPIC_FIELD, - "biology", - TEXT_FIELD, - "dog", - VECTOR_FIELD, - new float[] { 3.0f, 3.0f, 3.0f }, - ALL_TIME_FIELD, - 1000 - ); - refresh(INDEX); - } - - public void testRankDocsRetrieverBasicWithPagination() { - final int rankWindowSize = 100; - SearchSourceBuilder source = new SearchSourceBuilder(); - StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); - // this one retrieves docs 1, 4, and 6 - standard0.queryBuilder = QueryBuilders.boolQuery() - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(9L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(8L)); - StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(); - // this one retrieves docs 2 and 6 due to prefilter - standard1.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(ID_FIELD, "doc_2", "doc_3", "doc_6")).boost(20L); - standard1.preFilterQueryBuilders.add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 7, 2, 3, and 6 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( - VECTOR_FIELD, - new float[] { 3.0f, 3.0f, 3.0f }, - null, - 10, - 100, - null - ); - // the compound retriever here produces a score for a doc based on the percentage of the queries that it was matched on and - // resolves ties based on actual score, and then the doc (we're forcing 1 shard for consistent results) - // so ideal rank would be: 6, 2, 1, 3, 4, 7 and with pagination, we'd just omit the first result - source.retriever( - new CompoundRetrieverWithRankDocs( - rankWindowSize, - Arrays.asList( - new RetrieverSource(standard0, null), - new RetrieverSource(standard1, null), - new RetrieverSource(knnRetrieverBuilder, null) - ) - ) - ); - // include some pagination as well - source.from(1); - SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); - ElasticsearchAssertions.assertResponse(req, resp -> { - assertNull(resp.pointInTimeId()); - assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); - assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); - assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); - assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_3")); - assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_4")); - assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_7")); - }); - } - - public void testRankDocsRetrieverWithAggs() { - // same as above, but we only want to bring back the top result from each subsearch - // so that would be 1, 2, and 7 - // and final rank would be (based on score): 2, 1, 7 - // aggs should still account for the same docs as the testRankDocsRetriever test, i.e. all but doc_5 - final int rankWindowSize = 1; - SearchSourceBuilder source = new SearchSourceBuilder(); - StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); - // this one retrieves docs 1, 4, and 6 - standard0.queryBuilder = QueryBuilders.boolQuery() - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(9L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(8L)); - StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(); - // this one retrieves docs 2 and 6 due to prefilter - standard1.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(ID_FIELD, "doc_2", "doc_3", "doc_6")).boost(20L); - standard1.preFilterQueryBuilders.add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 7, 2, 3, and 6 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( - VECTOR_FIELD, - new float[] { 3.0f, 3.0f, 3.0f }, - null, - 10, - 100, - null - ); - source.retriever( - new CompoundRetrieverWithRankDocs( - rankWindowSize, - Arrays.asList( - new RetrieverSource(standard0, null), - new RetrieverSource(standard1, null), - new RetrieverSource(knnRetrieverBuilder, null) - ) - ) - ); - source.size(1); - source.aggregation(new TermsAggregationBuilder("topic").field(TOPIC_FIELD)); - SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); - ElasticsearchAssertions.assertResponse(req, resp -> { - assertNull(resp.pointInTimeId()); - assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value, equalTo(5L)); - assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); - assertThat(resp.getHits().getHits().length, equalTo(1)); - assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); - assertNotNull(resp.getAggregations()); - assertNotNull(resp.getAggregations().get("topic")); - Terms terms = resp.getAggregations().get("topic"); - // doc_3 is not part of the final aggs computation as it is only retrieved through the knn retriever - // and is outside of the rank window - assertThat(terms.getBucketByKey("technology").getDocCount(), equalTo(2L)); - assertThat(terms.getBucketByKey("astronomy").getDocCount(), equalTo(1L)); - assertThat(terms.getBucketByKey("biology").getDocCount(), equalTo(1L)); - }); - } - - public void testRankDocsRetrieverWithCollapse() { - final int rankWindowSize = 100; - SearchSourceBuilder source = new SearchSourceBuilder(); - StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); - // this one retrieves docs 1, 4, and 6 - standard0.queryBuilder = QueryBuilders.boolQuery() - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(9L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(8L)); - StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(); - // this one retrieves docs 2 and 6 due to prefilter - standard1.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(ID_FIELD, "doc_2", "doc_3", "doc_6")).boost(20L); - standard1.preFilterQueryBuilders.add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 7, 2, 3, and 6 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( - VECTOR_FIELD, - new float[] { 3.0f, 3.0f, 3.0f }, - null, - 10, - 100, - null - ); - // the compound retriever here produces a score for a doc based on the percentage of the queries that it was matched on and - // resolves ties based on actual score, and then the doc (we're forcing 1 shard for consistent results) - // so ideal rank would be: 6, 2, 1, 3, 4, 7 - // with collapsing on topic field we would have 6, 2, 1, 7 - source.retriever( - new CompoundRetrieverWithRankDocs( - rankWindowSize, - Arrays.asList( - new RetrieverSource(standard0, null), - new RetrieverSource(standard1, null), - new RetrieverSource(knnRetrieverBuilder, null) - ) - ) - ); - source.collapse( - new CollapseBuilder(TOPIC_FIELD).setInnerHits( - new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10) - ) - ); - source.fetchField(TOPIC_FIELD); - SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); - ElasticsearchAssertions.assertResponse(req, resp -> { - assertNull(resp.pointInTimeId()); - assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); - assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); - assertThat(resp.getHits().getHits().length, equalTo(4)); - assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_6")); - assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_2")); - assertThat(resp.getHits().getAt(1).field(TOPIC_FIELD).getValue().toString(), equalTo("astronomy")); - assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(2).field(TOPIC_FIELD).getValue().toString(), equalTo("technology")); - assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4")); - assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3")); - assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_7")); - assertThat(resp.getHits().getAt(3).field(TOPIC_FIELD).getValue().toString(), equalTo("biology")); - }); - } - - public void testRankDocsRetrieverWithNestedCollapseAndAggs() { - final int rankWindowSize = 10; - SearchSourceBuilder source = new SearchSourceBuilder(); - StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); - // this one retrieves docs 1 and 6 as doc_4 is collapsed to doc_1 - standard0.queryBuilder = QueryBuilders.boolQuery() - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(9L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(8L)); - standard0.collapseBuilder = new CollapseBuilder(TOPIC_FIELD).setInnerHits( - new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10) - ); - StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(); - // this one retrieves docs 2 and 6 due to prefilter - standard1.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(ID_FIELD, "doc_2", "doc_3", "doc_6")).boost(20L); - standard1.preFilterQueryBuilders.add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 7, 2, 3, and 6 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( - VECTOR_FIELD, - new float[] { 3.0f, 3.0f, 3.0f }, - null, - 10, - 100, - null - ); - // the compound retriever here produces a score for a doc based on the percentage of the queries that it was matched on and - // resolves ties based on actual score, and then the doc (we're forcing 1 shard for consistent results) - // so ideal rank would be: 6, 2, 1, 3, 4, 7 - source.retriever( - new CompoundRetrieverWithRankDocs( - rankWindowSize, - Arrays.asList( - new RetrieverSource(standard0, null), - new RetrieverSource(standard1, null), - new RetrieverSource(knnRetrieverBuilder, null) - ) - ) - ); - source.aggregation(new TermsAggregationBuilder("topic").field(TOPIC_FIELD)); - SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); - ElasticsearchAssertions.assertResponse(req, resp -> { - assertNull(resp.pointInTimeId()); - assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); - assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); - assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_6")); - assertNotNull(resp.getAggregations()); - assertNotNull(resp.getAggregations().get("topic")); - Terms terms = resp.getAggregations().get("topic"); - // doc_3 is not part of the final aggs computation as it is only retrieved through the knn retriever - // and is outside of the rank window - assertThat(terms.getBucketByKey("technology").getDocCount(), equalTo(3L)); - assertThat(terms.getBucketByKey("astronomy").getDocCount(), equalTo(1L)); - assertThat(terms.getBucketByKey("biology").getDocCount(), equalTo(1L)); - }); - } - - public void testRankDocsRetrieverWithNestedQuery() { - final int rankWindowSize = 100; - SearchSourceBuilder source = new SearchSourceBuilder(); - StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); - // this one retrieves docs 1, 4, and 6 - standard0.queryBuilder = QueryBuilders.nestedQuery("views", QueryBuilders.rangeQuery(LAST_30D_FIELD).gt(10L), ScoreMode.Avg); - StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(); - // this one retrieves docs 2 and 6 due to prefilter - standard1.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(ID_FIELD, "doc_2", "doc_3", "doc_6")).boost(20L); - standard1.preFilterQueryBuilders.add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 7, 2, 3, and 6 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( - VECTOR_FIELD, - new float[] { 3.0f, 3.0f, 3.0f }, - null, - 10, - 100, - null - ); - // the compound retriever here produces a score for a doc based on the percentage of the queries that it was matched on and - // resolves ties based on actual score, and then the doc (we're forcing 1 shard for consistent results) - // so ideal rank would be: 6, 2, 1, 3, 4, 7 - source.retriever( - new CompoundRetrieverWithRankDocs( - rankWindowSize, - Arrays.asList( - new RetrieverSource(standard0, null), - new RetrieverSource(standard1, null), - new RetrieverSource(knnRetrieverBuilder, null) - ) - ) - ); - source.fetchField(TOPIC_FIELD); - SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); - ElasticsearchAssertions.assertResponse(req, resp -> { - assertNull(resp.pointInTimeId()); - assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); - assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); - assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_6")); - assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_2")); - assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_3")); - assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_4")); - assertThat(resp.getHits().getAt(5).getId(), equalTo("doc_7")); - }); - } - - public void testRankDocsRetrieverMultipleCompoundRetrievers() { - final int rankWindowSize = 100; - SearchSourceBuilder source = new SearchSourceBuilder(); - StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); - // this one retrieves docs 1, 4, and 6 - standard0.queryBuilder = QueryBuilders.boolQuery() - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(9L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(8L)); - StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(); - // this one retrieves docs 2 and 6 due to prefilter - standard1.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(ID_FIELD, "doc_2", "doc_3", "doc_6")).boost(20L); - standard1.preFilterQueryBuilders.add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 7, 2, 3, and 6 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( - VECTOR_FIELD, - new float[] { 3.0f, 3.0f, 3.0f }, - null, - 10, - 100, - null - ); - // the compound retriever here produces a score for a doc based on the percentage of the queries that it was matched on and - // resolves ties based on actual score, rank, and then the doc (we're forcing 1 shard for consistent results) - // so ideal rank would be: 6, 2, 1, 4, 7, 3 - CompoundRetrieverWithRankDocs compoundRetriever1 = new CompoundRetrieverWithRankDocs( - rankWindowSize, - Arrays.asList( - new RetrieverSource(standard0, null), - new RetrieverSource(standard1, null), - new RetrieverSource(knnRetrieverBuilder, null) - ) - ); - // simple standard retriever that would have the doc_4 as its first (and only) result - StandardRetrieverBuilder standard2 = new StandardRetrieverBuilder(); - standard2.queryBuilder = QueryBuilders.queryStringQuery("aardvark").defaultField(TEXT_FIELD); - - // combining the two retrievers would bring doc_4 at the top as it would be the only one present in both doc sets - // the rest of the docs would be sorted based on their ranks as they have the same score (1/2) - source.retriever( - new CompoundRetrieverWithRankDocs( - rankWindowSize, - Arrays.asList(new RetrieverSource(compoundRetriever1, null), new RetrieverSource(standard2, null)) - ) - ); - - SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); - ElasticsearchAssertions.assertResponse(req, resp -> { - assertNull(resp.pointInTimeId()); - assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); - assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); - assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_4")); - assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_2")); - assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_3")); - assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_6")); - assertThat(resp.getHits().getAt(5).getId(), equalTo("doc_7")); - }); - } - - public void testRankDocsRetrieverDifferentNestedSorting() { - final int rankWindowSize = 100; - SearchSourceBuilder source = new SearchSourceBuilder(); - StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); - // this one retrieves docs 1, 4, 6, 2 - standard0.queryBuilder = QueryBuilders.nestedQuery("views", QueryBuilders.rangeQuery(LAST_30D_FIELD).gt(0), ScoreMode.Avg); - standard0.sortBuilders = List.of( - new FieldSortBuilder(LAST_30D_FIELD).setNestedSort(new NestedSortBuilder("views")).order(SortOrder.DESC) - ); - StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(); - // this one retrieves docs 4, 7 - standard1.queryBuilder = QueryBuilders.nestedQuery("views", QueryBuilders.rangeQuery(ALL_TIME_FIELD).gt(0), ScoreMode.Avg); - standard1.sortBuilders = List.of( - new FieldSortBuilder(ALL_TIME_FIELD).setNestedSort(new NestedSortBuilder("views")).order(SortOrder.ASC) - ); - - source.retriever( - new CompoundRetrieverWithRankDocs( - rankWindowSize, - Arrays.asList(new RetrieverSource(standard0, null), new RetrieverSource(standard1, null)) - ) - ); - - SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); - ElasticsearchAssertions.assertResponse(req, resp -> { - assertNull(resp.pointInTimeId()); - assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value, equalTo(5L)); - assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); - assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_4")); - assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_2")); - assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_6")); - assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_7")); - }); - } - - class CompoundRetrieverWithRankDocs extends RetrieverBuilder { - - private final List sources; - private final int rankWindowSize; - - private CompoundRetrieverWithRankDocs(int rankWindowSize, List sources) { - this.rankWindowSize = rankWindowSize; - this.sources = Collections.unmodifiableList(sources); - } - - @Override - public boolean isCompound() { - return true; - } - - @Override - public QueryBuilder topDocsQuery() { - throw new UnsupportedOperationException("should not be called"); - } - - @Override - public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { - if (ctx.getPointInTimeBuilder() == null) { - throw new IllegalStateException("PIT is required"); - } - - // Rewrite prefilters - boolean hasChanged = false; - var newPreFilters = rewritePreFilters(ctx); - hasChanged |= newPreFilters != preFilterQueryBuilders; - - // Rewrite retriever sources - List newRetrievers = new ArrayList<>(); - for (var entry : sources) { - RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx); - if (newRetriever != entry.retriever) { - newRetrievers.add(new RetrieverSource(newRetriever, null)); - hasChanged |= newRetriever != entry.retriever; - } else if (newRetriever == entry.retriever) { - var sourceBuilder = entry.source != null - ? entry.source - : createSearchSourceBuilder(ctx.getPointInTimeBuilder(), newRetriever); - var rewrittenSource = sourceBuilder.rewrite(ctx); - newRetrievers.add(new RetrieverSource(newRetriever, rewrittenSource)); - hasChanged |= rewrittenSource != entry.source; - } - } - if (hasChanged) { - return new CompoundRetrieverWithRankDocs(rankWindowSize, newRetrievers); - } - - // execute searches - final SetOnce results = new SetOnce<>(); - final MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); - for (var entry : sources) { - SearchRequest searchRequest = new SearchRequest().source(entry.source); - // The can match phase can reorder shards, so we disable it to ensure the stable ordering - searchRequest.setPreFilterShardSize(Integer.MAX_VALUE); - multiSearchRequest.add(searchRequest); - } - ctx.registerAsyncAction((client, listener) -> { - client.execute(TransportMultiSearchAction.TYPE, multiSearchRequest, new ActionListener<>() { - @Override - public void onResponse(MultiSearchResponse items) { - List topDocs = new ArrayList<>(); - for (int i = 0; i < items.getResponses().length; i++) { - var item = items.getResponses()[i]; - var rankDocs = getRankDocs(item.getResponse()); - sources.get(i).retriever().setRankDocs(rankDocs); - topDocs.add(rankDocs); - } - results.set(combineResults(topDocs)); - listener.onResponse(null); - } - - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - }); - }); - - return new RankDocsRetrieverBuilder( - rankWindowSize, - newRetrievers.stream().map(s -> s.retriever).toList(), - results::get, - newPreFilters - ); - } - - @Override - public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { - throw new UnsupportedOperationException("should not be called"); - } - - @Override - public String getName() { - return "compound_retriever"; - } - - @Override - protected void doToXContent(XContentBuilder builder, Params params) throws IOException { - - } - - @Override - protected boolean doEquals(Object o) { - return false; - } - - @Override - protected int doHashCode() { - return 0; - } - - private RankDoc[] getRankDocs(SearchResponse searchResponse) { - assert searchResponse != null; - int size = Math.min(rankWindowSize, searchResponse.getHits().getHits().length); - RankDoc[] docs = new RankDoc[size]; - for (int i = 0; i < size; i++) { - var hit = searchResponse.getHits().getAt(i); - long sortValue = (long) hit.getRawSortValues()[hit.getRawSortValues().length - 1]; - int doc = ShardDocSortField.decodeDoc(sortValue); - int shardRequestIndex = ShardDocSortField.decodeShardRequestIndex(sortValue); - docs[i] = new RankDoc(doc, hit.getScore(), shardRequestIndex); - docs[i].rank = i + 1; - } - return docs; - } - - record RankDocAndHitRatio(RankDoc rankDoc, float hitRatio) {} - - /** - * Combines the provided {@code rankResults} to return the final top documents. - */ - public RankDoc[] combineResults(List rankResults) { - int totalQueries = rankResults.size(); - final float step = 1.0f / totalQueries; - Map docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize); - for (var rankResult : rankResults) { - for (RankDoc scoreDoc : rankResult) { - docsToRankResults.compute(new RankDoc.RankKey(scoreDoc.doc, scoreDoc.shardIndex), (key, value) -> { - if (value == null) { - RankDoc res = new RankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex); - res.rank = scoreDoc.rank; - return new RankDocAndHitRatio(res, step); - } else { - RankDoc res = new RankDoc(scoreDoc.doc, Math.max(scoreDoc.score, value.rankDoc.score), scoreDoc.shardIndex); - res.rank = Math.min(scoreDoc.rank, value.rankDoc.rank); - return new RankDocAndHitRatio(res, value.hitRatio + step); - } - }); - } - } - // sort the results based on hit ratio, then doc, then rank, and final tiebreaker is based on smaller doc id - RankDocAndHitRatio[] sortedResults = docsToRankResults.values().toArray(RankDocAndHitRatio[]::new); - Arrays.sort(sortedResults, (RankDocAndHitRatio doc1, RankDocAndHitRatio doc2) -> { - if (doc1.hitRatio != doc2.hitRatio) { - return doc1.hitRatio < doc2.hitRatio ? 1 : -1; - } - if (false == (Float.isNaN(doc1.rankDoc.score) || Float.isNaN(doc2.rankDoc.score)) - && (doc1.rankDoc.score != doc2.rankDoc.score)) { - return doc1.rankDoc.score < doc2.rankDoc.score ? 1 : -1; - } - if (doc1.rankDoc.rank != doc2.rankDoc.rank) { - return doc1.rankDoc.rank < doc2.rankDoc.rank ? -1 : 1; - } - return doc1.rankDoc.doc < doc2.rankDoc.doc ? -1 : 1; - }); - // trim the results if needed, otherwise each shard will always return `rank_window_size` results. - // pagination and all else will happen on the coordinator when combining the shard responses - RankDoc[] topResults = new RankDoc[Math.min(rankWindowSize, sortedResults.length)]; - for (int rank = 0; rank < topResults.length; ++rank) { - topResults[rank] = sortedResults[rank].rankDoc; - topResults[rank].rank = rank + 1; - topResults[rank].score = sortedResults[rank].hitRatio; - } - return topResults; - } - } - - private SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { - var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit).trackTotalHits(false).size(100); - retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, false); - - // Record the shard id in the sort result - List> sortBuilders = sourceBuilder.sorts() != null ? new ArrayList<>(sourceBuilder.sorts()) : new ArrayList<>(); - if (sortBuilders.isEmpty()) { - sortBuilders.add(new ScoreSortBuilder()); - } - sortBuilders.add(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME)); - sourceBuilder.sort(sortBuilders); - return sourceBuilder; - } -} diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java index 2e7bc44811bf6..be64d34dc8765 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java @@ -33,7 +33,6 @@ import java.util.Collection; import java.util.List; -import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS; import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -98,7 +97,7 @@ protected void setupIndex() { } } """; - createIndex(INDEX, Settings.builder().put(SETTING_NUMBER_OF_SHARDS, 1).put(SETTING_NUMBER_OF_REPLICAS, 0).build()); + createIndex(INDEX, Settings.builder().put(SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5)).build()); admin().indices().preparePutMapping(INDEX).setSource(mapping, XContentType.JSON).get(); indexDoc(INDEX, "doc_1", DOC_FIELD, "doc_1", TOPIC_FIELD, "technology", TEXT_FIELD, "term"); indexDoc( @@ -167,8 +166,8 @@ public void testRRFPagination() { QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 3, 2, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -214,8 +213,8 @@ public void testRRFWithAggs() { QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 3, 2, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -266,8 +265,8 @@ public void testRRFWithCollapse() { QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 3, 2, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -320,8 +319,8 @@ public void testRRFRetrieverWithCollapseAndAggs() { QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 3, 2, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -383,8 +382,8 @@ public void testMultipleRRFRetrievers() { QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 3, 2, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -446,8 +445,8 @@ public void testRRFExplainWithNamedRetrievers() { QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 3, 2, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -474,13 +473,12 @@ public void testRRFExplainWithNamedRetrievers() { assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2)); var rrfDetails = resp.getHits().getAt(0).getExplanation().getDetails()[0]; assertThat(rrfDetails.getDetails().length, equalTo(3)); - assertThat(rrfDetails.getDescription(), containsString("computed for initial ranks [2, 1, 2]")); + assertThat(rrfDetails.getDescription(), containsString("computed for initial ranks [2, 1, 1]")); - assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("for rank [2] in query at index [0]")); assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("for rank [2] in query at index [0]")); assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("[my_custom_retriever]")); assertThat(rrfDetails.getDetails()[1].getDescription(), containsString("for rank [1] in query at index [1]")); - assertThat(rrfDetails.getDetails()[2].getDescription(), containsString("for rank [2] in query at index [2]")); + assertThat(rrfDetails.getDetails()[2].getDescription(), containsString("for rank [1] in query at index [2]")); }); } @@ -503,8 +501,8 @@ public void testRRFExplainWithAnotherNestedRRF() { QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 3, 2, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); RRFRetrieverBuilder nestedRRF = new RRFRetrieverBuilder( Arrays.asList( diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java index 512874e5009f3..ea251917cfae2 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.KnnRetrieverBuilder; @@ -21,8 +22,9 @@ import java.util.Arrays; -import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS; import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; public class RRFRetrieverBuilderNestedDocsIT extends RRFRetrieverBuilderIT { @@ -68,7 +70,7 @@ protected void setupIndex() { } } """; - createIndex(INDEX, Settings.builder().put(SETTING_NUMBER_OF_SHARDS, 1).put(SETTING_NUMBER_OF_REPLICAS, 0).build()); + createIndex(INDEX, Settings.builder().put(SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5)).build()); admin().indices().preparePutMapping(INDEX).setSource(mapping, XContentType.JSON).get(); indexDoc(INDEX, "doc_1", DOC_FIELD, "doc_1", TOPIC_FIELD, "technology", TEXT_FIELD, "term", LAST_30D_FIELD, 100); indexDoc( @@ -134,9 +136,9 @@ public void testRRFRetrieverWithNestedQuery() { final int rankWindowSize = 100; final int rankConstant = 10; SearchSourceBuilder source = new SearchSourceBuilder(); - // this one retrieves docs 1, 4 + // this one retrieves docs 1 StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( - QueryBuilders.nestedQuery("views", QueryBuilders.rangeQuery(LAST_30D_FIELD).gte(30L), ScoreMode.Avg) + QueryBuilders.nestedQuery("views", QueryBuilders.rangeQuery(LAST_30D_FIELD).gte(50L), ScoreMode.Avg) ); // this one retrieves docs 2 and 6 due to prefilter StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( @@ -157,16 +159,21 @@ public void testRRFRetrieverWithNestedQuery() { ) ); source.fetchField(TOPIC_FIELD); + source.explain(true); SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); ElasticsearchAssertions.assertResponse(req, resp -> { assertNull(resp.pointInTimeId()); assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value, equalTo(4L)); + assertThat(resp.getHits().getTotalHits().value, equalTo(3L)); assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_6")); - assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_2")); - assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_4")); + assertThat((double) resp.getHits().getAt(0).getScore(), closeTo(0.1742, 1e-4)); + assertThat( + Arrays.stream(resp.getHits().getHits()).skip(1).map(SearchHit::getId).toList(), + containsInAnyOrder("doc_1", "doc_2") + ); + assertThat((double) resp.getHits().getAt(1).getScore(), closeTo(0.0909, 1e-4)); + assertThat((double) resp.getHits().getAt(2).getScore(), closeTo(0.0909, 1e-4)); }); } } diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/350_rrf_retriever_pagination.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/350_rrf_retriever_pagination.yml index 47ba3658bb38d..d5d7a5de1dc71 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/350_rrf_retriever_pagination.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/350_rrf_retriever_pagination.yml @@ -1,6 +1,8 @@ setup: - skip: - features: close_to + features: + - close_to + - contains - requires: cluster_features: 'rrf_retriever_composition_supported' @@ -10,8 +12,6 @@ setup: indices.create: index: test body: - settings: - number_of_shards: 1 mappings: properties: number_val: @@ -81,35 +81,49 @@ setup: bool: { should: [ { - term: { - number_val: { - value: "1", - boost: 10.0 - } - } - }, - { - term: { - number_val: { - value: "2", - boost: 9.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "1" + } + } + }, + boost: 10.0 } - }, - { - term: { - number_val: { - value: "3", - boost: 8.0 - } + },{ + constant_score: { + filter: { + term: { + number_val: { + value: "2" + } + } + }, + boost: 9.0 + } }, + { + constant_score: { + filter: { + term: { + number_val: { + value: "3" + } + } + }, + boost: 8.0 } }, { - term: { - number_val: { - value: "4", - boost: 7.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "4" + } + } + }, + boost: 7.0 } } ] @@ -124,35 +138,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "A", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "B", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "C", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "D", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 7.0 } } ] @@ -198,35 +228,49 @@ setup: bool: { should: [ { - term: { - number_val: { - value: "1", - boost: 10.0 - } - } - }, - { - term: { - number_val: { - value: "2", - boost: 9.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "1" + } + } + }, + boost: 10.0 } - }, - { - term: { - number_val: { - value: "3", - boost: 8.0 - } + },{ + constant_score: { + filter: { + term: { + number_val: { + value: "2" + } + } + }, + boost: 9.0 + } }, + { + constant_score: { + filter: { + term: { + number_val: { + value: "3" + } + } + }, + boost: 8.0 } }, { - term: { - number_val: { - value: "4", - boost: 7.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "4" + } + } + }, + boost: 7.0 } } ] @@ -241,35 +285,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "A", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "B", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "C", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "D", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 7.0 } } ] @@ -306,35 +366,49 @@ setup: bool: { should: [ { - term: { - number_val: { - value: "1", - boost: 10.0 - } - } - }, - { - term: { - number_val: { - value: "2", - boost: 9.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "1" + } + } + }, + boost: 10.0 } - }, - { - term: { - number_val: { - value: "3", - boost: 8.0 - } + },{ + constant_score: { + filter: { + term: { + number_val: { + value: "2" + } + } + }, + boost: 9.0 + } }, + { + constant_score: { + filter: { + term: { + number_val: { + value: "3" + } + } + }, + boost: 8.0 } }, { - term: { - number_val: { - value: "4", - boost: 7.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "4" + } + } + }, + boost: 7.0 } } ] @@ -349,35 +423,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "A", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "B", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "C", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "D", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 7.0 } } ] @@ -422,35 +512,49 @@ setup: bool: { should: [ { - term: { - number_val: { - value: "1", - boost: 10.0 - } - } - }, - { - term: { - number_val: { - value: "2", - boost: 9.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "1" + } + } + }, + boost: 10.0 } - }, - { - term: { - number_val: { - value: "3", - boost: 8.0 - } + },{ + constant_score: { + filter: { + term: { + number_val: { + value: "2" + } + } + }, + boost: 9.0 + } }, + { + constant_score: { + filter: { + term: { + number_val: { + value: "3" + } + } + }, + boost: 8.0 } }, { - term: { - number_val: { - value: "4", - boost: 7.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "4" + } + } + }, + boost: 7.0 } } ] @@ -465,35 +569,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "D", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "C", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "A", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "B", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 7.0 } } ] @@ -533,35 +653,49 @@ setup: bool: { should: [ { - term: { - number_val: { - value: "1", - boost: 10.0 - } - } - }, - { - term: { - number_val: { - value: "2", - boost: 9.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "1" + } + } + }, + boost: 10.0 } - }, - { - term: { - number_val: { - value: "3", - boost: 8.0 - } + },{ + constant_score: { + filter: { + term: { + number_val: { + value: "2" + } + } + }, + boost: 9.0 + } }, + { + constant_score: { + filter: { + term: { + number_val: { + value: "3" + } + } + }, + boost: 8.0 } }, { - term: { - number_val: { - value: "4", - boost: 7.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "4" + } + } + }, + boost: 7.0 } } ] @@ -576,35 +710,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "D", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "C", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "A", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "B", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 7.0 } } ] @@ -632,9 +782,9 @@ setup: "Pagination within interleaved results, different result set sizes, rank_window_size covering all results": # perform multiple searches with different "from" parameter, ensuring that results are consistent # rank_window_size covers the entire result set for both queries, so pagination should be consistent - # queryA has a result set of [5, 1] and + # queryA has a result set of [1] and # queryB has a result set of [4, 3, 1, 2] - # so for rank_constant=10, the expected order is [1, 4, 5, 3, 2] + # so for rank_constant=10, the expected order is [1, 4, 3, 2] - do: search: index: test @@ -645,19 +795,11 @@ setup: { retrievers: [ { - # this should clause would generate the result set [5, 1] + # this should clause would generate the result set [1] standard: { query: { bool: { should: [ - { - term: { - number_val: { - value: "5", - boost: 10.0 - } - } - }, { term: { number_val: { @@ -678,35 +820,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "D", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "C", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "A", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "B", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 7.0 } } ] @@ -721,11 +879,11 @@ setup: from : 0 size : 2 - - match: { hits.total.value : 5 } + - match: { hits.total.value : 4 } - length: { hits.hits : 2 } - match: { hits.hits.0._id: "1" } # score for doc 1 is (1/12 + 1/13) - - close_to: {hits.hits.0._score: {value: 0.1602, error: 0.001}} + - close_to: {hits.hits.0._score: {value: 0.1678, error: 0.001}} - match: { hits.hits.1._id: "4" } # score for doc 4 is (1/11) - close_to: {hits.hits.1._score: {value: 0.0909, error: 0.001}} @@ -740,19 +898,11 @@ setup: { retrievers: [ { - # this should clause would generate the result set [5, 1] + # this should clause would generate the result set [1] standard: { query: { bool: { should: [ - { - term: { - number_val: { - value: "5", - boost: 10.0 - } - } - }, { term: { number_val: { @@ -773,35 +923,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "D", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "C", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "A", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "B", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 7.0 } } ] @@ -816,14 +982,15 @@ setup: from : 2 size : 2 - - match: { hits.total.value : 5 } + - match: { hits.total.value : 4 } - length: { hits.hits : 2 } - - match: { hits.hits.0._id: "5" } - # score for doc 5 is (1/11) - - close_to: {hits.hits.0._score: {value: 0.0909, error: 0.001}} - - match: { hits.hits.1._id: "3" } + - match: { hits.hits.0._id: "3" } # score for doc 3 is (1/12) - - close_to: {hits.hits.1._score: {value: 0.0833, error: 0.001}} + - close_to: {hits.hits.0._score: {value: 0.0833, error: 0.001}} + - match: { hits.hits.1._id: "2" } + # score for doc 2 is (1/14) + - close_to: {hits.hits.1._score: {value: 0.0714, error: 0.001}} + - do: search: @@ -835,19 +1002,11 @@ setup: { retrievers: [ { - # this should clause would generate the result set [5, 1] + # this should clause would generate the result set [1] standard: { query: { bool: { should: [ - { - term: { - number_val: { - value: "5", - boost: 10.0 - } - } - }, { term: { number_val: { @@ -868,35 +1027,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "D", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "C", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "A", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "B", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 7.0 } } ] @@ -911,12 +1086,8 @@ setup: from: 4 size: 2 - - match: { hits.total.value: 5 } - - length: { hits.hits: 1 } - - match: { hits.hits.0._id: "2" } - # score for doc 2 is (1/14) - - close_to: {hits.hits.0._score: {value: 0.0714, error: 0.001}} - + - match: { hits.total.value: 4 } + - length: { hits.hits: 0 } --- "Pagination within interleaved results, different result set sizes, rank_window_size not covering all results": @@ -943,19 +1114,27 @@ setup: bool: { should: [ { - term: { - number_val: { - value: "5", - boost: 10.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "5" + } + } + }, + boost: 10.0 } }, { - term: { - number_val: { - value: "1", - boost: 9.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "1" + } + } + }, + boost: 9.0 } } ] @@ -970,35 +1149,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "D", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "C", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "A", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "B", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 7.0 } } ] @@ -1015,11 +1210,11 @@ setup: - match: { hits.total.value : 5 } - length: { hits.hits : 2 } - - match: { hits.hits.0._id: "4" } - # score for doc 4 is (1/11) + - contains: { hits.hits: { _id: "4" } } + - contains: { hits.hits: { _id: "5" } } + + # both docs have the same score (1/11) - close_to: {hits.hits.0._score: {value: 0.0909, error: 0.001}} - - match: { hits.hits.1._id: "5" } - # score for doc 5 is (1/11) - close_to: {hits.hits.1._score: {value: 0.0909, error: 0.001}} - do: @@ -1038,19 +1233,27 @@ setup: bool: { should: [ { - term: { - number_val: { - value: "5", - boost: 10.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "5" + } + } + }, + boost: 10.0 } }, { - term: { - number_val: { - value: "1", - boost: 9.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "1" + } + } + }, + boost: 9.0 } } ] @@ -1065,35 +1268,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "D", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "C", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "A", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "B", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 7.0 } } ] diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml index 1f7125377b892..517c162c33e95 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml @@ -1,7 +1,6 @@ setup: - skip: features: close_to - - requires: cluster_features: 'rrf_retriever_composition_supported' reason: 'test requires rrf retriever composition support' @@ -10,8 +9,6 @@ setup: indices.create: index: test body: - settings: - number_of_shards: 1 mappings: properties: text: @@ -42,7 +39,7 @@ setup: index: test id: "1" body: - text: "term term term term term term term term term" + text: "term1" vector: [1.0] - do: @@ -50,7 +47,7 @@ setup: index: test id: "2" body: - text: "term term term term term term term term" + text: "term2" text_to_highlight: "search for the truth" keyword: "biology" vector: [2.0] @@ -60,8 +57,8 @@ setup: index: test id: "3" body: - text: "term term term term term term term" - text_to_highlight: "nothing related but still a match" + text: "term3" + text_to_highlight: "nothing related" keyword: "technology" vector: [3.0] @@ -70,14 +67,14 @@ setup: index: test id: "4" body: - text: "term term term term term term" + text: "term4" vector: [4.0] - do: index: index: test id: "5" body: - text: "term term term term term" + text: "term5" text_to_highlight: "You know, for Search!" keyword: "technology" integer: 5 @@ -87,7 +84,7 @@ setup: index: test id: "6" body: - text: "term term term term" + text: "term6" keyword: "biology" integer: 6 vector: [6.0] @@ -96,27 +93,26 @@ setup: index: test id: "7" body: - text: "term term term" + text: "term7" keyword: "astronomy" - vector: [7.0] + vector: [77.0] nested: { views: 50} - do: index: index: test id: "8" body: - text: "term term" + text: "term8" keyword: "technology" - vector: [8.0] nested: { views: 100} - do: index: index: test id: "9" body: - text: "term" + text: "term9" + integer: 2 keyword: "technology" - vector: [9.0] nested: { views: 10} - do: indices.refresh: {} @@ -133,6 +129,7 @@ setup: rrf: retrievers: [ { + # this one retrieves docs 6, 5, 4 knn: { field: vector, query_vector: [ 6.0 ], @@ -141,10 +138,72 @@ setup: } }, { + # this one retrieves docs 4, 5, 1, 2, 6 standard: { query: { - term: { - text: term + bool: { + should: [ + { + constant_score: { + filter: { + term: { + text: term4 + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term5 + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term1 + } + }, + boost: 8.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term2 + } + }, + boost: 7.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term6 + } + }, + boost: 6.0 + } + }, + { + constant_score: { + filter: { + exists: { + field: text + } + }, + boost: 1 + } + } + ] } } } @@ -158,9 +217,13 @@ setup: terms: field: keyword - - match: { hits.hits.0._id: "5" } - - match: { hits.hits.1._id: "1" } + + - match: { hits.hits.0._id: "4" } + - close_to: { hits.hits.0._score: { value: 0.1678, error: 0.001 } } + - match: { hits.hits.1._id: "5" } + - close_to: { hits.hits.1._score: { value: 0.1666, error: 0.001 } } - match: { hits.hits.2._id: "6" } + - close_to: { hits.hits.2._score: { value: 0.1575, error: 0.001 } } - match: { aggregations.keyword_aggs.buckets.0.key: "technology" } - match: { aggregations.keyword_aggs.buckets.0.doc_count: 4 } @@ -181,6 +244,7 @@ setup: rrf: retrievers: [ { + # this one retrieves docs 6, 5, 4 knn: { field: vector, query_vector: [ 6.0 ], @@ -189,10 +253,72 @@ setup: } }, { + # this one retrieves docs 4, 5, 1, 2, 6 standard: { query: { - term: { - text: term + bool: { + should: [ + { + constant_score: { + filter: { + term: { + text: term4 + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term5 + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term1 + } + }, + boost: 8.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term2 + } + }, + boost: 7.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term6 + } + }, + boost: 6.0 + } + }, + { + constant_score: { + filter: { + exists: { + field: text + } + }, + boost: 1 + } + } + ] } } } @@ -208,12 +334,14 @@ setup: lang: painless source: "_score" - - - match: { hits.hits.0._id: "5" } - - match: { hits.hits.1._id: "1" } + - match: { hits.hits.0._id: "4" } + - close_to: { hits.hits.0._score: { value: 0.1678, error: 0.001 } } + - match: { hits.hits.1._id: "5" } + - close_to: { hits.hits.1._score: { value: 0.1666, error: 0.001 } } - match: { hits.hits.2._id: "6" } + - close_to: { hits.hits.2._score: { value: 0.1575, error: 0.001 } } - - close_to: { aggregations.max_score.value: { value: 0.15, error: 0.001 }} + - close_to: { aggregations.max_score.value: { value: 0.1678, error: 0.001 }} --- "rrf retriever with top-level collapse": @@ -228,6 +356,7 @@ setup: rrf: retrievers: [ { + # this one retrieves docs 6, 5, 4 knn: { field: vector, query_vector: [ 6.0 ], @@ -236,10 +365,72 @@ setup: } }, { + # this one retrieves docs 4, 5, 1, 2, 6 standard: { query: { - term: { - text: term + bool: { + should: [ + { + constant_score: { + filter: { + term: { + text: term4 + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term5 + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term1 + } + }, + boost: 8.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term2 + } + }, + boost: 7.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term6 + } + }, + boost: 6.0 + } + }, + { + constant_score: { + filter: { + exists: { + field: text + } + }, + boost: 1 + } + } + ] } } } @@ -250,18 +441,23 @@ setup: size: 3 collapse: { field: keyword, inner_hits: { name: sub_hits, size: 2 } } - - match: { hits.hits.0._id: "5" } - - match: { hits.hits.1._id: "1" } + - match: { hits.total : 9 } + + - match: { hits.hits.0._id: "4" } + - close_to: { hits.hits.0._score: { value: 0.1678, error: 0.001 } } + - match: { hits.hits.1._id: "5" } + - close_to: { hits.hits.1._score: { value: 0.1666, error: 0.001 } } - match: { hits.hits.2._id: "6" } + - close_to: { hits.hits.2._score: { value: 0.1575, error: 0.001 } } - - match: { hits.hits.0.inner_hits.sub_hits.hits.total : 4 } - length: { hits.hits.0.inner_hits.sub_hits.hits.hits : 2 } - - match: { hits.hits.0.inner_hits.sub_hits.hits.hits.0._id: "5" } - - match: { hits.hits.0.inner_hits.sub_hits.hits.hits.1._id: "3" } + - match: { hits.hits.0.inner_hits.sub_hits.hits.hits.0._id: "4" } + - match: { hits.hits.0.inner_hits.sub_hits.hits.hits.1._id: "1" } + - match: { hits.hits.1.inner_hits.sub_hits.hits.total : 4 } - length: { hits.hits.1.inner_hits.sub_hits.hits.hits : 2 } - - match: { hits.hits.1.inner_hits.sub_hits.hits.hits.0._id: "1" } - - match: { hits.hits.1.inner_hits.sub_hits.hits.hits.1._id: "4" } + - match: { hits.hits.1.inner_hits.sub_hits.hits.hits.0._id: "5" } + - match: { hits.hits.1.inner_hits.sub_hits.hits.hits.1._id: "3" } - length: { hits.hits.2.inner_hits.sub_hits.hits.hits: 2 } - match: { hits.hits.2.inner_hits.sub_hits.hits.hits.0._id: "6" } @@ -280,18 +476,132 @@ setup: rrf: retrievers: [ { - knn: { - field: vector, - query_vector: [ 6.0 ], - k: 3, - num_candidates: 10 + # this one retrieves docs 7, 3 + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + text: term7 + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term3 + } }, + boost: 9.0 + } + } + ] + } + } } }, { + # this one retrieves docs 1, 2, 3, 7 standard: { query: { - term: { - text: term + bool: { + should: [ + { + constant_score: { + filter: { + term: { + text: term1 + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term2 + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term3 + } + }, + boost: 8.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term4 + } + }, + boost: 7.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term5 + } + }, + boost: 6.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term6 + } + }, + boost: 5.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term7 + } + }, + boost: 4.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term8 + } + }, + boost: 3.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term9 + } + }, + boost: 2.0 + } + } + ] } }, collapse: { field: keyword, inner_hits: { name: sub_hits, size: 1 } } @@ -303,8 +613,9 @@ setup: size: 3 - match: { hits.hits.0._id: "7" } - - match: { hits.hits.1._id: "1" } - - match: { hits.hits.2._id: "6" } + - close_to: { hits.hits.0._score: { value: 0.1623, error: 0.001 } } + - match: { hits.hits.1._id: "3" } + - close_to: { hits.hits.1._score: { value: 0.1602, error: 0.001 } } --- "rrf retriever highlighting results": @@ -331,7 +642,7 @@ setup: standard: { query: { term: { - keyword: technology + text: term5 } } } @@ -349,7 +660,7 @@ setup: } } - - match: { hits.total : 5 } + - match: { hits.total : 2 } - match: { hits.hits.0._id: "5" } - match: { hits.hits.0.highlight.text_to_highlight.0: "You know, for Search!" } @@ -357,9 +668,6 @@ setup: - match: { hits.hits.1._id: "2" } - match: { hits.hits.1.highlight.text_to_highlight.0: "search for the truth" } - - match: { hits.hits.2._id: "3" } - - not_exists: hits.hits.2.highlight - --- "rrf retriever with custom nested sort": @@ -374,12 +682,103 @@ setup: retrievers: [ { # this one retrievers docs 1, 2, 3, .., 9 - # but due to sorting, it will revert the order to 6, 5, .., 9 which due to + # but due to sorting, it will revert the order to 6, 5, 9, ... which due to # rank_window_size: 2 will only return 6 and 5 standard: { query: { - term: { - text: term + bool: { + should: [ + { + constant_score: { + filter: { + term: { + text: term1 + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term2 + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term3 + } + }, + boost: 8.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term4 + } + }, + boost: 7.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term5 + } + }, + boost: 6.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term6 + } + }, + boost: 5.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term7 + } + }, + boost: 4.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term8 + } + }, + boost: 3.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term9 + } + }, + boost: 2.0 + } + } + ] } }, sort: [ @@ -410,7 +809,6 @@ setup: - length: {hits.hits: 2 } - match: { hits.hits.0._id: "6" } - - match: { hits.hits.1._id: "2" } --- "rrf retriever with nested query": @@ -427,7 +825,7 @@ setup: { knn: { field: vector, - query_vector: [ 7.0 ], + query_vector: [ 77.0 ], k: 1, num_candidates: 3 } From 3ee06081ce34bfeba441aee7e886fad724188759 Mon Sep 17 00:00:00 2001 From: Joe Gallo Date: Mon, 7 Oct 2024 15:44:22 -0300 Subject: [PATCH 03/22] IPinfo ASN and Country (#114192) (#114256) Adding the building blocks to support IPinfo ASN and Country data --- .../elasticsearch/ingest/geoip/Database.java | 19 +- .../ingest/geoip/IpDataLookupFactories.java | 3 +- .../ingest/geoip/IpinfoIpDataLookups.java | 235 ++++++++++++++++++ .../ingest/geoip/GeoIpProcessorTests.java | 13 +- .../geoip/IpinfoIpDataLookupsTests.java | 223 +++++++++++++++++ .../ingest/geoip/MaxMindSupportTests.java | 6 + .../src/test/resources/ipinfo/asn_sample.mmdb | Bin 0 -> 25210 bytes .../test/resources/ipinfo/ip_asn_sample.mmdb | Bin 0 -> 23456 bytes .../resources/ipinfo/ip_country_sample.mmdb | Bin 0 -> 32292 bytes 9 files changed, 495 insertions(+), 4 deletions(-) create mode 100644 modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookups.java create mode 100644 modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookupsTests.java create mode 100644 modules/ingest-geoip/src/test/resources/ipinfo/asn_sample.mmdb create mode 100644 modules/ingest-geoip/src/test/resources/ipinfo/ip_asn_sample.mmdb create mode 100644 modules/ingest-geoip/src/test/resources/ipinfo/ip_country_sample.mmdb diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java index 52ca5eea52c1a..31d7a43e3869a 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java @@ -22,6 +22,10 @@ *

* A database has a set of properties that are valid to use with it (see {@link Database#properties()}), * as well as a list of default properties to use if no properties are specified (see {@link Database#defaultProperties()}). + *

+ * Some database providers have similar concepts but might have slightly different properties associated with those types. + * This can be accommodated, for example, by having a Foo value and a separate FooV2 value where the 'V' should be read as + * 'variant' or 'variation'. A V-less Database type is inherently the first variant/variation (i.e. V1). */ enum Database { @@ -137,6 +141,18 @@ enum Database { Property.MOBILE_COUNTRY_CODE, Property.MOBILE_NETWORK_CODE ) + ), + AsnV2( + Set.of( + Property.IP, + Property.ASN, + Property.ORGANIZATION_NAME, + Property.NETWORK, + Property.DOMAIN, + Property.COUNTRY_ISO_CODE, + Property.TYPE + ), + Set.of(Property.IP, Property.ASN, Property.ORGANIZATION_NAME, Property.NETWORK) ); private final Set properties; @@ -211,7 +227,8 @@ enum Property { MOBILE_COUNTRY_CODE, MOBILE_NETWORK_CODE, CONNECTION_TYPE, - USER_TYPE; + USER_TYPE, + TYPE; /** * Parses a string representation of a property into an actual Property instance. Not all properties that exist are diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpDataLookupFactories.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpDataLookupFactories.java index 990788978a0ca..3379fdff0633a 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpDataLookupFactories.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpDataLookupFactories.java @@ -76,6 +76,7 @@ static Database getDatabase(final String databaseType) { return database; } + @Nullable static Function, IpDataLookup> getMaxmindLookup(final Database database) { return switch (database) { case City -> MaxmindIpDataLookups.City::new; @@ -86,6 +87,7 @@ static Function, IpDataLookup> getMaxmindLookup(final Dat case Domain -> MaxmindIpDataLookups.Domain::new; case Enterprise -> MaxmindIpDataLookups.Enterprise::new; case Isp -> MaxmindIpDataLookups.Isp::new; + default -> null; }; } @@ -97,7 +99,6 @@ static IpDataLookupFactory get(final String databaseType, final String databaseF final Function, IpDataLookup> factoryMethod = getMaxmindLookup(database); - // note: this can't presently be null, but keep this check -- it will be useful in the near future if (factoryMethod == null) { throw new IllegalArgumentException("Unsupported database type [" + databaseType + "] for file [" + databaseFile + "]"); } diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookups.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookups.java new file mode 100644 index 0000000000000..ac7f56468f37e --- /dev/null +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookups.java @@ -0,0 +1,235 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.ingest.geoip; + +import com.maxmind.db.DatabaseRecord; +import com.maxmind.db.MaxMindDbConstructor; +import com.maxmind.db.MaxMindDbParameter; +import com.maxmind.db.Reader; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.network.InetAddresses; +import org.elasticsearch.common.network.NetworkAddress; +import org.elasticsearch.core.Nullable; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Set; + +/** + * A collection of {@link IpDataLookup} implementations for IPinfo databases + */ +final class IpinfoIpDataLookups { + + private IpinfoIpDataLookups() { + // utility class + } + + private static final Logger logger = LogManager.getLogger(IpinfoIpDataLookups.class); + + /** + * Lax-ly parses a string that (ideally) looks like 'AS123' into a Long like 123L (or null, if such parsing isn't possible). + * @param asn a potentially empty (or null) ASN string that is expected to contain 'AS' and then a parsable long + * @return the parsed asn + */ + static Long parseAsn(final String asn) { + if (asn == null || Strings.hasText(asn) == false) { + return null; + } else { + String stripped = asn.toUpperCase(Locale.ROOT).replaceAll("AS", "").trim(); + try { + return Long.parseLong(stripped); + } catch (NumberFormatException e) { + logger.trace("Unable to parse non-compliant ASN string [{}]", asn); + return null; + } + } + } + + public record AsnResult( + Long asn, + @Nullable String country, // not present in the free asn database + String domain, + String name, + @Nullable String type // not present in the free asn database + ) { + @SuppressWarnings("checkstyle:RedundantModifier") + @MaxMindDbConstructor + public AsnResult( + @MaxMindDbParameter(name = "asn") String asn, + @Nullable @MaxMindDbParameter(name = "country") String country, + @MaxMindDbParameter(name = "domain") String domain, + @MaxMindDbParameter(name = "name") String name, + @Nullable @MaxMindDbParameter(name = "type") String type + ) { + this(parseAsn(asn), country, domain, name, type); + } + } + + public record CountryResult( + @MaxMindDbParameter(name = "continent") String continent, + @MaxMindDbParameter(name = "continent_name") String continentName, + @MaxMindDbParameter(name = "country") String country, + @MaxMindDbParameter(name = "country_name") String countryName + ) { + @MaxMindDbConstructor + public CountryResult {} + } + + static class Asn extends AbstractBase { + Asn(Set properties) { + super(properties, AsnResult.class); + } + + @Override + protected Map transform(final Result result) { + AsnResult response = result.result; + Long asn = response.asn; + String organizationName = response.name; + String network = result.network; + + Map data = new HashMap<>(); + for (Database.Property property : this.properties) { + switch (property) { + case IP -> data.put("ip", result.ip); + case ASN -> { + if (asn != null) { + data.put("asn", asn); + } + } + case ORGANIZATION_NAME -> { + if (organizationName != null) { + data.put("organization_name", organizationName); + } + } + case NETWORK -> { + if (network != null) { + data.put("network", network); + } + } + case COUNTRY_ISO_CODE -> { + if (response.country != null) { + data.put("country_iso_code", response.country); + } + } + case DOMAIN -> { + if (response.domain != null) { + data.put("domain", response.domain); + } + } + case TYPE -> { + if (response.type != null) { + data.put("type", response.type); + } + } + } + } + return data; + } + } + + static class Country extends AbstractBase { + Country(Set properties) { + super(properties, CountryResult.class); + } + + @Override + protected Map transform(final Result result) { + CountryResult response = result.result; + + Map data = new HashMap<>(); + for (Database.Property property : this.properties) { + switch (property) { + case IP -> data.put("ip", result.ip); + case COUNTRY_ISO_CODE -> { + String countryIsoCode = response.country; + if (countryIsoCode != null) { + data.put("country_iso_code", countryIsoCode); + } + } + case COUNTRY_NAME -> { + String countryName = response.countryName; + if (countryName != null) { + data.put("country_name", countryName); + } + } + case CONTINENT_CODE -> { + String continentCode = response.continent; + if (continentCode != null) { + data.put("continent_code", continentCode); + } + } + case CONTINENT_NAME -> { + String continentName = response.continentName; + if (continentName != null) { + data.put("continent_name", continentName); + } + } + } + } + return data; + } + } + + /** + * Just a little record holder -- there's the data that we receive via the binding to our record objects from the Reader via the + * getRecord call, but then we also need to capture the passed-in ip address that came from the caller as well as the network for + * the returned DatabaseRecord from the Reader. + */ + public record Result(T result, String ip, String network) {} + + /** + * The {@link IpinfoIpDataLookups.AbstractBase} is an abstract base implementation of {@link IpDataLookup} that + * provides common functionality for getting a {@link IpinfoIpDataLookups.Result} that wraps a record from a {@link IpDatabase}. + * + * @param the record type that will be wrapped and returned + */ + private abstract static class AbstractBase implements IpDataLookup { + + protected final Set properties; + protected final Class clazz; + + AbstractBase(final Set properties, final Class clazz) { + this.properties = Set.copyOf(properties); + this.clazz = clazz; + } + + @Override + public Set getProperties() { + return this.properties; + } + + @Override + public final Map getData(final IpDatabase ipDatabase, final String ipAddress) { + final Result response = ipDatabase.getResponse(ipAddress, this::lookup); + return (response == null || response.result == null) ? Map.of() : transform(response); + } + + @Nullable + private Result lookup(final Reader reader, final String ipAddress) throws IOException { + final InetAddress ip = InetAddresses.forString(ipAddress); + final DatabaseRecord record = reader.getRecord(ip, clazz); + final RESPONSE data = record.getData(); + return (data == null) ? null : new Result<>(data, NetworkAddress.format(ip), record.getNetwork().toString()); + } + + /** + * Extract the configured properties from the retrieved response + * @param response the non-null response that was retrieved + * @return a mapping of properties for the ip from the response + */ + protected abstract Map transform(Result response); + } +} diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java index 793754ec316b2..46024cb6ad215 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java @@ -24,6 +24,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import static org.elasticsearch.ingest.IngestDocumentMatcher.assertIngestDocument; @@ -64,8 +65,16 @@ public void testDatabasePropertyInvariants() { assertThat(Sets.difference(Database.Asn.properties(), Database.Isp.properties()), is(empty())); assertThat(Sets.difference(Database.Asn.defaultProperties(), Database.Isp.defaultProperties()), is(empty())); - // the enterprise database is like everything joined together - for (Database type : Database.values()) { + // the enterprise database is like these other databases joined together + for (Database type : Set.of( + Database.City, + Database.Country, + Database.Asn, + Database.AnonymousIp, + Database.ConnectionType, + Database.Domain, + Database.Isp + )) { assertThat(Sets.difference(type.properties(), Database.Enterprise.properties()), is(empty())); } // but in terms of the default fields, it's like a drop-in replacement for the city database diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookupsTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookupsTests.java new file mode 100644 index 0000000000000..905eb027626a1 --- /dev/null +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookupsTests.java @@ -0,0 +1,223 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.ingest.geoip; + +import com.maxmind.db.DatabaseRecord; +import com.maxmind.db.Networks; +import com.maxmind.db.Reader; + +import org.elasticsearch.common.network.NetworkAddress; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.watcher.ResourceWatcherService; +import org.junit.After; +import org.junit.Before; + +import java.io.File; +import java.io.IOException; +import java.net.InetAddress; +import java.nio.file.Path; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; + +import static java.util.Map.entry; +import static org.elasticsearch.ingest.geoip.GeoIpTestUtils.copyDatabase; +import static org.elasticsearch.ingest.geoip.IpinfoIpDataLookups.parseAsn; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.startsWith; + +public class IpinfoIpDataLookupsTests extends ESTestCase { + + private ThreadPool threadPool; + private ResourceWatcherService resourceWatcherService; + + @Before + public void setup() { + threadPool = new TestThreadPool(ConfigDatabases.class.getSimpleName()); + Settings settings = Settings.builder().put("resource.reload.interval.high", TimeValue.timeValueMillis(100)).build(); + resourceWatcherService = new ResourceWatcherService(settings, threadPool); + } + + @After + public void cleanup() { + resourceWatcherService.close(); + threadPool.shutdownNow(); + } + + public void testDatabasePropertyInvariants() { + // the second ASN variant database is like a specialization of the ASN database + assertThat(Sets.difference(Database.Asn.properties(), Database.AsnV2.properties()), is(empty())); + assertThat(Database.Asn.defaultProperties(), equalTo(Database.AsnV2.defaultProperties())); + } + + public void testParseAsn() { + // expected case: "AS123" is 123 + assertThat(parseAsn("AS123"), equalTo(123L)); + // defensive cases: null and empty becomes null, this is not expected fwiw + assertThat(parseAsn(null), nullValue()); + assertThat(parseAsn(""), nullValue()); + // defensive cases: we strip whitespace and ignore case + assertThat(parseAsn(" as 456 "), equalTo(456L)); + // defensive cases: we ignore the absence of the 'AS' prefix + assertThat(parseAsn("123"), equalTo(123L)); + // bottom case: a non-parsable string is null + assertThat(parseAsn("anythingelse"), nullValue()); + } + + public void testAsn() throws IOException { + Path configDir = createTempDir(); + copyDatabase("ipinfo/ip_asn_sample.mmdb", configDir.resolve("ip_asn_sample.mmdb")); + copyDatabase("ipinfo/asn_sample.mmdb", configDir.resolve("asn_sample.mmdb")); + + GeoIpCache cache = new GeoIpCache(1000); // real cache to test purging of entries upon a reload + ConfigDatabases configDatabases = new ConfigDatabases(configDir, cache); + configDatabases.initialize(resourceWatcherService); + + // this is the 'free' ASN database (sample) + { + DatabaseReaderLazyLoader loader = configDatabases.getDatabase("ip_asn_sample.mmdb"); + IpDataLookup lookup = new IpinfoIpDataLookups.Asn(Set.of(Database.Property.values())); + Map data = lookup.getData(loader, "5.182.109.0"); + assertThat( + data, + equalTo( + Map.ofEntries( + entry("ip", "5.182.109.0"), + entry("organization_name", "M247 Europe SRL"), + entry("asn", 9009L), + entry("network", "5.182.109.0/24"), + entry("domain", "m247.com") + ) + ) + ); + } + + // this is the non-free or 'standard' ASN database (sample) + { + DatabaseReaderLazyLoader loader = configDatabases.getDatabase("asn_sample.mmdb"); + IpDataLookup lookup = new IpinfoIpDataLookups.Asn(Set.of(Database.Property.values())); + Map data = lookup.getData(loader, "23.53.116.0"); + assertThat( + data, + equalTo( + Map.ofEntries( + entry("ip", "23.53.116.0"), + entry("organization_name", "Akamai Technologies, Inc."), + entry("asn", 32787L), + entry("network", "23.53.116.0/24"), + entry("domain", "akamai.com"), + entry("type", "hosting"), + entry("country_iso_code", "US") + ) + ) + ); + } + } + + public void testAsnInvariants() { + Path configDir = createTempDir(); + copyDatabase("ipinfo/ip_asn_sample.mmdb", configDir.resolve("ip_asn_sample.mmdb")); + copyDatabase("ipinfo/asn_sample.mmdb", configDir.resolve("asn_sample.mmdb")); + + { + final Set expectedColumns = Set.of("network", "asn", "name", "domain"); + + Path databasePath = configDir.resolve("ip_asn_sample.mmdb"); + assertDatabaseInvariants(databasePath, (ip, row) -> { + assertThat(row.keySet(), equalTo(expectedColumns)); + String asn = (String) row.get("asn"); + assertThat(asn, startsWith("AS")); + assertThat(asn, equalTo(asn.trim())); + Long parsed = parseAsn(asn); + assertThat(parsed, notNullValue()); + assertThat(asn, equalTo("AS" + parsed)); // reverse it + }); + } + + { + final Set expectedColumns = Set.of("network", "asn", "name", "domain", "country", "type"); + + Path databasePath = configDir.resolve("asn_sample.mmdb"); + assertDatabaseInvariants(databasePath, (ip, row) -> { + assertThat(row.keySet(), equalTo(expectedColumns)); + String asn = (String) row.get("asn"); + assertThat(asn, startsWith("AS")); + assertThat(asn, equalTo(asn.trim())); + Long parsed = parseAsn(asn); + assertThat(parsed, notNullValue()); + assertThat(asn, equalTo("AS" + parsed)); // reverse it + }); + } + } + + public void testCountry() throws IOException { + Path configDir = createTempDir(); + copyDatabase("ipinfo/ip_country_sample.mmdb", configDir.resolve("ip_country_sample.mmdb")); + + GeoIpCache cache = new GeoIpCache(1000); // real cache to test purging of entries upon a reload + ConfigDatabases configDatabases = new ConfigDatabases(configDir, cache); + configDatabases.initialize(resourceWatcherService); + + // this is the 'free' Country database (sample) + { + DatabaseReaderLazyLoader loader = configDatabases.getDatabase("ip_country_sample.mmdb"); + IpDataLookup lookup = new IpinfoIpDataLookups.Country(Set.of(Database.Property.values())); + Map data = lookup.getData(loader, "4.221.143.168"); + assertThat( + data, + equalTo( + Map.ofEntries( + entry("ip", "4.221.143.168"), + entry("country_name", "South Africa"), + entry("country_iso_code", "ZA"), + entry("continent_name", "Africa"), + entry("continent_code", "AF") + ) + ) + ); + } + } + + private static void assertDatabaseInvariants(final Path databasePath, final BiConsumer> rowConsumer) { + try (Reader reader = new Reader(pathToFile(databasePath))) { + Networks networks = reader.networks(Map.class); + while (networks.hasNext()) { + DatabaseRecord dbr = networks.next(); + InetAddress address = dbr.getNetwork().getNetworkAddress(); + @SuppressWarnings("unchecked") + Map result = reader.get(address, Map.class); + try { + rowConsumer.accept(address, result); + } catch (AssertionError e) { + fail(e, "Assert failed for address [%s]", NetworkAddress.format(address)); + } catch (Exception e) { + fail(e, "Exception handling address [%s]", NetworkAddress.format(address)); + } + } + } catch (Exception e) { + fail(e); + } + } + + @SuppressForbidden(reason = "Maxmind API requires java.io.File") + private static File pathToFile(Path databasePath) { + return databasePath.toFile(); + } +} diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java index 84ea5fd584352..3b12003637783 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java @@ -361,8 +361,14 @@ public class MaxMindSupportTests extends ESTestCase { private static final Set> KNOWN_UNSUPPORTED_RESPONSE_CLASSES = Set.of(IpRiskResponse.class); + private static final Set KNOWN_UNSUPPORTED_DATABASE_VARIANTS = Set.of(Database.AsnV2); + public void testMaxMindSupport() { for (Database databaseType : Database.values()) { + if (KNOWN_UNSUPPORTED_DATABASE_VARIANTS.contains(databaseType)) { + continue; + } + Class maxMindClass = TYPE_TO_MAX_MIND_CLASS.get(databaseType); Set supportedFields = TYPE_TO_SUPPORTED_FIELDS_MAP.get(databaseType); Set unsupportedFields = TYPE_TO_UNSUPPORTED_FIELDS_MAP.get(databaseType); diff --git a/modules/ingest-geoip/src/test/resources/ipinfo/asn_sample.mmdb b/modules/ingest-geoip/src/test/resources/ipinfo/asn_sample.mmdb new file mode 100644 index 0000000000000000000000000000000000000000..916a8252a5df1d5d2ea15dfb14061e55360d6cd0 GIT binary patch literal 25210 zcmbW71$Y}rw1s5~wGFi4w9Ph66UT{FF}P@s46+m1!Mdp|t!>4!5JLAF>d1s2Q5jVu;RAz1`-OZdg~OUM!ANOBZ8 znjAxJKyFB)Zdw0WavV9HoIq|&ZbD8ZCy|rMDdbdg8abVuL2gQJMs7~dBn@&4au$he zm*+d1+=`q-&LvC9d1M(mpIktelU8zT(ni`z2k9hTq?`1RUeZVU$qI5Ca$9mca(i+I za!0a~tRkz)8nTwGBkRe9WPofS8_6cJnGBLGWGlIdTug?@HnN>uLM|njk;};yYYw1Ke2PJ~CjeQ*T7#(KRd4np`F0|^{{tCIe`A|u=mi8-;?@YAzV{vpb?97-NW9!?%X9!VZW9*ugAQMPRFSn9`-$20ad@=m075_vLX z8zFWI%Bh-v8vN5W{|p{;CiJs3{cMY6RCYd{i+L~^@$*ogLpdMiZj=j9uH*3+k{3a{ znD!;)rHb-8E<^ls&A);{{-CBmMEzmtkI;Tpi$4baaq3TyPm)h* zc~2wn8S2mG_0jXFJx{)Xycbb^M!|VoK1O+&u~+iPSzd+yn&!U_{|(K56aHJ8|2F)0 zH2+=r?`i(~%C>v}{X>+Gls_8lOcdE?pJ@4?Qu~bj9LIb?`%CgGMR}~4*WYOVx6Jtt z`uDVd(BeNr|0(Z}Mju#y!Mq*8++X4UM*Da25Asj)FGbOQ0n#O56clQD5&UAtO7i-c za_A$ekIL%>qoI${{0*41At~A#OM9G_H=f!A#5bnB2|1CRq$tjFOh59bpe#h0n%4`a z!Jm$@fXB>$y{Q)8jQZx#XVNyb_!iJ-QQuP2XH(k>@j0~TYVlI&oT~+8u;;7T7_1jj z#5v3U75!$_{H@{JP_{#{GsmIDozPvH?`Dn%x>wVE@co)!!MtstZ>#JLvVF3>mbU}* zc0^t!?JBZb%d0`WR`ctaQ?Izeh4cf+YtZr<;WweQp)@lmNVaHst<)Bgi^-71vSGIG z+NmuemnzDBT}Ev=+F7CLD{*`eV;zXSg0eHpF(|80dKudVc9?djc6ok85{(8;d zjmPXx?t%E8%HC+QmbW+beVDVarte2>fARqGK=L5+U>tu4?L*1K6y^0CPVETtNb)H1 zXp3c|>^L2Zy6#0e4&^44<54a^If1z+k|&WTlcy-k_D)6qX_|jJbIyQ%rly}o{cQ3a z@?7MdNBew5QSU}~QOnNvi;#D5K3;GM{7W_eGUi`SUO{4B6kMg{T}}NO@>=pb@_OrmpOK#F#a|94f!qk9r-%&$SbR@3XK*F#@OJD|lIs5K(qr0fZ& zQwwT&E!0}cMdV^KgyY+2w=2qiSOR@1^<|pA9Qq2)UkQIF=5>%eYw=Zx@1pr(_}Tdr zVSbb}$u7kSTt|gH)MMmo)UlR!FByl{M?0Y?uOmsVAMssjr?fcco6O6Q1ISyGk8ivl z{<^$hILJJq??!ufExre}JrUna)Ay#n5A=O$@0X8n{0Z~}s2@lkL>`>aD?9{whoYRo z*kP~_r+ox@BzY8hG3#WD8^&(Qoc znRAxKvdKf5eh%`^W$ZlieDVUs?n1ec{zb5_L%CS_^4u-_+u7QG1(whkO_L?_Dft=lmm>djup0aVoCSW#_*dlDit^ZR5dW6?cgUYa`+Mj=z_)-uYWY7g{xkGn zH2qijzcKbZ`G*$&6Y;2Ihg>k)O0r9O@vPfj2=R+RahAb+CrCm+W6WO9m@H?vV`8&d|)ch)CPrZvV6UO>|u zs5g>Lh&OBTAp91_TFFJ^VlqUwk?rIX#VM!JUxvEwLs^b;4z(4qSE9t}?*zMp_Ri!g zau-F}FJa_$YJP+{QPM=bi*`5JqbQG$A-KSqX@iny9lIs*@{vhJ(sqd!gyIU+%FQmRFat@`v7wo;YynW#BtNHseZ-3|qX!?QF z4}yNMrXQm0X=2`By%wE_a<~>h0{)Sjf0VL|jwX*m{8%l19QEU&pODw5;T(%j%KJqp zGw&4Wr)qk3j-1Z;8Hk@r`z-Qo#c40nKUY!q=Xuo5$8i^+T#u6N$BVSQiy$rzE9Tsw#c!l`6Y_4ReG7Rjc^i4VqOA80YIl-% z;rP3?_&xCN&HK}Ftwr}U_kp}#^dPl|5Pz8VBjlsxW8~wCvi&EhJxM-AK8<>w(ej?9 z{v7!{;xA}%jMa>6zr4(xSIAe%*T~n&H^?{f9yf{hTd4PKly6bqq4qBM9{E1`0r?^M z5&1Fs3Hd4cnWDVzTZcD~f~ui$^p*f)872IgJScR0^t@cX=8^aK1KHUB3b`!n=k zH2qijzcKbZ`3L!@miL#k(Tt_IfOcU%UR-3cY&t{pOPDi)97&EMN28uGv^P)`*I&FL z3MlQdnm!Ktc+H;xe`C$xgvU(8d2kIEPex7?$`q6el&L85m^TgfbnSSad-0}>Z-)5h zv}bC02J|hc&my-ZXOo<7#dGM-Rg~9MYO!plw_Aq#=4a){=E(J-Lt!kPT#` zqNs0k%%kFFlrUpKvISaeKCgHY{Kbrg$TqT_T%st?b1CwdQD083AXkz*ksajD%BcUIa_h({kijT?r#mDlPI~n>ZTHdMD zPa{uPocSXCGm(FmvgP@l4gVa*&einusGX1a1+*_DFCs5il=WPKyi2KHMqW-{L0(B- zrD)8je+_vpc^!E@>U;*}2Ib3kZiIgm<2S>8fc7o0Z>4=3c{_QBqCEaiYIh<3Zrb;d z_iB0fA%4H|4fdft{z2wGL_UnXN3`P}h5s01kCRW3Pm)h5%KDzRShm=d`m^M7sN)@! z=e0bn1DT&4$Cv28OumBrSGD}tsJ{;V4cc##Z)tgND|-u!P4T;ozem0g?JJZIwEPd@ zf5iC5YtNekY6g!nnC|-@*72Yec$5v?==5==KMhZi1<&mf7bGTf&MG? z-xOzYofZE<{ZH~Q#S#lyfPJRdVks#^!LcPpD43Ha#VBnkNGzFyGJ^Rd$x+az&>l^W zAvYj5M9xOE$CBg7@rvTQOD0g;81^QbK9Tw)aEvU~TwM93Sm6*z=Vw+gkv?oG~l8wHCJ_Zl~_h zbSE_z=_Wm-S8+?swGuzI3fT21+tA;Z+>YFy+(A*E-;UHOaa^NVGx?e=Oj`8bZKY(%r{TmhKxZXtVX7U!~-Aem5@^(d;e+S}sQooD5 zTZ`XA?Ox>FNBe%oxgTnI4>I-;@*dXoN2osv{W02)Yw;(bKS}*5@@euJE$>-s&msSL zO@BezrC4hvFQI(Gyq96WLi<(nHS%>udHru7?@j7&k#CdlknfW3k?$+Y<3GUhdOv&w z|6>&|{f7CUlAmeEeU7{@H2+KHdan@hbE`aDf9gFm0K1!OsCCATJRq+L;-rvt}3HQ%M|GOX_r9_D*V zAL&<=$5&9>hTN9i4##h=e<3BX6Y^--+=Ka%aR>(cXm& zlbwq4JR-=8Qa8yiE#8fI5A~R)uU7W_PY{nY-Us_0lmz`G?0snW!`@ZdvfUK?wB~2v z4={HPxmJs>qc(`V^|W^*cPIByl=bgPZ7<}buY@Mb0`$pQ`&}H5N{D(FF5%`a4{$o7maqAAK?6dF5L% z7p3+h`4agu`3m_e`I@3U=5=askZ+=1%q~By$P`*6oL--$Q z{>RJ}>-|$r|BU+Q(7&Mlr567R`qz1X>v8bEW!`t>_xbpU9}xdh^M8W>GxL5Sf7Rl@ zA^yAO|Do)y_l5o!bqiUL&l_0?yNG@9z#ggjqhODwKSt9xfW4vew|+qp&1wIA zSRGC$>MKJ|hu7z=k0b{Ynbcrad#Em&>TmJwL$@>~5lEs-y+ z+9jaGqpYR=A4*`U$7^*stgD1dtg${?;>r@qD z^wx=vGs>1$C)XL($-cgUL@W}{#FB}$93vyJ%;Uq*s8gi*&l+7GpWDG2V|@`-V{4$f zy1cpvwRI;K7)?#pxLB{t>acsQeq3$Nl!LIab0Cd=O{djJxLgiDDy*un!bl_r!ik8f zD($cY+be_BwMKPIb8~wzP+i#;XbFalra*I`t+oau?RQu)!gj9{ZDehyQZf$7_1Z!K2P`}k!ZzQcrD@N5X<|c~GF6zuXs!}nk z=vBAV=k};R2qadQ$JVIWal0|KPE6<2 zfKj%xr9Xo{j2qSA&bVm=)sjhz35RoA>UUdR9;@FgYR#H|KQEuhAtrUG9@n}i8I73< zc4hGr>0y{NnAw<3p=5kOjeKjyR5QB6BCgtQvsoST>hm5Jl-{SfHjhs%%;x3B)K*NQ zHD)Ruj*F(OeS=s+o3v*&OtVD7jHs6MUBM`Ao@a!j39C`nBQ9>G-R87<{1{BpX`1to zmYpu24J}uq+gfAdV!BhYsBAfm4(zbF>kVuNOIsE<8ZE80i#cUl+m;(mZJ08$8N1DA z#dJ~4NQc{loHSzJ@!G|dX=pX3)+aI<_F+M=t?i%7raGDGPp0G=!noV97@cbM(8SsQ zW8B?7zaMok!M;~>k*Sn{%Ob9R0N(X~#eP0r9- zj30LMSQwM0Kb#m;ecak6B+u@m%SU+-1-DXa|Ui6%W*1H6lr=Nt@$z6xA6 zo7;h*6q|KJEYXz;$Hi3`WlaI)a3e%*=n}CliET=A^0lE~d^Rr!w7D-9NhQR6idbK#ob3#OB;1+XDRL|F27u=ZHsYBNe}d!sieHCsT~V9yS!us;$|k6 zj`bM}W9$3OOx#Sxj7n=ru1+szq0?=}rHICqgSPX7fNSu(-C_t@8;z-r{h5fmr}p<^ zVivZBd&REP6zhv+%&5_fbqintCymy&>L&EL1MA<87Tn^7p*eqR0sD?v{Q>Od%@H#m z7duWQX^mk!?-*IF1fwpt4s|z)TQ??B{w8NMSzC?33Wp6hO(!OkY)9Fi?AVIwge!5| zY~0^l#3mDu?K)s4I3A0{9!g~aE+K;cS?<7=;==`q?#bJK8)28viQW;<13~#LBW}KY zjOi$l&oPF$$+a{axREBpc8}kv3P*Z7F?&V#IB;jT^Q>v-2b-tu!i`Baf!-0@yx14y zgHp#>S*cN*UEfAH5j7T@X)~OP^knZeMm4U|Oo{6eI~i_);&POkYiz5p)J@Rg@`)*1 z6%>P!NUlky&D0v+x`RRZA>kWU)+Okmj)@MN&0gX6L@O%LK{TCRU0oF@|N0N_;bJsg z9&xJ-V9%&CV<~IQ#O6~QOQHL6Gb-CcXQwmUf-r|_A-jFJ>^_$m@wdgbn3*oz(vzuf z`9QcbDE(|7rHz*Uh|#(RlUY0lc1)936YX+Txa}@mw(Xq%e_S00&4JBabYly)5lk#I zkxW?=*ha)&C>$e*Gck?I5E`2y8$)+^E3p0KT9bnJKl)L2FglaX`E5?onRVC?7A4F9 zYzGN*9UcL(j*|V@yn>pRZ?V(q`Fo2(7^+3_2!`Rqtq@!O?2|C@aee7bN^ZqgaldZC zH3~r+#>w&$eQrz<&J<}6XCO<9D^)wC#X+O2w0KqS{CyseyTYgSm=rpP*JDnnLu@)V zn0|7Ada-Jwn1Z1ib0CwB^q9Dd#7(I3CMajRv%+rgQs=MCVV!uOBzU*uy#b|D!suOD+jnov{DE)$xk`QBQ78 z$vTBFoJl1fxBNc8OYB(0J$nWrcZ96IeR!g zDqjEGKCy?2^Qm7QOY#&fi^a36c-n5j2Ln6Ck2xh2hLJ@r32G z@dm#T?`b~I-|p?r$<7$w8vpjplQ!y8$$@@|Ry>Z%*GM_z{S`iU#FbnB;>4sdto1+o zaP^Q;I{FmM_|fNc6`_i7r#`1B{MQKoZKL|{IwPH(f9o%w5QfuVrTvf2c$3tA8KN8FE^GC;M^>o#0W>U$xInak!uj(EX5Axz&UVlZDTSK@a z;`Pc0A(xHzP$$5z8Fsw=+Hf-xQ?DVKz=xq&M18d?5O414f1_Ny%wQ@StplC$Sj4Eo zM-#Qtg)7X6J*sslg`v8WAI%(2zg@JB%P%ds8PD9j?x49YW7K0wX5WOxQzRCsQC+(j zZ{rTO9;t}fO;ek66o!~|@`1Bfe9Gb5Rv-i)YK!D9JSPls z;qs9U)6VJl+C6+bY>&peW0`PV?0xC}aKucj?I%>*AB(j28w->1D88S{4|=lOMPpqN zfBx|&Ck)kkcF?ie<2zWg=>g3It_P;6`2?}1oc zqJgL|MA=@>+AkaC*9Qd^H;_9lr=S zs_~ujYGfUq?#|r9Pfi%ljS9o(@VI>rF?{lRqU!6XSP@onJ@`D(h2KW-nJk41S!`_* zmonMm!!mS5(1&;+%-KUtOs)cl$LW$Ei>uIqt$pH4O7_vm+KDEYH#cU#dA8Ly)rwCD z_}E!#G~?G0tU@2wgxz8HA>#7aBl7?E^S5&Iu&W^R) z;m+S{hsAxP=1YAr)>(Y^-V2Rl}*(x!C-B5TlQ;27z@;H+BzG}nzb5Vb~>}S)q<)N?k4G;>=ep>!Al#h;-@7nC((|>=8-RlIeR#Li%WES z?D!Bj>k;ho_?VjLHs#MjJGDsLXidh`z1V~BQxQ~A>BIShP zOe8eybl{sF-kI7kq}&F@tvvf_x;?uS2Be7Hqd&V5ik9$AH2;d46Na@BPHJGYrIVIn zBb*e^U5~?O&p)xtN`)|->%pDGg`Zi)J{ZCcvmxCtf2OpiamVbi)DOhF#9gd^K<X%97d_%vj9tB>5``&nJOD`oA%2HY_&l*BJB_#v$hlPD67rA?z| zz`&buEGY)Mzejw;EjL0q3Z11UlcU1tvgdDY;zXn{tP!0u7X2>Hp>)b-7||(nv`W2r zsp0)nRSIF4y^7yi8|1F8Qe| zEbgwt9?mI<$1lIjC4K=DXRr+OHkw&$rh3ito^W5E8SfmxS7-5YaRu&U3>gb^q3kJ$ z&B%p!IiJ3GDGaCc(G_^AcX{wsHfJfuI+?KJRW)wKye(^MNf>s#(!}EiejLKDn;H2# zoa}jgF?8A7;)WK+W2m%;dD(7!!j#WUm|D&Ksc^TMDvxo;54B1ky92&ONAM=8`dM70 z7jI6H{HN!fFw`R9YW3juE56-|-zXOKBvY~VNxYh>2haq5G-_L@wCt@W`)h@0I$Yt6 zM6f1vONqizONmY6_kEZ6ja)QcA9F_4D^p|4StBj{EW#xl7ES2mNMWdP6z!BwuR{}+ z=+n~a?brvVlepZ+Jd9|#V^+Cc9>3y+m^U- zm{|XQ?b?(-)TN0V2iBik{c2lPiS^gol`^v*68T3qt#y#(^COwGgte5H`PjXPNsj1{8Z)B=@v69^Y*uMpYaB%qcE^{N+#L ztMM%nzyD_Mo1qvssZ5M#g9jlKFWI^2Wv4Kl-oO}m@ROJLlb=vEHUs=lD_>{$M>Gp# zviIe0YuM~jGk$24zv<=dVQxA2%Ht8O$Q6jcC)m|0mVBTUmbT8(fU655n_K5uVK@_4 zJOugFyP23zfmnYWzfaX zRSJKodB%rCxB5FCydX_&!nQAdxWiuytj7~-5YO7;7dt$S@vHb`b^Z>AU)&Ks&Bn|vU=BjkKuRm^D`}(4t zjZrflNyYG&W68ujRc7Kp(i&p@*_O^35s2fxX8i;29bph$KWOKi|7v~hmb?b zVdQ4yaB>7$OpYW+k)z2mx=9b|C4Hoy43I%GM25*SavO46ayxQ+atCrp zGD2bw$azpsR*;n>=7h|zBCE+7vX-nPqhvkVKsJ(1wo zYsr5-iD=sIpT;TJ%`kebS097w2`qLusfA4`=X0_0{S{lUr)Uo zdXJ_jsi&a#(%x0ePgBbvKdb3|)cc`t$m>JVXO01+!;m&HW;b$oau0G(McMAX5WhF| zeKdVvYWtDM3Hr&@Pa#hwPt)Q~hy5+}Gd2Azls}vLIpn!o{&~nhpE~YE-V^T2aUr#f zU|+21m%zVN^Dl#cxx+C`l(|At)FJ(=pxw)Qt|qS`uO+WT%=NTyAa5jZQj~piGvaQc zek*yKmVZ0)@1TAsc^7#%d5_{SdwlMren0tuqP*UN)E*)qCLh6d9;N*l`8fH6qAdR; z;+~@ZH2UKM+Rs3LHt!4F@jTK;NH1vRUu67C&|jwg3i&Genxeek>xg?p^WS95ThQOu z^mnMgOTLHv_i3|@vc3-;j?MhkKSs>&NS`QQ=6?$RGtK{;v0sp1BL6GeUz6V`%JTMH z{SI;8GyezjM=kCrYCj|H7o^{iens23#szI#*82zY|J2I=B_I2ag3(GVD1bjGpI=Z& zy$JeXO&>yiC^?MW3~|G0k06T`#q|qDQX8eX8OE<*47IW3IC8wAEH{DLM8r>`Jz2}& z9QqXMQ^{#sehKoY=l$VZz@LeP{wOez&P1Ap)P^)0Nwjqil8<%GCFhaz$pz#>#G+2o zR@vSy;crEI5m`z)$*mP-ybJMe>K;w^Ivm53)H&Z|e1KYz4556OcA27#+lJb<6-h+C%l%URC~=qok775*ybttNNU^4Cz?8F6bheHZv<-XDQJDTpH-fYgrE zhh*jR3p(I;YJL~|gyye) zn5G{N{Rkv+Z%1nSQSgsu{4wOQ7n}_J6r?klcPe=rdAedT zVnmW{KMV1X(LNjYIY_tDKNt3SNLSE5ANB>xmSr!bei8JGXuMZ%lji;jJX5)oq4_BF6wua_mKA@ z?mpW0E6REwfc~K7KLr0_#yz6xk2)M9F($IiQQxEQqp#HbC5)>=Ts7?) zEx#6eo#sa!jxm@^vL73ew*jdUDNd~k_EOr-TDchX7V68$X!AlL&w_5`5o{(HNT5-3FzxIeLehc=Jk+CGNr}! zB5qgeq77Nv8O1R#LhGZ}pVwsDHc}rz%ptTlk-L$*lY1a$Pi4!pdr{vT`aZPx)$;eF zwm*3Q;tr&JkQR5a!!dRk^+PrNFlvXBN03L7N0CRP-eYJVt0>!d9JS+7zkN?9Ivj;> zBK{<#zmQI5%qg(HL^>7eW~9@opH7|u?J}e@>7PZOtthW|4z+X1^ALYN?F+~Y6=nQI z)GkKcB}kVle=M#ql8BLUmm}{A+E;4&RnV`devPJI3;jCHzn(ESK);dpO^RdxpnnVU z-axt)=>=-H!M>gL9k3s!eJAX@lr7tLH~f2;cQ1J#dA}C-0JR4Z|B$lB)ob~WF#l2V zG4gT5KcU4vN&PAEY4RDwJxlvJ#c^VspLaL}bNwmS!%Ld~GRwR|zKZ^AYrqX@5d~N`9s&`}cEdUpO4&xqhVn74@&l zZ&2U2w7(<2Cx0M+B!41*CUI}F&R-R!{Tu50UGwd^@~6_q^ZpARSW^W^Gm!?Vyh4bg zj>00Ou}FjS^%M?a-caaV_l28j`NNq%0{O+7J`(;Y%^wYajPfVAHGLfP@yaipKu#nl zk(0^I$tjAm%~Mf+8ub!#x|Tmf*%SH^XCNJgG)v2$4Sx>OzDRSCwn3VQk zZIMUCmWF8 zn9ra1IQ*q~zpxp8O!HeAx*Lk z@o{C#_P0~F$PUFxh%fA-mVmtxX&q88we_&Ol`ZS(fuCevDz8tvj@qtd8gZF?eqok+ zA9T!FQJ1W5gTpbIYfbtC$lt`g-L$yfsqI1Ti8zjxEVnoOeN^6L_Iu%ejNP9+fIN^q zh&)(P)^iBzIh6Wg?@vA%{?SMeARU8r6U!V6`#7XC=pPUJgnYTe6XBnv z`KU);{}l36@-*^v#mPUR+?hz1Qa=m!*~*sX&w+og=AQ@ue9gZA{)L)<5&Vla{}N?y zZm*fklwWu`c?EeTc@=pzc@65jmbN`7uBU&4qHO1l4#(#9J>RV9w=nir@;35z@(%J& z@-FgjT=yQ@_mcOK_bX1BK>tBS*^Y;(J&f{?X!@hnA0x&3euDOsuwSJ86!|o?XOu1L zeU|!j(4VLMLOy>A)>Pq3)L$lF$>$foO6@hoy{_qRI2=>#__rAIHh$v%5OME-TZ8X{ zxYxq>kiKWk`>;Qtjabn}+0Kuke@y!m@>B9N@^eKQ{{^)#$*;(-QO`G8+_%)fb2z3B zq5T8eHJJ8~h{s%(b^fdpTT zc5~=cXiwGhr$H~#{OOFDLCz!%#Ld#;W>cR-&L!ue4>)#33$*x!D8B_`w$${kpf93c zs_9OLV_Fq;7h;Y#3ktNiIP=+b{F0;ny&)maHSA$XiOgo@^i+k=KO2;2IY!YGz(c zQEDwJrf69{R$g}n;#aCT>9@jP#qz5)eJ5&b$ej_lmi8{N+i9C*8yQzD!F(yQs9_Dr zYjx)Hq@O^{{z&VXzaI7g?QXJ%Od>C(Y+1IK`mSV}%pfkS#r47O*Zd8P+o)KAdn($b zGwpPu2X>;GeGfXDEBRSX*Zy z-KPAavtgfubQRLMTKswN&)5757=Iyo5%Moax{Qo2cE4xLau7n$MqZxAS)9-GR6}^Z7-0!M|Jc?}2}>=HCbZ zewKYe(;uYv5cx3qh~o5@Q0_6L7pXrE`w3;sc0Wn|De`IZ8N@wH`#D7!|2*^;^8Soj z@No^1C-Yy1{|fC_$=9^}*QvchzKQs^Xuqw+y#xJS>hF>7I~+5xR*ODR6xWmfNBIA7 zKIk8VCEzEF{}lG`v_FIWxw2(Dzo7ml`4#fN*7Coh{w?_(`91jq;(tW?mHtn#e@1)6 zJpV;ewomxKsTh$bul)!4C)zs#`F|;YutQPA3@)HHh%AIwq~#-c@DSz=C5LJGn<0O= z@@Me+vfg6GjwDARZnRcz4E3?(IC4BW0r3-&rqG`Rd$PkZa~t}bD+*tv!Bdf@5QF0&Lj;*8H0HxIUD8YAZ*K=MN513zKD}=!

Z>K2RvpuyP$Q=T2SY!E`k}B7Q?_vn?gL5Ibp&~&qSTI}b~Jemc`SJx>N=kG3FL|7Ns6+b zlM#1{=AX)#)5z1wGZ1&C7Izl)v!S1(>E}{E4?5?Ur+r8O}~-aO~}8QHv3uDdn@$YsNYWBLEfo2 z3v*J$%J{nxe-G_@$@|Fr$p^>>$%n{?6=gk-pq@uH|1tQFGwuoUN%AT3Y4REJS@JpZ zdGZDFMa5Yk(SMnIg?yEK4f_)2_~6&+zd^nU-3R?G&;hkSnME*?vLjI~ayB1}C zL;9WiALO6pU->w3tsw=xYX9`yN|z5xD0=50Z4Np3|hQk3V?!C%U_W=)SlZ_)f^@Ru`gg{H5B-b#HHxtiRGT%$Obl=rxndApD% z*+#}uZ#!*EQI_j~-l_Rr@DrN9j%C(E@21^DCbhT}wO(>pGL3Q>+F7!X>?b#n8_5B3 zlcH?rZm4H>>U(JVp30v4DDwAa{yyZss@s?OMcLNBep${|0I| zk~bmlX4u%`xP`{VFkG!9JfP7GKo;`OTrv3fe#ylRuC@l0T6@lfRI^lE0C^lYfwZl7A@YeF*%YP1&*dYh8(LnpX+Goc=$f%O#mm-o08j=lH7`1M3$0Ha%<8>x4#6 zqo67h^LxWSuTfmzYYud$;)(2lsL?1{)@1uec`DtTN}Jh4DrvY^gq-fM)8liaww$Rd ztB%CHVZYDgF^cO_=K5}PLsHe%+8C*eMpiT!wTdTqA`rO_?*cCL28@u|<4OSxFo$6~>b=Frei>#=uL0yf@Dq@CvwcF=(dxK7QK-{P8 z1pUt3je5{?Zlk!l$4qAv$xMHuyW4X1o6f$BQL?feH*TaljCw6|k+D2txLZA5RP1pE ze5f~XYt_2le%x`q3q`H8YTfd3<)QC-dis)yxa_?Qn(6VO^IT51FC?m?6LzzivI|2l zHwL6J*}0KJm{FBR8>&_s4QfCS+~pOQf;j-?@A<1qs8jB z#Eg`cxV=879AH#JJ6DM}9CEt@MsZbgAnxo;^&2G>ODY;#MM2C0AEt-ssJxvYLAM_( zr?{abkhWUCT(U+wZhNkYzIVBC zJ5^?u{T_?7%x^JbR=PhCw=#A;tpSXn7uATVMcZ%Jf~nvR2f_iYrsiZO(Ut0MPb51t z-B~VxW9xQ^|uD;GLi}PxAwRG&M zWyA(DS*r)Fl(S)V5DPRYW`t;^?Rat1_C)gs{4S$7+L2Brvk7|ztQoO{{y!#}SY?qO zD~&Z__ruDd&xsW#)&;Hr+v^m)Ywu7nf?^Mlt3RH?Y{TEa^+ri+x!4cNQ{^dKBbl_~ z*?cc74+fng`dHd3IyZ_cYZ^(!BzkVRuo*w8{b2nC}Jfp%)V!v5$M0&ADwTAr85Ehc` zN!q#bMzefApSaOgnJkKwCX(%anQS_NValYs`^5H(ooY=%tR>Q1hxS!0i?uY@RK|>0 zV{}6kc|CsIN>wM8nq0Ljs#;q9Ipw%) z!)Oe)aXbh_UA7}OI`J%&doMZ$TU&K+cf3!H z{sp{F{6%d!J71Y6ga;COd4tuK*|pD<`{X9mC|SM2YBOR&$xYjZP1yk~;VhQ#YPSnZ z0e6juHXdYoCs(oC>kW%}(iD$p+R`brz0FK=Q)sElbeT!9k}+kBrhYVB)VaptcA-Xh z04rHO9yKpF8Xi|zj7F`w-o)ub-F<6BI@tH{XiLh6nAvSCb}q;Gib}Dmx-b@4yrQYL zXU`Kk)jWO=mRe(6Z00?wwgeu!VmmKcGoji_L{rJm+Ju?xH0ov6-v>^3y=HPCzp7is z(i7W+xSlYvd*)V^Rl`R6FTl(pT>1Y)uqywX*Afa$ME%A=?owEEon2E!SN=` z^UZ2E&JJR`!+s&IK}VY;ZXcdw*xhG;V3f@MRty*(EbfrYiD$4Vp&T?aKVUwWCyZ99 zPQZc_4>V^yiN=j9udZo`v{Xhb%Ny$r*%g^Sqc@%EPbA|w$MtRMO7$83Dg$#{oaSV= zV_o7Ij0G+)nDewpv=>tk^SCD7(}v!b{ZL+BzgU&V^zvg*56eLlCkS{R@f^x+xbBcY zj2(TGmFzYL>}45UMOQopu|v3Vd2|m7*nsMOb2 zwpKP9_Kl+xQG*9_DXb2HbnN=XsUzSEqq@4zRJ+;U&QoerRjR`F?B|s@tcl~5JYm_6 z-`+9ABngN80rbwwOv381l9|#j%k0i#ep;9oT=g-PWmIcfMx;u1pr{{Pl6qXp3rJ6N zq}WHWCwb(dwJw7yF%R3F>v}QNfo3aXndx|!(U|TulZj1g#^G^Tlk7;Pdz6c%oPPl4 z=4?$elTBp%vX*#y=ECFRK#76I>3-4wqAD@v)=Vgug4k8$k^H`k?z_Rt7`1)LUHhy=+B7Vq zsxN`Z^?m0-#zD(eR}@b_G1b)-ZBJV-0UTO!!N#J+9FHr`0DoJFEIS|x@=wuw%DO3Sl#Q)hRo&BRd4)0Q)t zGD@1NqK%6q(elQIhRX65Lk#k2*%)*b9yqK`IND+kxa2lnyCETZ%yQzXTGFsWIY!Lc z?2L*7gs298J$Nrv)kw#lFJh0uZYvHlP4TYIQgNni$FoFrePelbRcU=?b4^7}q`_Xf zDDQS*s(9r?Q1kM0+U0ZmaGYygpX$M)%zE7Y)_OS;S7O=-#aOu#Z=F~teR#pa)75UM z+wE1SB{`zPvF}$j)a7#pP;E`z+9c;BPMh`Rm8%*mTa1h|Eze7Mo{ARXxg#DX>V}o4 z?Ok{*`9oN9HQ4nsR@|AfjFOs&w2g?`-(@A3uer{bk0f}1yGbu*#`2`KF{?&2!&ej4 zd94kVD_aaTutu!1d3jAa7I4(rl)G8%47q9N7LGRW+#auZnX9+@GV8Ii;(UoEi1Dqo zeT;BK9UwwzXArli)|Tz*Pjrtf=nY~uEbS0)V5RuDfiDrMv@?l)EGFN|j7Dl)7MM18 zdSC!TYWC5wcL;3HVYfVzq|!LYi3bzL+L;z->lqbk9J%{b299aWz)rE=j3?Gv*tNy6 zkFwF8N*Rl-RA<_3FI{C}Ly>!hxH4MrmM;sk9rW^jA9MxX*eqi5ZAR24_9v0V9@Q-0 zXmanoIM3q?gqRrEi19`)_8~Db=;<#>7;Z7aqVaAsEjMDPDfYZEEmph>S23dX_)3B) zEO)**?gU$$Jk^=;Y@*-H=uC)%&#;P<4ce{)2|lqMt*U$6m@Ej zYK(Z@^82uJ)O7dZ$ z(|x_EE~{U@a>zz5EXF|#O~E&wK+q|l^f^cCaDT|}K?Ms7?l(%7G?>XOZb=mLIK|tq zJilvB?z%X%h_i8hW}v4nk#eg24==DW>D$L5@lX_{v74&hDd*TNKxYK~V#Jqbdpq@o zEZ(GJElpMVWo&<%L5&z`S3sRjZ6`Nkm_%-$$BTQaOC<5s&f>F#*g30elxMhi65oo% zK};T#ZAWYk>h>zpCh>h|VNnF@vq;QakIV1GXHa!uqT}(S7C8D$jZZI@*<&p36Q}r0##k2P%oc4y zJ85G_}p{A{NBL_7#EqrBo>RaPTB>~y)o4Z5*ftF7)t#&mWh(27c7 z8jS<%@L@|FtS}SgI4e_sqTm!Q4&%}M%bPfmudI!g8?!5$OIy^V3+G7rVG^hF)v`Wu z(i2Y(%~7opy&m+V#=1l=UTQP!>#9U=weaK(iBk%?QS2L{270+c77GM-geenkH?4N7 zyN~Z{%`IIPzO^aIs7j~$dQsm*Y&-5Ud|NIHn`LEUJIi?;_*3>T|ES+Yyk1WT7nvQx z8~zehG+7q)`@>~;jMz0P*Zj9yq^7FEr$!IXYYm-g%Sv{1rxN^9(^%D9S=nGLi8j_$ zs1L1crpT)Vg7Gp>M<8xrCg=Z4YwhcxYr`QI&Ro@a6G7j}PYa{e_w#@Iu##aHmE-$7 z&#-GIVViT8;dRy}uAH;~>#Fvpak#~Y_n-@Bj~KpN#@A=s6L@>W7Z$#v*W!C+d0kBW z+=HGRkD3GMMP5Xi-hUqp`x3aj03HOm0BTKaN+sL4Ce^ZNTCUXmQgyGH<_Wl5W$28K zG7~>2h&$2*+~R-TmKro)z~{vd*V2YV`cDdH6AxgTIz*qRLOTQ`vHZkOyO19 z&9e5mb{Xma>-4s}R4gQX?Zd=u%%<>DjoR$g?ZsO1icw>KqnBOkE9>x~H$B*$RF~QU zZs~t?sV5v354MF9R$_;qfZa#0(D%;@egC^cd|MSSLNR<(X-lPXtdify?EN4HHH}lP z8~b*|-nb{Bzx-uBQ_a4d{a@o@cNXTKH{iwuuf~qUmzmlMn1Ab=U|EN!&0S{U$2-+G zy6|6@z;s(M$tYeh((TxHryB8+i364RDY6;O#OKW1tA+fn zq`smuii39UgB!X(~5Gx*xme#ZBSMzu4K9TAO;~cntR9hY))qHZ=FOSxNq8 zYi+`H#-naO)_usMu0#8Ou6K7x`~tY57heE-#rJ;v#rJOc@ksVwQ>>*j5^bqA@WQ^L zu^HQ$x~I0XaL~RdX=Bjw=Rg0!c#C_&9xWaUQSsJYn#CKM{K60wzYmJ{HT63VM#yxv zWB>4Z%dj5g6YTH8e_cv#V*U_D7*AU>BQ5p$XLY@a^BTSa;zLQwh~OtHd>W5s`{MXM z&wE1a+tqw9v31*Y{po?V>XYW$9&=qPy;gp3PbJS8 z^v~QxGM{^PU0YwGyM3+In~HZ84eA{9@aXz>Gi$b)8Eb8Jpx0_m^x_>dWh8pn;y**I z&6qvC_%)%Yr@gJN-O9w%iC*y|**lA^$GwOecB ZDf}pQ@6g3{X}p@J+t+3io2+w;{{xjsLht|p literal 0 HcmV?d00001 diff --git a/modules/ingest-geoip/src/test/resources/ipinfo/ip_country_sample.mmdb b/modules/ingest-geoip/src/test/resources/ipinfo/ip_country_sample.mmdb new file mode 100644 index 0000000000000000000000000000000000000000..88428315ee8d6d164a89a40a9e8fff42338e3e2f GIT binary patch literal 32292 zcmbW62Yi%8_qLY+fe?D{EWIb$&896NB-GH0NEcHU2qnP~s@TN>_J(2u3#eePSWvNm z3id7*R1{Feh6USqKQq^62MqfC-}m~Yy;yZ*F0d=?2D`%^uqRA{$uI>5U=XIl5KM#VFau`7EEtBpU~kw59tZov zelQ#6!2WOm90&)&!Egx7g+t*mI2?|EBjG4G8jgWu;W(%>Y0l+%cmkXNC&Ec^GMoaZ zLY-w(=R`Oio&@uB^xxEZ&sW>4R-ks7S|Ra>>{GF4$exKk%hJ>?LYuAl97msvJ`WZv zUgE^(qn|9j6ng|$cCOS?2?o4H&r?FsU(fLBVpiaebM z-_>fX)z%ZgM)tK%d=2__(yzz9!HM4}{U*m>OU^p!H)G!dZ*}r+vv#fK#5Ta&;T`Z! zxDjqr-Mg%9)_6DiJ@8(5pYrZ^@*Y5cQ2IlT{;>Q87JMqVSzFN8`zQ--im`}>r zej9s>?5D7whR-uBHcJe+!{}k?qpDFKi>@VOR_$Bozc1|zbke(C*B>c2kfc5B((s3GEA|oJ=5p&TBV>`YD}vX zlCQO!xu(Z-<7Zmm*jdU8JANp@z;{GPWsJ`evABDiQVSt>!ojS{M*U7L;9WA8{sA=?=Hpf zM!yH%3-5#XTYB@XqqA7)LH#a$SbRu~&eQWBk>65$RPC_ZX0?54k5T7w+3%}8f&V1j zB5eouQ?j4Ne#X-5&sOQrqCW?pcjCHJm9{zlc5+^n{*t4=Z0$PQ>qyG~hId8(Z6n`809r&)LneThrE4Tj#%Ky-b@1oB~(m!_ePvn2<_`BtQM(*eE3%Cb< z3BOYAUTfE_>+^Zdn~Ryxe#O4V{sta^-@=3NJ4@5=d$b?a|3~bf;2|gPXT^U({}uig zi`VUk_Pg@_!2T2drE^mo`)}DkYkT9zRkC)R#JI|iUPXRYV)2e%4Xrv%fHjn#=;YOu zU(5P+OC7zAboMQ-o}<^7-oW~C4PhhL*vV_6yr$^Q9KE^p7S^x3o_H&16VzI(C9AbT z_rtc9W*^!q-rn&$$nWU*o#@jUc2T^m6YqxJU3w2k?X(HbwlwSLCB3&=Ut)dWahCP;{NnndWjk7q&*#NVjR9~V9As%?gVBbl zPA>LPCq4{qI2@t8k!oYqMrjOhexohDdd441?l?=MMPuWYe?m+*bC@WtSZxyhCc`O~ zCTFVR(;WXq`O_W$B>8!cpKtBB0_la09+5vot%&$cILor0=Ha!t*=loQTHIXu^J2c4 zQ;D=Q=s#cf$=L4wSRlR3>ZZ>^au&hGiZ8)l3YS@${8JQP?)ayYa~eF|vff(dor!i9 zJX_jjYUkjufahA8`sbmY4=+&uO6*neLU@s->31>OCGb-1+gsR|J9$^A-j$AjmHey8 zTkYuApk1r@8b`kl{d(y)VBhG(Z?bm%%8IWeezOz51?^VFZ^K^i#5YL4-SO`r=T5j0 zZi06y|88rWeYgkxUU;A4_hUccJ{HT1dmQZv_@wf;IPs^@ zpO*fNqi>b}tm8jN-t+JU#kXN^hcCjH;LFN+6?+HVDeV=XufFbsw|J&yk+IbXrOitlsc`_aFa{tfm4_^q@bun)rT;P;mGb>6%d_oLcR#16$`-Wtt1 ze}TWk-;{sY+9v+H{6C2O>F9q+|J(XH`leRpN-!>#S6RWzRnV(Cdc5>%=+$8Ytf5#< z>_p4@dInyL&bxAL<!w3*;A)7jg6%(r2R2a`YnUvmJj9Ide7lo3V?Pvktoi&X;yFcB$-@*b878)cVal z7s17F30w-7!BgOJcq%*%p02UZu(qjtCi+?MY^rwX|9>yS!c<^6J8R%Byc}GhPGqhER9c z=#60$*c3KX|K`>QyaQsa9S!%t=4_h``DXn+R_Qo*d90&Wt zelQ#6!2aq#z}lwYK>34+4R-V)Xt|0HwRWS;Xv3Yn5sHsQ9|cFlF;3oCv~kL3&Sw1a z=+WmfLGg)bli*}alRriAsWHEi?pBo($(!!;Ro8j1a+1$i^>Q%}Js%dpLKuNF;7m9R z7Fn8cXRF>E^tr09Ggf6DezEMI)JmNE`SMRz^PX2Jc?+C)8QMa)NO_B~m%ycP89c?( z%xAgsPj&p$NXUIR3erGxQ+0xHJU*YKIqMfJs`Pdi4;*Hy)ty26#?2BUYDi=$? z#PKhce;Ijt7G};@z$@WZ@M^dkUIVX%Yv6V8dUyl85#9vX!gcUwcniE0-ezg$wqEny zfPOo?1KtTY!cFilcsINU-V5)8_rnL^gYY5vFnk0)3ZwV!G5p8j6Yxp61wI9zhR?vQ z@LBjAd>+04x54f3MfehY8Sa2P;VbY}_!@j2z5(BaZ^5_WJMdlj9(*5u06&Dg;79Od z_zCzmgaoy*ZKGw{ToL=ApKj%KPdk@$Nygb4{ASJt?^u+ z&ud;yQ}>W!KRf;}@_%*w-{c=A_jmXQ{1g5K|F-n%R`sb>a{M^?l^wr|d?lIOcvuZq zhY7F-OoTOIEm+&q)T)D4SM}>**LUI#q&IZ@M&vYxO<+^l3^rH47T7H<&3LV(w|4wC z)^4o3SG6tq?VNaf=^fN&sdZEvMNTK!Sy~tDu1>#hXx$a>f!z}(!DN^M1270vVF;$d zbeLgj)|IJovK&85PA}=bvHQT|U|-nJ($vXD%YpsXe*pGCI0z1gLtri(sy@T4-9%@h z>Tt&%;qw`7q@_2e8F#eu#$b=NH2OH{U*!0U z>9Yhbh0Bz83ifhKv#wL6pN4)qJOiEy&w^)Lnttb~-wMY+mz?w9`S1d`Qu(W}FSImu zE|Pw+<6k2GQpdl{+D$e|zryFMrZZmkN~=}93SJFY!)xHRK3}}Gle!Zm` z{|4zdqTdA9N_!4_o$Q;jZ-KYM+u(Y*0p1SpfOo=;P-`(`-v#fsG}=9A_rm+&{qOeW@kzJ^J_VnK&%mwlSxc{u*Wc889=@P@+pxDg z3rA6{#o^Zv9{T-U*-Qs?69N%F8vSmKjB~SZ_9X|ZG0vCI9M50fmNY8cw@#_gVkYz zrHR!*ON2FHE%mF7T?f|H`O;mAuV-nVhvzpS-VipjH1!)R-UMBT(`e0LbJ#-lv`6tR zVJp}gwt;@w7Pf=!VF%a|c7mN@7uXecvov$+&RlwE9zBUC#o}guDT?V##|Lb#=@*or zs&*3nL$cGb(_se8gjp~Qds&)#z0vxp&T-g%op?XAY{hf1`@;cnARJ_A>JCO5qJFuK zK2-iNV#DDGI1-M6qv06!nSec3_Bfxfnd-;utQn5CG~=FN^E}mYV_5z*I-{G z`&!L;pY=UP?fB?AZXowYOH=PA^;?U+4&Lm-x<@J>kh_m ziuv(($-kT2d*HqBJ}2*fv?V{0Sa{Kf_-vP5och|2Oo*@OLNv z2il*?`^(Y)wswnj?MWqTSBry{V|mr8pjCzO%GV}(V|%Syb(kQ#26iH>DXkrLEm#}Y zfpuX$<<`e;02{(aurX|6X~u7g)(kdRofg2zR zyHKq+ejj+8rI|-xw0_FV#?FELoxA~x4@4gX2g4ywUasbIh3~=l;RoiW{|3&_gI6aEGN)_()_#qwI}8C8$7e)Y<-b+)Tl!LJJAVKt}% z-khr^SQ@*A{6w|W)oK#21#3&2q*g~Qq*fQbp3}cRS_9Y+HiC`SuL*WjOEYdW>CMqw zz?M$D6Fv=wIC@9vozOeOF0iYU*G+ld(R;w2FbO8Z6c~U(m}=RI zc~wudzFB9w>SvIb>F8N#Va0nndT;rCh#d#}!hSFt=D_}N02~Mh!NHbhK0{PL7k#Ls z50gF|eFPi{N5Ro>3><4|>W))?oh@UJhbK7k31|~7Tg_JfWV9)8s->whP4N>Qe>yoQ z!919+yaMb(OVd9heFpkWN1r9V2wih9xpUxLI1d)X5;z~8Y-#$IqAgI}GDlx1f05%a zmcNADrEnQM1uloD!qY5UJ%)bO|0;5>mcAPM8h9;S1FwVE`+TjdD*py}BfJT&h3nwW@D_M0 zybZ308!WwbnfiAy-koqG+yw6uLz>6kK3|Rd#e2~2_4yLMlz*Sv`)Z1ry?OvX2p@tE z!$;twa5IcP*T=1IYCWO4PZHk(pMp=rXW&-&tflGyocceH{sP;w2A+yy_fG&vumeWGzc#oi4+ zgP&WP{4bQZ2VKv~XkR(;y^8PCbJSg~zF(|K&e!l8cmRH@yo1=^!SCS@@JIL)JOqD+ zzrbHD&76KiI}Cr&tF{-4Hf{jbRhm6gGp+VGGz2wz4#H zY_0lj9N#a$EqU!6y}k4f=pA7vY2DR2Yi)W430*8r-LCSx#e6Sra(YPbY5N;LNq%z7 z_xhUt0qH^PR2YJ3Fdb&VOqd13mZn}WwBE1}JWk{E#qI~Qr47Z-f&HZo@cG&d#vcd= zNtfp72}59REYIk}$Q=$xSen>K^&5pg+R?{QYiz7mLUeD&+dQ+b6X-huPE@@~*puND zI2BHVC&KCQB$#Ju>gTIo0eYdMN6=;{KGV@>p%+P??dWsl&m}g`(TmYaEZeMB-pOj) z&`aS0Sfi9Js=XpR4$J zj(@)V3mku?{8f&Bq5O-ipK!6GUt;YxUlF@Zd6#2f0k4Et!K9hQFG#e|J9+Z)r2dl$T0@q4iE zh4;bx;REnN_z-*;J^~+wn=Q>akD)ydpU`+uVsC*@!KW=vzh{)U)$yMt=Q;Si;xEMF zem&!a?ZjS$FTs~%c?mnvb}IiB>{sDy@OAhGd=tI}-?lXKd`ETOb^Q0_zfax=j{YIq zF2z4`^pDX$fuF+N@H6#ApL9nZ{Pv=Ej$Rn zgWp@4_5Gm!KRW(T@(+>sGyDbq3V(x#;qUMd_$T}e{%u*q2P?riSQ%DM+DKG$oFcpSi8cc^7FcW6MFzf|;!#?mh*cbML*)Rw8 zhXde1OY>|8>A4R^9|CjXP&f<@ha=!fI0}x2W8hdg4(gnlF~`FbEIqAlbPp$LK9h(~ zhErmBHKr;)4gEwo9i9a9U_LB>g)jnVsMa~yGvO?0Mb|_`xHylf4THi9se}>r#t=`^f?or1<$r@8{L}~=;x~c zdDb@d&qu#N`bz9o@IrVIyck{rFNK#`ntGS3&J~V-B{^3~zZ!cryarxtY5K2G-gW5L zJNgZ1H!6M;_F5;tPG{;N^jqMq@HV&}Zm=}t-Hvt#yi@fzVsC-pLF~!zl8m=>^*8boc=rIze4;~M}JNF>*#O5H{o0G zZTJp+7rqDIw>0zo0PRE7-R0;X$^Y2#Kau~b5XVO2%{=%}|DEu!i%{X77?N$Ch zN8d01YhvHP1MpjT5Pk>0hd;m{;ZN`o{2BfNe}%un!eoa^l)`oRpU04s+hYc*f`FVBBJR8Br zkb9BXRO2^8YYtn$mar9U4ckCJYzy1L_OJugT{Ls(1Uthnuq*5ayTcx^r{)+`OY-^J zA15Zm6d17d#_-}MCpD%errA1X4(aNjp{8e_n5lN9T9(>kwXoU=YP}e@H|zsd+tlc5 zY3zQ=&vyJA`TZS#0DT6+L5dH?9s+aWP&mxe)E}<=5$GcwU3V&RwBwHBKSw8RGK&%WdbYhEQxn|x=$XTj>%dBnmQ_z>gQ{idKJ01HB zcqTjxo(<1|E8w~CJWDgq`DhoY?n-Q)Q{shA-bKVNR^BDpmpbvwq+gDHg=Gib6EntD z@M^ePx%XpV1Fw~~2KzdAy``yfgW@+j{!Q}NlDE#$ZG+l}@a{9N_Fz}^GD zgkQnEa39*>R3v8M}&D)zRaztBKXI6JQOP2y4Px@OZJdS{=u)i(OBw zkKMqDH*S?k(Um={0!_&F-vWrS{SVt z^zwRR_kqWWeX;w&Z0P0ZsP)&lW5of`n}_Fn@j+^Xu|0nXcCI)Sdl(!Jy}S|FBgIkJ zqv06nPWGUY<<9itym^s9*)vLu3i1XQlq@VR zD_xp1IeR#!2@fAvvaoEHpQPeQ+2WGYIa5MO!C)XM5K2u7q-XYwdBL=_J^o5LBeVUr zM~o~f^*YXrc;)>aa>kGICyy9kTvQe*^p7jcD~l}f=S~TzLP}avAT1>+rDwU5nna<2 zH;n{k8oI{T=)({K@{6{@h7v zLCrLj6bNXjs2!cBf3uReXqwBqyb(C51AR0vX=kMD6k+kDA{!UwTGTKpPPVrFgC95pJlW6=i6C zDM@L1j+vpJfZ>QU~bnhlHg7@Xrz9?5ea zR#ZH*uw1v9;TDbbz-2Kti+4~~@1%PZDsQ!{lgQZ>8OP?DZdxhFN4 z;!M%MO;2ct&O(;*0-9P9B<)!sGuZP8H`VKLNXOx@zeC&0{mH|}fV02Ho~N67lMIcQl{;I7c|5D<0zO8I;26J*XR(qp$dE0 zS?P&-mr*Cc4pQzPy(n*1Q_LV4SxKp>N#z}^XBJWMq_R?z^#4G*F1xphQP+D&?yh+Q z*sThtWNI1Os(_tTxu0eZowL)XPbIIcw@-l-Zx}O5xtmteI_%>N2K6G*t3WrwRIqL^ z_1L==)LBuwntdeYepW>jGrJYkZPu|#V_JB}-Ysw4ncfSMrAJ+N8975o>w*r=o1dpw zj<=RTlHR33?^cJriN##AVeKpVYqyWrn<{&}-iX-?7nGIe&GlY3nxO8Z_Bq2lGwN9G z1-+Mtj&kj`tJN=OtUq~BPHEn&GJ=|6 zdb(aECO_(#eTcpavh}JToIN++Tv+weOU>K8fZje@mE*Y=CTHAWEoz+JIeGI-N+Y_} z0q;WVr4z`~7Gm4aL-YmLw#c8{R$DhIpszt7l%zj1lpAx77~uE`TK(~PMT@=No)SpQ z(5zCD^g{CvLCmv9Ai9>}+UlXh7cPm+%P(12I@6qN?{%BzeRK4}3Iuuvj`Y3TK=0_= zZE_$(Pb8#+9n>q-ae_ze+aT}p3@XhlF3|FIgS-RkeYNzF3zV3-bm_ZJhQRR29}Sc*?C_aZMb(iva(Dc z$9MX4-tSNDe9(KsS!qe>8pB(bPKDzgvFl^?;T}A;MBgp@O4mnbCS$H(DV5zt-A+h(7n@ z78jK*vp*a))h4!Zz7`ysQC<_bg2|a7j$V3~xM}694iy%rg!5(u^+;v}%U6ArT0vp2 zCif_3$S_^OAtl8#{lmOJa;LS?+a)EOnW;ZgBPqeKK26?ST`{$y#h1^P%ANM=oE=cX z7N&1jK}6FG=rK_&CR8xP@?u9noYpFq7S7kB4(nT^KOz)7Qp~7mhTgLOrCv~vCmb** zyL=4&jibU-T|U3)Hs+4>ezF!9nuG04GAkTPNlyW&1wy)>&X7gNKj!DHW-qu+7f z4rYZjv(iH0^n&!%Xswu_S`|*ydC<|1p)O2$vAj?~H_zKkW@jo|HUEFASP;s{*TRpg zXw;1VWsZ5;Jk2pH?}&=1n&N-0I;X$)T>8(lUzhTwt3X~LFI;fsM5Ai@vCCF;d?Ml0 z)Pk`7#-t~mAD!nB;<4);{iMB<6VXdH-Mf1k;neh4JxBb{dS=;O;#cYR#VV%HsGy=1 z|79biKPYm&zbECEM(kf*rnL;F>t!2A$Wk{3p9&Gwiv%`y7p zvyp%y1->8S++pg#V9qMqd&CDAd|?4C+jW^|ZY&s=0q&HAnpa zuh-1f3#COu;mnlOwCK!Y!ZAzz<89bVrPB2+E(`~CH=?ipnD!qp!{}oel&iBcsHm8; z(md#0Uu||UB{dk1*ySD}9eak9EdN2(TD1bP%q!i zC`(@#rygQz=zpl^Ep$d8-+Z5<^UF74F!*0TA1auizMpy#=7%Eg{G#H~mn?etqQByD zykAZBr$clHQ^OhF2QQSCWT>KyAn3=9~6DpixKJ?K!#>9WG82y{kG@p)KFsxs4dc69Ldd$(aw8oniZ#o(=x)LeBN_MN)@m7A5X(OC3FnT@P5+=Q@Kz_s-d)F-O;0e)H@YZ z^9yw8GxcGNJ>ax(Os#m1(Z7H=ze~&K7)lRj=?7Ja*9dcrsmG{z^i%fUbkV+9`5EE# zqZS>?%FmRTf2=hhU9Xi>&->d z>C)*lO7yS&Wz*;7%`PdOz9>?Ag2nm;}IU$PeD&6__rk~D8#Vg9hf$by2>qWNVS zX?ISf_&?Hzs!DWT@wxgHHFII!%*cXDRn&E8aY&ki EKQa5Jod5s; literal 0 HcmV?d00001 From 46b069621e938a25371f96d3a0aaf5c8c1a9bb03 Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Mon, 7 Oct 2024 12:47:58 -0600 Subject: [PATCH 04/22] Fix issue with releasing resources in bulk tests (#114186) (#114258) A recent commit incidentally changed a release resources call from doBefore to doAfter. Several tests depending on resources being released synchronously which requires doBefore. Closes #114181 Closes #114182 --- .../org/elasticsearch/action/bulk/IncrementalBulkService.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/IncrementalBulkService.java b/server/src/main/java/org/elasticsearch/action/bulk/IncrementalBulkService.java index 8f36fed380c48..78c1d65ede95e 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/IncrementalBulkService.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/IncrementalBulkService.java @@ -194,7 +194,7 @@ public void lastItems(List> items, Releasable releasable, Act releasables.clear(); // We do not need to set this back to false as this will be the last request. bulkInProgress = true; - client.bulk(bulkRequest, ActionListener.runAfter(new ActionListener<>() { + client.bulk(bulkRequest, ActionListener.runBefore(new ActionListener<>() { private final boolean isFirstRequest = incrementalRequestSubmitted == false; From 2356006e50cfe32d0cc72291d4d0e2eaf77117c7 Mon Sep 17 00:00:00 2001 From: Iraklis Psaroudakis Date: Mon, 7 Oct 2024 22:22:11 +0300 Subject: [PATCH 05/22] Fast refresh indices should use search shards (#113478) (#114259) Fast refresh indices should now behave like non fast refresh indices in how they execute (m)gets and searches. I.e., they should use the search shards. For BWC, we define a new transport version. We expect search shards to be upgraded first, before promotable shards. Until the cluster is fully upgraded, the promotable shards (whether upgraded or not) will still receive and execute gets/searches locally. Relates ES-9573 Relates ES-9579 --- .../org/elasticsearch/TransportVersions.java | 1 + .../refresh/TransportShardRefreshAction.java | 32 +++++++------------ ...ansportUnpromotableShardRefreshAction.java | 15 +++++++++ .../action/get/TransportGetAction.java | 3 +- .../get/TransportShardMultiGetAction.java | 3 +- .../support/replication/PostWriteRefresh.java | 9 ++---- .../cluster/routing/OperationRouting.java | 9 +++++- .../index/cache/bitset/BitsetFilterCache.java | 2 +- .../routing/IndexRoutingTableTests.java | 24 ++++++++------ .../cache/bitset/BitSetFilterCacheTests.java | 2 +- 10 files changed, 56 insertions(+), 44 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 72485bb873d57..3ef128c4b1d3a 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -235,6 +235,7 @@ static TransportVersion def(int id) { public static final TransportVersion SEARCH_FAILURE_STATS = def(8_759_00_0); public static final TransportVersion INGEST_GEO_DATABASE_PROVIDERS = def(8_760_00_0); public static final TransportVersion DATE_TIME_DOC_VALUES_LOCALES = def(8_761_00_0); + public static final TransportVersion FAST_REFRESH_RCO = def(8_762_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java index 7857e9a22e9b9..cb667400240f0 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java @@ -23,7 +23,6 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.indices.IndicesService; import org.elasticsearch.injection.guice.Inject; @@ -120,27 +119,18 @@ public void onPrimaryOperationComplete( ActionListener listener ) { assert replicaRequest.primaryRefreshResult.refreshed() : "primary has not refreshed"; - boolean fastRefresh = IndexSettings.INDEX_FAST_REFRESH_SETTING.get( - clusterService.state().metadata().index(indexShardRoutingTable.shardId().getIndex()).getSettings() + UnpromotableShardRefreshRequest unpromotableReplicaRequest = new UnpromotableShardRefreshRequest( + indexShardRoutingTable, + replicaRequest.primaryRefreshResult.primaryTerm(), + replicaRequest.primaryRefreshResult.generation(), + false + ); + transportService.sendRequest( + transportService.getLocalNode(), + TransportUnpromotableShardRefreshAction.NAME, + unpromotableReplicaRequest, + new ActionListenerResponseHandler<>(listener.safeMap(r -> null), in -> ActionResponse.Empty.INSTANCE, refreshExecutor) ); - - // Indices marked with fast refresh do not rely on refreshing the unpromotables - if (fastRefresh) { - listener.onResponse(null); - } else { - UnpromotableShardRefreshRequest unpromotableReplicaRequest = new UnpromotableShardRefreshRequest( - indexShardRoutingTable, - replicaRequest.primaryRefreshResult.primaryTerm(), - replicaRequest.primaryRefreshResult.generation(), - false - ); - transportService.sendRequest( - transportService.getLocalNode(), - TransportUnpromotableShardRefreshAction.NAME, - unpromotableReplicaRequest, - new ActionListenerResponseHandler<>(listener.safeMap(r -> null), in -> ActionResponse.Empty.INSTANCE, refreshExecutor) - ); - } } } } diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportUnpromotableShardRefreshAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportUnpromotableShardRefreshAction.java index 6c24ec2d17604..f91a983d47885 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportUnpromotableShardRefreshAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportUnpromotableShardRefreshAction.java @@ -24,6 +24,9 @@ import java.util.List; +import static org.elasticsearch.TransportVersions.FAST_REFRESH_RCO; +import static org.elasticsearch.index.IndexSettings.INDEX_FAST_REFRESH_SETTING; + public class TransportUnpromotableShardRefreshAction extends TransportBroadcastUnpromotableAction< UnpromotableShardRefreshRequest, ActionResponse.Empty> { @@ -73,6 +76,18 @@ protected void unpromotableShardOperation( return; } + // During an upgrade to FAST_REFRESH_RCO, we expect search shards to be first upgraded before the primary is upgraded. Thus, + // when the primary is upgraded, and starts to deliver unpromotable refreshes, we expect the search shards to be upgraded already. + // Note that the fast refresh setting is final. + // TODO: remove assertion (ES-9563) + assert INDEX_FAST_REFRESH_SETTING.get(shard.indexSettings().getSettings()) == false + || transportService.getLocalNodeConnection().getTransportVersion().onOrAfter(FAST_REFRESH_RCO) + : "attempted to refresh a fast refresh search shard " + + shard + + " on transport version " + + transportService.getLocalNodeConnection().getTransportVersion() + + " (before FAST_REFRESH_RCO)"; + ActionListener.run(responseListener, listener -> { shard.waitForPrimaryTermAndGeneration( request.getPrimaryTerm(), diff --git a/server/src/main/java/org/elasticsearch/action/get/TransportGetAction.java b/server/src/main/java/org/elasticsearch/action/get/TransportGetAction.java index 189aa1c95d865..99eac250641ae 100644 --- a/server/src/main/java/org/elasticsearch/action/get/TransportGetAction.java +++ b/server/src/main/java/org/elasticsearch/action/get/TransportGetAction.java @@ -125,11 +125,10 @@ protected void asyncShardOperation(GetRequest request, ShardId shardId, ActionLi IndexService indexService = indicesService.indexServiceSafe(shardId.getIndex()); IndexShard indexShard = indexService.getShard(shardId.id()); if (indexShard.routingEntry().isPromotableToPrimary() == false) { - assert indexShard.indexSettings().isFastRefresh() == false - : "a search shard should not receive a TransportGetAction for an index with fast refresh"; handleGetOnUnpromotableShard(request, indexShard, listener); return; } + // TODO: adapt assertion to assert only that it is not stateless (ES-9563) assert DiscoveryNode.isStateless(clusterService.getSettings()) == false || indexShard.indexSettings().isFastRefresh() : "in Stateless a promotable to primary shard can receive a TransportGetAction only if an index has the fast refresh setting"; if (request.realtime()) { // we are not tied to a refresh cycle here anyway diff --git a/server/src/main/java/org/elasticsearch/action/get/TransportShardMultiGetAction.java b/server/src/main/java/org/elasticsearch/action/get/TransportShardMultiGetAction.java index 8d5760307c3fe..633e7ef6793ab 100644 --- a/server/src/main/java/org/elasticsearch/action/get/TransportShardMultiGetAction.java +++ b/server/src/main/java/org/elasticsearch/action/get/TransportShardMultiGetAction.java @@ -124,11 +124,10 @@ protected void asyncShardOperation(MultiGetShardRequest request, ShardId shardId IndexService indexService = indicesService.indexServiceSafe(shardId.getIndex()); IndexShard indexShard = indexService.getShard(shardId.id()); if (indexShard.routingEntry().isPromotableToPrimary() == false) { - assert indexShard.indexSettings().isFastRefresh() == false - : "a search shard should not receive a TransportShardMultiGetAction for an index with fast refresh"; handleMultiGetOnUnpromotableShard(request, indexShard, listener); return; } + // TODO: adapt assertion to assert only that it is not stateless (ES-9563) assert DiscoveryNode.isStateless(clusterService.getSettings()) == false || indexShard.indexSettings().isFastRefresh() : "in Stateless a promotable to primary shard can receive a TransportShardMultiGetAction only if an index has " + "the fast refresh setting"; diff --git a/server/src/main/java/org/elasticsearch/action/support/replication/PostWriteRefresh.java b/server/src/main/java/org/elasticsearch/action/support/replication/PostWriteRefresh.java index 683c3589c893d..7414aeeb2c405 100644 --- a/server/src/main/java/org/elasticsearch/action/support/replication/PostWriteRefresh.java +++ b/server/src/main/java/org/elasticsearch/action/support/replication/PostWriteRefresh.java @@ -19,7 +19,6 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.engine.Engine; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.translog.Translog; @@ -53,9 +52,7 @@ public void refreshShard( case WAIT_UNTIL -> waitUntil(indexShard, location, new ActionListener<>() { @Override public void onResponse(Boolean forced) { - // Fast refresh indices do not depend on the unpromotables being refreshed - boolean fastRefresh = IndexSettings.INDEX_FAST_REFRESH_SETTING.get(indexShard.indexSettings().getSettings()); - if (location != null && (indexShard.routingEntry().isSearchable() == false && fastRefresh == false)) { + if (location != null && indexShard.routingEntry().isSearchable() == false) { refreshUnpromotables(indexShard, location, listener, forced, postWriteRefreshTimeout); } else { listener.onResponse(forced); @@ -68,9 +65,7 @@ public void onFailure(Exception e) { } }); case IMMEDIATE -> immediate(indexShard, listener.delegateFailureAndWrap((l, r) -> { - // Fast refresh indices do not depend on the unpromotables being refreshed - boolean fastRefresh = IndexSettings.INDEX_FAST_REFRESH_SETTING.get(indexShard.indexSettings().getSettings()); - if (indexShard.getReplicationGroup().getRoutingTable().unpromotableShards().size() > 0 && fastRefresh == false) { + if (indexShard.getReplicationGroup().getRoutingTable().unpromotableShards().size() > 0) { sendUnpromotableRequests(indexShard, r.generation(), true, l, postWriteRefreshTimeout); } else { l.onResponse(true); diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/OperationRouting.java b/server/src/main/java/org/elasticsearch/cluster/routing/OperationRouting.java index f7812d284f2af..9120e25b443d7 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/OperationRouting.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/OperationRouting.java @@ -32,6 +32,7 @@ import java.util.Set; import java.util.stream.Collectors; +import static org.elasticsearch.TransportVersions.FAST_REFRESH_RCO; import static org.elasticsearch.index.IndexSettings.INDEX_FAST_REFRESH_SETTING; public class OperationRouting { @@ -305,8 +306,14 @@ public ShardId shardId(ClusterState clusterState, String index, String id, @Null } public static boolean canSearchShard(ShardRouting shardRouting, ClusterState clusterState) { + // TODO: remove if and always return isSearchable (ES-9563) if (INDEX_FAST_REFRESH_SETTING.get(clusterState.metadata().index(shardRouting.index()).getSettings())) { - return shardRouting.isPromotableToPrimary(); + // Until all the cluster is upgraded, we send searches/gets to the primary (even if it has been upgraded) to execute locally. + if (clusterState.getMinTransportVersion().onOrAfter(FAST_REFRESH_RCO)) { + return shardRouting.isSearchable(); + } else { + return shardRouting.isPromotableToPrimary(); + } } else { return shardRouting.isSearchable(); } diff --git a/server/src/main/java/org/elasticsearch/index/cache/bitset/BitsetFilterCache.java b/server/src/main/java/org/elasticsearch/index/cache/bitset/BitsetFilterCache.java index c19e3ca353569..3b37afc3b297b 100644 --- a/server/src/main/java/org/elasticsearch/index/cache/bitset/BitsetFilterCache.java +++ b/server/src/main/java/org/elasticsearch/index/cache/bitset/BitsetFilterCache.java @@ -105,7 +105,7 @@ static boolean shouldLoadRandomAccessFiltersEagerly(IndexSettings settings) { boolean loadFiltersEagerlySetting = settings.getValue(INDEX_LOAD_RANDOM_ACCESS_FILTERS_EAGERLY_SETTING); boolean isStateless = DiscoveryNode.isStateless(settings.getNodeSettings()); if (isStateless) { - return DiscoveryNode.hasRole(settings.getNodeSettings(), DiscoveryNodeRole.INDEX_ROLE) + return DiscoveryNode.hasRole(settings.getNodeSettings(), DiscoveryNodeRole.SEARCH_ROLE) && loadFiltersEagerlySetting && INDEX_FAST_REFRESH_SETTING.get(settings.getSettings()); } else { diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/IndexRoutingTableTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/IndexRoutingTableTests.java index 21b30557cafea..6a7f4bb27a324 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/IndexRoutingTableTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/IndexRoutingTableTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.cluster.routing; +import org.elasticsearch.TransportVersion; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.settings.Settings; @@ -19,6 +20,7 @@ import java.util.List; +import static org.elasticsearch.TransportVersions.FAST_REFRESH_RCO; import static org.elasticsearch.index.IndexSettings.INDEX_FAST_REFRESH_SETTING; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -27,16 +29,22 @@ public class IndexRoutingTableTests extends ESTestCase { public void testReadyForSearch() { - innerReadyForSearch(false); - innerReadyForSearch(true); + innerReadyForSearch(false, false); + innerReadyForSearch(false, true); + innerReadyForSearch(true, false); + innerReadyForSearch(true, true); } - private void innerReadyForSearch(boolean fastRefresh) { + // TODO: remove if (fastRefresh && beforeFastRefreshRCO) branches (ES-9563) + private void innerReadyForSearch(boolean fastRefresh, boolean beforeFastRefreshRCO) { Index index = new Index(randomIdentifier(), UUIDs.randomBase64UUID()); ClusterState clusterState = mock(ClusterState.class, Mockito.RETURNS_DEEP_STUBS); when(clusterState.metadata().index(any(Index.class)).getSettings()).thenReturn( Settings.builder().put(INDEX_FAST_REFRESH_SETTING.getKey(), fastRefresh).build() ); + when(clusterState.getMinTransportVersion()).thenReturn( + beforeFastRefreshRCO ? TransportVersion.fromId(FAST_REFRESH_RCO.id() - 1_00_0) : TransportVersion.current() + ); // 2 primaries that are search and index ShardId p1 = new ShardId(index, 0); IndexShardRoutingTable shardTable1 = new IndexShardRoutingTable( @@ -55,7 +63,7 @@ private void innerReadyForSearch(boolean fastRefresh) { shardTable1 = new IndexShardRoutingTable(p1, List.of(getShard(p1, true, ShardRoutingState.STARTED, ShardRouting.Role.INDEX_ONLY))); shardTable2 = new IndexShardRoutingTable(p2, List.of(getShard(p2, true, ShardRoutingState.STARTED, ShardRouting.Role.INDEX_ONLY))); indexRoutingTable = new IndexRoutingTable(index, new IndexShardRoutingTable[] { shardTable1, shardTable2 }); - if (fastRefresh) { + if (fastRefresh && beforeFastRefreshRCO) { assertTrue(indexRoutingTable.readyForSearch(clusterState)); } else { assertFalse(indexRoutingTable.readyForSearch(clusterState)); @@ -91,7 +99,7 @@ private void innerReadyForSearch(boolean fastRefresh) { ) ); indexRoutingTable = new IndexRoutingTable(index, new IndexShardRoutingTable[] { shardTable1, shardTable2 }); - if (fastRefresh) { + if (fastRefresh && beforeFastRefreshRCO) { assertTrue(indexRoutingTable.readyForSearch(clusterState)); } else { assertFalse(indexRoutingTable.readyForSearch(clusterState)); @@ -118,8 +126,6 @@ private void innerReadyForSearch(boolean fastRefresh) { assertTrue(indexRoutingTable.readyForSearch(clusterState)); // 2 unassigned primaries that are index only with some replicas that are all available - // Fast refresh indices do not support replicas so this can not practically happen. If we add support we will want to ensure - // that readyForSearch allows for searching replicas when the index shard is not available. shardTable1 = new IndexShardRoutingTable( p1, List.of( @@ -137,8 +143,8 @@ private void innerReadyForSearch(boolean fastRefresh) { ) ); indexRoutingTable = new IndexRoutingTable(index, new IndexShardRoutingTable[] { shardTable1, shardTable2 }); - if (fastRefresh) { - assertFalse(indexRoutingTable.readyForSearch(clusterState)); // if we support replicas for fast refreshes this needs to change + if (fastRefresh && beforeFastRefreshRCO) { + assertFalse(indexRoutingTable.readyForSearch(clusterState)); } else { assertTrue(indexRoutingTable.readyForSearch(clusterState)); } diff --git a/server/src/test/java/org/elasticsearch/index/cache/bitset/BitSetFilterCacheTests.java b/server/src/test/java/org/elasticsearch/index/cache/bitset/BitSetFilterCacheTests.java index 77635fd0312f8..4cb3ce418f761 100644 --- a/server/src/test/java/org/elasticsearch/index/cache/bitset/BitSetFilterCacheTests.java +++ b/server/src/test/java/org/elasticsearch/index/cache/bitset/BitSetFilterCacheTests.java @@ -276,7 +276,7 @@ public void testShouldLoadRandomAccessFiltersEagerly() { for (var isStateless : values) { if (isStateless) { assertEquals( - loadFiltersEagerly && indexFastRefresh && hasIndexRole, + loadFiltersEagerly && indexFastRefresh && hasIndexRole == false, BitsetFilterCache.shouldLoadRandomAccessFiltersEagerly( bitsetFilterCacheSettings(isStateless, hasIndexRole, loadFiltersEagerly, indexFastRefresh) ) From 5b8f9c12d2b4495b9f75b25e046f394c3d1f0719 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 7 Oct 2024 15:58:10 -0400 Subject: [PATCH 06/22] [ML] Default inference endpoint for ELSER (#114164) Co-authored-by: Elastic Machine --- docs/changelog/113873.yaml | 5 + docs/reference/rest-api/usage.asciidoc | 7 +- .../inference/InferenceService.java | 9 + .../inference/UnparsedModel.java | 24 ++ .../test/cluster/FeatureFlag.java | 3 +- x-pack/plugin/build.gradle | 1 + .../StartTrainedModelDeploymentAction.java | 4 +- .../AdaptiveAllocationsSettings.java | 4 +- .../AdaptiveAllocationSettingsTests.java | 2 +- .../xpack/inference/CustomElandModelIT.java | 2 +- .../xpack/inference/DefaultElserIT.java | 70 ++++++ .../inference/InferenceBaseRestTest.java | 21 +- .../xpack/inference/InferenceCrudIT.java | 8 +- .../MockDenseInferenceServiceIT.java | 11 +- .../MockSparseInferenceServiceIT.java | 15 +- .../xpack/inference/TextEmbeddingCrudIT.java | 4 +- .../integration/ModelRegistryIT.java | 215 +++++++++++++++++- .../inference/DefaultElserFeatureFlag.java | 21 ++ .../xpack/inference/InferencePlugin.java | 3 + ...ransportDeleteInferenceEndpointAction.java | 3 +- .../TransportGetInferenceModelAction.java | 3 +- .../action/TransportInferenceAction.java | 39 ++-- .../ShardBulkInferenceActionFilter.java | 5 +- .../inference/registry/ModelRegistry.java | 161 ++++++++++--- .../BaseElasticsearchInternalService.java | 28 ++- .../ElasticsearchInternalService.java | 106 +++++++-- .../ShardBulkInferenceActionFilterTests.java | 6 +- .../registry/ModelRegistryTests.java | 90 +++++++- .../test/inference/inference_crud.yml | 9 +- 29 files changed, 745 insertions(+), 134 deletions(-) create mode 100644 docs/changelog/113873.yaml create mode 100644 server/src/main/java/org/elasticsearch/inference/UnparsedModel.java create mode 100644 x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultElserIT.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/DefaultElserFeatureFlag.java diff --git a/docs/changelog/113873.yaml b/docs/changelog/113873.yaml new file mode 100644 index 0000000000000..ac52aaf94d518 --- /dev/null +++ b/docs/changelog/113873.yaml @@ -0,0 +1,5 @@ +pr: 113873 +summary: Default inference endpoint for ELSER +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/reference/rest-api/usage.asciidoc b/docs/reference/rest-api/usage.asciidoc index 4a8895807f2fa..4dcf0d328e4f1 100644 --- a/docs/reference/rest-api/usage.asciidoc +++ b/docs/reference/rest-api/usage.asciidoc @@ -206,7 +206,12 @@ GET /_xpack/usage "inference": { "available" : true, "enabled" : true, - "models" : [] + "models" : [{ + "service": "elasticsearch", + "task_type": "SPARSE_EMBEDDING", + "count": 1 + } + ] }, "logstash" : { "available" : true, diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index aba644b392cec..cbbfef2cc65fa 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -191,4 +191,13 @@ default Set supportedStreamingTasks() { default boolean canStream(TaskType taskType) { return supportedStreamingTasks().contains(taskType); } + + /** + * A service can define default configurations that can be + * used out of the box without creating an endpoint first. + * @return Default configurations provided by this service + */ + default List defaultConfigs() { + return List.of(); + } } diff --git a/server/src/main/java/org/elasticsearch/inference/UnparsedModel.java b/server/src/main/java/org/elasticsearch/inference/UnparsedModel.java new file mode 100644 index 0000000000000..30a7c6aa2bf9c --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/UnparsedModel.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.inference; + +import java.util.Map; + +/** + * Semi parsed model where inference entity id, task type and service + * are known but the settings are not parsed. + */ +public record UnparsedModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map settings, + Map secrets +) {} diff --git a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java index 7df791bf11559..5adf01a2a0e7d 100644 --- a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java +++ b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java @@ -19,7 +19,8 @@ public enum FeatureFlag { TIME_SERIES_MODE("es.index_mode_feature_flag_registered=true", Version.fromString("8.0.0"), null), FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null), CHUNKING_SETTINGS_ENABLED("es.inference_chunking_settings_feature_flag_enabled=true", Version.fromString("8.16.0"), null), - INFERENCE_SCALE_TO_ZERO("es.inference_scale_to_zero_feature_flag_enabled=true", Version.fromString("8.16.0"), null); + INFERENCE_SCALE_TO_ZERO("es.inference_scale_to_zero_feature_flag_enabled=true", Version.fromString("8.16.0"), null), + INFERENCE_DEFAULT_ELSER("es.inference_default_elser_feature_flag_enabled=true", Version.fromString("8.16.0"), null); public final String systemProperty; public final Version from; diff --git a/x-pack/plugin/build.gradle b/x-pack/plugin/build.gradle index b90e8a22ea6c6..feddc7bcfba3f 100644 --- a/x-pack/plugin/build.gradle +++ b/x-pack/plugin/build.gradle @@ -201,5 +201,6 @@ tasks.named("precommit").configure { tasks.named("yamlRestTestV7CompatTransform").configure({ task -> task.skipTest("security/10_forbidden/Test bulk response with invalid credentials", "warning does not exist for compatibility") + task.skipTest("inference/inference_crud/Test get all", "Assertions on number of inference models break due to default configs") }) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java index 4a570bfde99a4..34ebdcb7f9f9f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java @@ -237,7 +237,9 @@ public int computeNumberOfAllocations() { if (numberOfAllocations != null) { return numberOfAllocations; } else { - if (adaptiveAllocationsSettings == null || adaptiveAllocationsSettings.getMinNumberOfAllocations() == null) { + if (adaptiveAllocationsSettings == null + || adaptiveAllocationsSettings.getMinNumberOfAllocations() == null + || adaptiveAllocationsSettings.getMinNumberOfAllocations() == 0) { return DEFAULT_NUM_ALLOCATIONS; } else { return adaptiveAllocationsSettings.getMinNumberOfAllocations(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsSettings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsSettings.java index 19af6a3a4ef4c..d4eace8e96621 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsSettings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsSettings.java @@ -147,8 +147,8 @@ public AdaptiveAllocationsSettings merge(AdaptiveAllocationsSettings updates) { public ActionRequestValidationException validate() { ActionRequestValidationException validationException = new ActionRequestValidationException(); boolean hasMinNumberOfAllocations = (minNumberOfAllocations != null && minNumberOfAllocations != -1); - if (hasMinNumberOfAllocations && minNumberOfAllocations < 1) { - validationException.addValidationError("[" + MIN_NUMBER_OF_ALLOCATIONS + "] must be a positive integer or null"); + if (hasMinNumberOfAllocations && minNumberOfAllocations < 0) { + validationException.addValidationError("[" + MIN_NUMBER_OF_ALLOCATIONS + "] must be a non-negative integer or null"); } boolean hasMaxNumberOfAllocations = (maxNumberOfAllocations != null && maxNumberOfAllocations != -1); if (hasMaxNumberOfAllocations && maxNumberOfAllocations < 1) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationSettingsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationSettingsTests.java index c86648f10f08b..d59fbb2a24ee0 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationSettingsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationSettingsTests.java @@ -17,7 +17,7 @@ public class AdaptiveAllocationSettingsTests extends AbstractWireSerializingTest public static AdaptiveAllocationsSettings testInstance() { return new AdaptiveAllocationsSettings( randomBoolean() ? null : randomBoolean(), - randomBoolean() ? null : randomIntBetween(1, 2), + randomBoolean() ? null : randomIntBetween(0, 2), randomBoolean() ? null : randomIntBetween(2, 4) ); } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java index 65b7a138e7e1e..c05d08fa33692 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java @@ -85,7 +85,7 @@ public void testSparse() throws IOException { var inferenceId = "sparse-inf"; putModel(inferenceId, inferenceConfig, TaskType.SPARSE_EMBEDDING); - var results = inferOnMockService(inferenceId, List.of("washing", "machine")); + var results = infer(inferenceId, List.of("washing", "machine")); deleteModel(inferenceId); assertNotNull(results.get("sparse_embedding")); } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultElserIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultElserIT.java new file mode 100644 index 0000000000000..5d84aad4b7344 --- /dev/null +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultElserIT.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +public class DefaultElserIT extends InferenceBaseRestTest { + + private TestThreadPool threadPool; + + @Before + public void createThreadPool() { + threadPool = new TestThreadPool(DefaultElserIT.class.getSimpleName()); + } + + @After + public void tearDown() throws Exception { + threadPool.close(); + super.tearDown(); + } + + @SuppressWarnings("unchecked") + public void testInferCreatesDefaultElser() throws IOException { + assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled()); + var model = getModel(ElasticsearchInternalService.DEFAULT_ELSER_ID); + assertDefaultElserConfig(model); + + var inputs = List.of("Hello World", "Goodnight moon"); + var queryParams = Map.of("timeout", "120s"); + var results = infer(ElasticsearchInternalService.DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, inputs, queryParams); + var embeddings = (List>) results.get("sparse_embedding"); + assertThat(results.toString(), embeddings, hasSize(2)); + } + + @SuppressWarnings("unchecked") + private static void assertDefaultElserConfig(Map modelConfig) { + assertEquals(modelConfig.toString(), ElasticsearchInternalService.DEFAULT_ELSER_ID, modelConfig.get("inference_id")); + assertEquals(modelConfig.toString(), ElasticsearchInternalService.NAME, modelConfig.get("service")); + assertEquals(modelConfig.toString(), TaskType.SPARSE_EMBEDDING.toString(), modelConfig.get("task_type")); + + var serviceSettings = (Map) modelConfig.get("service_settings"); + assertThat(modelConfig.toString(), serviceSettings.get("model_id"), is(oneOf(".elser_model_2", ".elser_model_2_linux-x86_64"))); + assertEquals(modelConfig.toString(), 1, serviceSettings.get("num_threads")); + + var adaptiveAllocations = (Map) serviceSettings.get("adaptive_allocations"); + assertThat( + modelConfig.toString(), + adaptiveAllocations, + Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 1, "max_number_of_allocations", 8)) + ); + } +} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 3fa6159661b7e..f82b6f155c0a0 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -270,7 +270,7 @@ protected Map deployE5TrainedModels() throws IOException { @SuppressWarnings("unchecked") protected Map getModel(String modelId) throws IOException { - var endpoint = Strings.format("_inference/%s", modelId); + var endpoint = Strings.format("_inference/%s?error_trace", modelId); return ((List>) getInternal(endpoint).get("endpoints")).get(0); } @@ -293,9 +293,9 @@ private Map getInternal(String endpoint) throws IOException { return entityAsMap(response); } - protected Map inferOnMockService(String modelId, List input) throws IOException { + protected Map infer(String modelId, List input) throws IOException { var endpoint = Strings.format("_inference/%s", modelId); - return inferOnMockServiceInternal(endpoint, input); + return inferInternal(endpoint, input, Map.of()); } protected Deque streamInferOnMockService(String modelId, TaskType taskType, List input) throws Exception { @@ -324,14 +324,23 @@ public void onFailure(Exception exception) { return responseConsumer.events(); } - protected Map inferOnMockService(String modelId, TaskType taskType, List input) throws IOException { + protected Map infer(String modelId, TaskType taskType, List input) throws IOException { var endpoint = Strings.format("_inference/%s/%s", taskType, modelId); - return inferOnMockServiceInternal(endpoint, input); + return inferInternal(endpoint, input, Map.of()); + } + + protected Map infer(String modelId, TaskType taskType, List input, Map queryParameters) + throws IOException { + var endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId); + return inferInternal(endpoint, input, queryParameters); } - private Map inferOnMockServiceInternal(String endpoint, List input) throws IOException { + private Map inferInternal(String endpoint, List input, Map queryParameters) throws IOException { var request = new Request("POST", endpoint); request.setJsonEntity(jsonBody(input)); + if (queryParameters.isEmpty() == false) { + request.addParameters(queryParameters); + } var response = client().performRequest(request); assertOkOrCreated(response); return entityAsMap(response); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 92affbc043669..5a84fd8985504 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -38,10 +38,12 @@ public void testGet() throws IOException { } var getAllModels = getAllModels(); - assertThat(getAllModels, hasSize(9)); + int numModels = DefaultElserFeatureFlag.isEnabled() ? 10 : 9; + assertThat(getAllModels, hasSize(numModels)); var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING); - assertThat(getSparseModels, hasSize(5)); + int numSparseModels = DefaultElserFeatureFlag.isEnabled() ? 6 : 5; + assertThat(getSparseModels, hasSize(numSparseModels)); for (var sparseModel : getSparseModels) { assertEquals("sparse_embedding", sparseModel.get("task_type")); } @@ -99,7 +101,7 @@ public void testApisWithoutTaskType() throws IOException { assertEquals(modelId, singleModel.get("inference_id")); assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type")); - var inference = inferOnMockService(modelId, List.of(randomAlphaOfLength(10))); + var inference = infer(modelId, List.of(randomAlphaOfLength(10))); assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING); deleteModel(modelId); } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockDenseInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockDenseInferenceServiceIT.java index 5f6bad5687407..1077bfec8bbbd 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockDenseInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockDenseInferenceServiceIT.java @@ -28,15 +28,12 @@ public void testMockService() throws IOException { } List input = List.of(randomAlphaOfLength(10)); - var inference = inferOnMockService(inferenceEntityId, input); + var inference = infer(inferenceEntityId, input); assertNonEmptyInferenceResults(inference, 1, TaskType.TEXT_EMBEDDING); // Same input should return the same result - assertEquals(inference, inferOnMockService(inferenceEntityId, input)); + assertEquals(inference, infer(inferenceEntityId, input)); // Different input values should not - assertNotEquals( - inference, - inferOnMockService(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10)))) - ); + assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10))))); } public void testMockServiceWithMultipleInputs() throws IOException { @@ -44,7 +41,7 @@ public void testMockServiceWithMultipleInputs() throws IOException { putModel(inferenceEntityId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING); // The response is randomly generated, the input can be anything - var inference = inferOnMockService( + var inference = infer( inferenceEntityId, TaskType.TEXT_EMBEDDING, List.of(randomAlphaOfLength(5), randomAlphaOfLength(10), randomAlphaOfLength(15)) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockSparseInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockSparseInferenceServiceIT.java index 24ba2708f5de4..9a17d8edc0768 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockSparseInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockSparseInferenceServiceIT.java @@ -30,15 +30,12 @@ public void testMockService() throws IOException { } List input = List.of(randomAlphaOfLength(10)); - var inference = inferOnMockService(inferenceEntityId, input); + var inference = infer(inferenceEntityId, input); assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING); // Same input should return the same result - assertEquals(inference, inferOnMockService(inferenceEntityId, input)); + assertEquals(inference, infer(inferenceEntityId, input)); // Different input values should not - assertNotEquals( - inference, - inferOnMockService(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10)))) - ); + assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10))))); } public void testMockServiceWithMultipleInputs() throws IOException { @@ -46,7 +43,7 @@ public void testMockServiceWithMultipleInputs() throws IOException { putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); // The response is randomly generated, the input can be anything - var inference = inferOnMockService( + var inference = infer( inferenceEntityId, TaskType.SPARSE_EMBEDDING, List.of(randomAlphaOfLength(5), randomAlphaOfLength(10), randomAlphaOfLength(15)) @@ -84,7 +81,7 @@ public void testMockService_DoesNotReturnHiddenField_InModelResponses() throws I } // The response is randomly generated, the input can be anything - var inference = inferOnMockService(inferenceEntityId, List.of(randomAlphaOfLength(10))); + var inference = infer(inferenceEntityId, List.of(randomAlphaOfLength(10))); assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING); } @@ -102,7 +99,7 @@ public void testMockService_DoesReturnHiddenField_InModelResponses() throws IOEx } // The response is randomly generated, the input can be anything - var inference = inferOnMockService(inferenceEntityId, List.of(randomAlphaOfLength(10))); + var inference = infer(inferenceEntityId, List.of(randomAlphaOfLength(10))); assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING); } } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java index 01e8c30e3bf27..8d9c859f129cb 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java @@ -38,7 +38,7 @@ public void testPutE5Small_withPlatformAgnosticVariant() throws IOException { var models = getTrainedModel("_all"); assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId)); - Map results = inferOnMockService( + Map results = infer( inferenceEntityId, TaskType.TEXT_EMBEDDING, List.of("hello world", "this is the second document") @@ -57,7 +57,7 @@ public void testPutE5Small_withPlatformSpecificVariant() throws IOException { var models = getTrainedModel("_all"); assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId)); - Map results = inferOnMockService( + Map results = infer( inferenceEntityId, TaskType.TEXT_EMBEDDING, List.of("hello world", "this is the second document") diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index ea8b32f36f54c..8e68ca9dfa565 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -20,6 +20,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.test.ESSingleNodeTestCase; @@ -38,6 +39,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Comparator; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; @@ -110,7 +112,7 @@ public void testGetModel() throws Exception { assertThat(putModelHolder.get(), is(true)); // now get the model - AtomicReference modelHolder = new AtomicReference<>(); + AtomicReference modelHolder = new AtomicReference<>(); blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(nullValue())); assertThat(modelHolder.get(), not(nullValue())); @@ -168,7 +170,7 @@ public void testDeleteModel() throws Exception { // get should fail deleteResponseHolder.set(false); - AtomicReference modelHolder = new AtomicReference<>(); + AtomicReference modelHolder = new AtomicReference<>(); blockingCall(listener -> modelRegistry.getModelWithSecrets("model1", listener), modelHolder, exceptionHolder); assertThat(exceptionHolder.get(), not(nullValue())); @@ -194,7 +196,7 @@ public void testGetModelsByTaskType() throws InterruptedException { } AtomicReference exceptionHolder = new AtomicReference<>(); - AtomicReference> modelHolder = new AtomicReference<>(); + AtomicReference> modelHolder = new AtomicReference<>(); blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(3)); var sparseIds = sparseAndTextEmbeddingModels.stream() @@ -235,8 +237,9 @@ public void testGetAllModels() throws InterruptedException { assertNull(exceptionHolder.get()); } - AtomicReference> modelHolder = new AtomicReference<>(); + AtomicReference> modelHolder = new AtomicReference<>(); blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder); + assertNull(exceptionHolder.get()); assertThat(modelHolder.get(), hasSize(modelCount)); var getAllModels = modelHolder.get(); @@ -264,15 +267,213 @@ public void testGetModelWithSecrets() throws InterruptedException { assertThat(putModelHolder.get(), is(true)); assertNull(exceptionHolder.get()); - AtomicReference modelHolder = new AtomicReference<>(); + AtomicReference modelHolder = new AtomicReference<>(); blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get().secrets().keySet(), hasSize(1)); var secretSettings = (Map) modelHolder.get().secrets().get("secret_settings"); assertThat(secretSettings.get("secret"), equalTo(secret)); + assertReturnModelIsModifiable(modelHolder.get()); // get model without secrets blockingCall(listener -> modelRegistry.getModel(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get().secrets().keySet(), empty()); + assertReturnModelIsModifiable(modelHolder.get()); + } + + public void testGetAllModels_WithDefaults() throws Exception { + var service = "foo"; + var secret = "abc"; + int configuredModelCount = 10; + int defaultModelCount = 2; + int totalModelCount = 12; + + var defaultConfigs = new HashMap(); + for (int i = 0; i < defaultModelCount; i++) { + var id = "default-" + i; + defaultConfigs.put(id, createUnparsedConfig(id, randomFrom(TaskType.values()), service, secret)); + } + defaultConfigs.values().forEach(modelRegistry::addDefaultConfiguration); + + AtomicReference putModelHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + var createdModels = new HashMap(); + for (int i = 0; i < configuredModelCount; i++) { + var id = randomAlphaOfLength(5) + i; + var model = createModel(id, randomFrom(TaskType.values()), service); + createdModels.put(id, model); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); + assertThat(putModelHolder.get(), is(true)); + assertNull(exceptionHolder.get()); + } + + AtomicReference> modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder); + assertNull(exceptionHolder.get()); + assertThat(modelHolder.get(), hasSize(totalModelCount)); + var getAllModels = modelHolder.get(); + assertReturnModelIsModifiable(modelHolder.get().get(0)); + + // sort in the same order as the returned models + var ids = new ArrayList<>(defaultConfigs.keySet().stream().toList()); + ids.addAll(createdModels.keySet().stream().toList()); + ids.sort(String::compareTo); + for (int i = 0; i < totalModelCount; i++) { + var id = ids.get(i); + assertEquals(id, getAllModels.get(i).inferenceEntityId()); + if (id.startsWith("default")) { + assertEquals(defaultConfigs.get(id).taskType(), getAllModels.get(i).taskType()); + assertEquals(defaultConfigs.get(id).service(), getAllModels.get(i).service()); + } else { + assertEquals(createdModels.get(id).getTaskType(), getAllModels.get(i).taskType()); + assertEquals(createdModels.get(id).getConfigurations().getService(), getAllModels.get(i).service()); + } + } + } + + public void testGetAllModels_OnlyDefaults() throws Exception { + var service = "foo"; + var secret = "abc"; + int defaultModelCount = 2; + + var defaultConfigs = new HashMap(); + for (int i = 0; i < defaultModelCount; i++) { + var id = "default-" + i; + defaultConfigs.put(id, createUnparsedConfig(id, randomFrom(TaskType.values()), service, secret)); + } + defaultConfigs.values().forEach(modelRegistry::addDefaultConfiguration); + + AtomicReference exceptionHolder = new AtomicReference<>(); + AtomicReference> modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder); + assertNull(exceptionHolder.get()); + assertThat(modelHolder.get(), hasSize(2)); + var getAllModels = modelHolder.get(); + assertReturnModelIsModifiable(modelHolder.get().get(0)); + + // sort in the same order as the returned models + var ids = new ArrayList<>(defaultConfigs.keySet().stream().toList()); + ids.sort(String::compareTo); + for (int i = 0; i < defaultModelCount; i++) { + var id = ids.get(i); + assertEquals(id, getAllModels.get(i).inferenceEntityId()); + assertEquals(defaultConfigs.get(id).taskType(), getAllModels.get(i).taskType()); + assertEquals(defaultConfigs.get(id).service(), getAllModels.get(i).service()); + } + } + + public void testGet_WithDefaults() throws InterruptedException { + var service = "foo"; + var secret = "abc"; + + var defaultSparse = createUnparsedConfig("default-sparse", TaskType.SPARSE_EMBEDDING, service, secret); + var defaultText = createUnparsedConfig("default-text", TaskType.TEXT_EMBEDDING, service, secret); + + modelRegistry.addDefaultConfiguration(defaultSparse); + modelRegistry.addDefaultConfiguration(defaultText); + + AtomicReference putModelHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + var configured1 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), service); + var configured2 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), service); + blockingCall(listener -> modelRegistry.storeModel(configured1, listener), putModelHolder, exceptionHolder); + assertThat(putModelHolder.get(), is(true)); + blockingCall(listener -> modelRegistry.storeModel(configured2, listener), putModelHolder, exceptionHolder); + assertThat(putModelHolder.get(), is(true)); + assertNull(exceptionHolder.get()); + + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModel("default-sparse", listener), modelHolder, exceptionHolder); + assertEquals("default-sparse", modelHolder.get().inferenceEntityId()); + assertEquals(TaskType.SPARSE_EMBEDDING, modelHolder.get().taskType()); + assertReturnModelIsModifiable(modelHolder.get()); + + blockingCall(listener -> modelRegistry.getModel("default-text", listener), modelHolder, exceptionHolder); + assertEquals("default-text", modelHolder.get().inferenceEntityId()); + assertEquals(TaskType.TEXT_EMBEDDING, modelHolder.get().taskType()); + + blockingCall(listener -> modelRegistry.getModel(configured1.getInferenceEntityId(), listener), modelHolder, exceptionHolder); + assertEquals(configured1.getInferenceEntityId(), modelHolder.get().inferenceEntityId()); + assertEquals(configured1.getTaskType(), modelHolder.get().taskType()); + } + + public void testGetByTaskType_WithDefaults() throws Exception { + var service = "foo"; + var secret = "abc"; + + var defaultSparse = createUnparsedConfig("default-sparse", TaskType.SPARSE_EMBEDDING, service, secret); + var defaultText = createUnparsedConfig("default-text", TaskType.TEXT_EMBEDDING, service, secret); + var defaultChat = createUnparsedConfig("default-chat", TaskType.COMPLETION, service, secret); + + modelRegistry.addDefaultConfiguration(defaultSparse); + modelRegistry.addDefaultConfiguration(defaultText); + modelRegistry.addDefaultConfiguration(defaultChat); + + AtomicReference putModelHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + var configuredSparse = createModel("configured-sparse", TaskType.SPARSE_EMBEDDING, service); + var configuredText = createModel("configured-text", TaskType.TEXT_EMBEDDING, service); + var configuredRerank = createModel("configured-rerank", TaskType.RERANK, service); + blockingCall(listener -> modelRegistry.storeModel(configuredSparse, listener), putModelHolder, exceptionHolder); + assertThat(putModelHolder.get(), is(true)); + blockingCall(listener -> modelRegistry.storeModel(configuredText, listener), putModelHolder, exceptionHolder); + assertThat(putModelHolder.get(), is(true)); + blockingCall(listener -> modelRegistry.storeModel(configuredRerank, listener), putModelHolder, exceptionHolder); + assertThat(putModelHolder.get(), is(true)); + assertNull(exceptionHolder.get()); + + AtomicReference> modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder); + if (exceptionHolder.get() != null) { + throw exceptionHolder.get(); + } + assertNull(exceptionHolder.get()); + assertThat(modelHolder.get(), hasSize(2)); + assertEquals("configured-sparse", modelHolder.get().get(0).inferenceEntityId()); + assertEquals("default-sparse", modelHolder.get().get(1).inferenceEntityId()); + + blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder); + assertThat(modelHolder.get(), hasSize(2)); + assertEquals("configured-text", modelHolder.get().get(0).inferenceEntityId()); + assertEquals("default-text", modelHolder.get().get(1).inferenceEntityId()); + assertReturnModelIsModifiable(modelHolder.get().get(0)); + + blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.RERANK, listener), modelHolder, exceptionHolder); + assertThat(modelHolder.get(), hasSize(1)); + assertEquals("configured-rerank", modelHolder.get().get(0).inferenceEntityId()); + + blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.COMPLETION, listener), modelHolder, exceptionHolder); + assertThat(modelHolder.get(), hasSize(1)); + assertEquals("default-chat", modelHolder.get().get(0).inferenceEntityId()); + assertReturnModelIsModifiable(modelHolder.get().get(0)); + } + + @SuppressWarnings("unchecked") + private void assertReturnModelIsModifiable(UnparsedModel unparsedModel) { + var settings = unparsedModel.settings(); + if (settings != null) { + var serviceSettings = (Map) settings.get("service_settings"); + if (serviceSettings != null && serviceSettings.size() > 0) { + var itr = serviceSettings.entrySet().iterator(); + itr.next(); + itr.remove(); + } + + var taskSettings = (Map) settings.get("task_settings"); + if (taskSettings != null && taskSettings.size() > 0) { + var itr = taskSettings.entrySet().iterator(); + itr.next(); + itr.remove(); + } + + if (unparsedModel.secrets() != null && unparsedModel.secrets().size() > 0) { + var itr = unparsedModel.secrets().entrySet().iterator(); + itr.next(); + itr.remove(); + } + } } private Model buildElserModelConfig(String inferenceEntityId, TaskType taskType) { @@ -327,6 +528,10 @@ public static Model createModelWithSecrets(String inferenceEntityId, TaskType ta ); } + public static UnparsedModel createUnparsedConfig(String inferenceEntityId, TaskType taskType, String service, String secret) { + return new UnparsedModel(inferenceEntityId, taskType, service, Map.of("a", "b"), Map.of("secret", secret)); + } + private static class TestModelOfAnyKind extends ModelConfigurations { record TestModelServiceSettings() implements ServiceSettings { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/DefaultElserFeatureFlag.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/DefaultElserFeatureFlag.java new file mode 100644 index 0000000000000..2a764dabd62ae --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/DefaultElserFeatureFlag.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.common.util.FeatureFlag; + +public class DefaultElserFeatureFlag { + + private DefaultElserFeatureFlag() {} + + private static final FeatureFlag FEATURE_FLAG = new FeatureFlag("inference_default_elser"); + + public static boolean isEnabled() { + return FEATURE_FLAG.isEnabled(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 0ab395f4bfa39..dbb9130ab91e1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -210,6 +210,9 @@ public Collection createComponents(PluginServices services) { // reference correctly var registry = new InferenceServiceRegistry(inferenceServices, factoryContext); registry.init(services.client()); + for (var service : registry.getServices().values()) { + service.defaultConfigs().forEach(modelRegistry::addDefaultConfiguration); + } inferenceServiceRegistry.set(registry); var actionFilter = new ShardBulkInferenceActionFilter(registry, modelRegistry); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index 3c893f8870627..829a6b6c67ff9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -24,6 +24,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -91,7 +92,7 @@ private void doExecuteForked( ClusterState state, ActionListener masterListener ) { - SubscribableListener.newForked(modelConfigListener -> { + SubscribableListener.newForked(modelConfigListener -> { // Get the model from the registry modelRegistry.getModel(request.getInferenceEndpointId(), modelConfigListener); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java index a1f33afa05b5c..5ee1e40869dbc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java @@ -17,6 +17,7 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -112,7 +113,7 @@ private void getModelsByTaskType(TaskType taskType, ActionListener unparsedModels) { + private GetInferenceModelAction.Response parseModels(List unparsedModels) { var parsedModels = new ArrayList(); for (var unparsedModel : unparsedModels) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index d2a73b7df77c1..4045734546596 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -19,6 +19,7 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -64,30 +65,16 @@ public TransportInferenceAction( @Override protected void doExecute(Task task, InferenceAction.Request request, ActionListener listener) { - ActionListener getModelListener = listener.delegateFailureAndWrap((delegate, unparsedModel) -> { + ActionListener getModelListener = listener.delegateFailureAndWrap((delegate, unparsedModel) -> { var service = serviceRegistry.getService(unparsedModel.service()); if (service.isEmpty()) { - delegate.onFailure( - new ElasticsearchStatusException( - "Unknown service [{}] for model [{}]. ", - RestStatus.INTERNAL_SERVER_ERROR, - unparsedModel.service(), - unparsedModel.inferenceEntityId() - ) - ); + listener.onFailure(unknownServiceException(unparsedModel.service(), request.getInferenceEntityId())); return; } if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) { // not the wildcard task type and not the model task type - delegate.onFailure( - new ElasticsearchStatusException( - "Incompatible task_type, the requested type [{}] does not match the model type [{}]", - RestStatus.BAD_REQUEST, - request.getTaskType(), - unparsedModel.taskType() - ) - ); + listener.onFailure(incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType())); return; } @@ -98,7 +85,6 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe unparsedModel.settings(), unparsedModel.secrets() ); - inferenceStats.incrementRequestCount(model); inferOnService(model, request, service.get(), delegate); }); @@ -112,6 +98,7 @@ private void inferOnService( ActionListener listener ) { if (request.isStreaming() == false || service.canStream(request.getTaskType())) { + inferenceStats.incrementRequestCount(model); service.infer( model, request.getQuery(), @@ -160,5 +147,19 @@ private ActionListener createListener( }); } return listener.delegateFailureAndWrap((l, inferenceResults) -> l.onResponse(new InferenceAction.Response(inferenceResults))); - }; + } + + private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) { + return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId); + } + + private static ElasticsearchStatusException incompatibleTaskTypeException(TaskType requested, TaskType expected) { + return new ElasticsearchStatusException( + "Incompatible task_type, the requested type [{}] does not match the model type [{}]", + RestStatus.BAD_REQUEST, + requested, + expected + ); + } + } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index ade0748ef10bf..a4eb94c2674d1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -35,6 +35,7 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; @@ -211,9 +212,9 @@ private void executeShardBulkInferenceAsync( final Releasable onFinish ) { if (inferenceProvider == null) { - ActionListener modelLoadingListener = new ActionListener<>() { + ActionListener modelLoadingListener = new ActionListener<>() { @Override - public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { + public void onResponse(UnparsedModel unparsedModel) { var service = inferenceServiceRegistry.getService(unparsedModel.service()); if (service.isEmpty() == false) { var provider = new InferenceProvider( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index a6e4fcae7169f..d756c0ef26f14 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -32,6 +32,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; @@ -48,6 +49,8 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Function; @@ -58,32 +61,19 @@ public class ModelRegistry { public record ModelConfigMap(Map config, Map secrets) {} - /** - * Semi parsed model where inference entity id, task type and service - * are known but the settings are not parsed. - */ - public record UnparsedModel( - String inferenceEntityId, - TaskType taskType, - String service, - Map settings, - Map secrets - ) { - - public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) { - if (modelConfigMap.config() == null) { - throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST); - } - String inferenceEntityId = ServiceUtils.removeStringOrThrowIfNull( - modelConfigMap.config(), - ModelConfigurations.INDEX_ONLY_ID_FIELD_NAME - ); - String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE); - String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME); - TaskType taskType = TaskType.fromString(taskTypeStr); - - return new UnparsedModel(inferenceEntityId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets()); + public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) { + if (modelConfigMap.config() == null) { + throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST); } + String inferenceEntityId = ServiceUtils.removeStringOrThrowIfNull( + modelConfigMap.config(), + ModelConfigurations.INDEX_ONLY_ID_FIELD_NAME + ); + String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE); + String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME); + TaskType taskType = TaskType.fromString(taskTypeStr); + + return new UnparsedModel(inferenceEntityId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets()); } private static final String TASK_TYPE_FIELD = "task_type"; @@ -91,9 +81,27 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) private static final Logger logger = LogManager.getLogger(ModelRegistry.class); private final OriginSettingClient client; + private Map defaultConfigs; public ModelRegistry(Client client) { this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); + this.defaultConfigs = new HashMap<>(); + } + + public void addDefaultConfiguration(UnparsedModel serviceDefaultConfig) { + if (defaultConfigs.containsKey(serviceDefaultConfig.inferenceEntityId())) { + throw new IllegalStateException( + "Cannot add default endpoint to the inference endpoint registry with duplicate inference id [" + + serviceDefaultConfig.inferenceEntityId() + + "] declared by service [" + + serviceDefaultConfig.service() + + "]. The inference Id is already use by [" + + defaultConfigs.get(serviceDefaultConfig.inferenceEntityId()).service() + + "] service." + ); + } + + defaultConfigs.put(serviceDefaultConfig.inferenceEntityId(), serviceDefaultConfig); } /** @@ -102,6 +110,11 @@ public ModelRegistry(Client client) { * @param listener Model listener */ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { + if (defaultConfigs.containsKey(inferenceEntityId)) { + listener.onResponse(deepCopyDefaultConfig(defaultConfigs.get(inferenceEntityId))); + return; + } + ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // There should be a hit for the configurations and secrets if (searchResponse.getHits().getHits().length == 0) { @@ -109,7 +122,7 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { + if (defaultConfigs.containsKey(inferenceEntityId)) { + listener.onResponse(deepCopyDefaultConfig(defaultConfigs.get(inferenceEntityId))); + return; + } + ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // There should be a hit for the configurations and secrets if (searchResponse.getHits().getHits().length == 0) { @@ -135,7 +153,7 @@ public void getModel(String inferenceEntityId, ActionListener lis return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList(); + var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList(); assert modelConfigs.size() == 1; delegate.onResponse(modelConfigs.get(0)); }); @@ -162,14 +180,29 @@ private ResourceNotFoundException inferenceNotFoundException(String inferenceEnt */ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { + var defaultConfigsForTaskType = defaultConfigs.values() + .stream() + .filter(m -> m.taskType() == taskType) + .map(ModelRegistry::deepCopyDefaultConfig) + .toList(); + // Not an error if no models of this task_type - if (searchResponse.getHits().getHits().length == 0) { + if (searchResponse.getHits().getHits().length == 0 && defaultConfigsForTaskType.isEmpty()) { delegate.onResponse(List.of()); return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList(); - delegate.onResponse(modelConfigs); + var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList(); + + if (defaultConfigsForTaskType.isEmpty() == false) { + var allConfigs = new ArrayList(); + allConfigs.addAll(modelConfigs); + allConfigs.addAll(defaultConfigsForTaskType); + allConfigs.sort(Comparator.comparing(UnparsedModel::inferenceEntityId)); + delegate.onResponse(allConfigs); + } else { + delegate.onResponse(modelConfigs); + } }); QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(TASK_TYPE_FIELD, taskType.toString())); @@ -191,14 +224,19 @@ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { - // Not an error if no models of this task_type - if (searchResponse.getHits().getHits().length == 0) { + var defaults = defaultConfigs.values().stream().map(ModelRegistry::deepCopyDefaultConfig).toList(); + + if (searchResponse.getHits().getHits().length == 0 && defaults.isEmpty()) { delegate.onResponse(List.of()); return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList(); - delegate.onResponse(modelConfigs); + var foundConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList(); + var allConfigs = new ArrayList(); + allConfigs.addAll(foundConfigs); + allConfigs.addAll(defaults); + allConfigs.sort(Comparator.comparing(UnparsedModel::inferenceEntityId)); + delegate.onResponse(allConfigs); }); // In theory the index should only contain model config documents @@ -216,7 +254,7 @@ public void getAllModels(ActionListener> listener) { client.search(modelSearch, searchListener); } - private List parseHitsAsModels(SearchHits hits) { + private ArrayList parseHitsAsModels(SearchHits hits) { var modelConfigs = new ArrayList(); for (var hit : hits) { modelConfigs.add(new ModelConfigMap(hit.getSourceAsMap(), Map.of())); @@ -393,4 +431,57 @@ private static IndexRequest createIndexRequest(String docId, String indexName, T private QueryBuilder documentIdQuery(String inferenceEntityId) { return QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(inferenceEntityId))); } + + static UnparsedModel deepCopyDefaultConfig(UnparsedModel other) { + // Because the default config uses immutable maps + return new UnparsedModel( + other.inferenceEntityId(), + other.taskType(), + other.service(), + copySettingsMap(other.settings()), + copySecretsMap(other.secrets()) + ); + } + + @SuppressWarnings("unchecked") + static Map copySettingsMap(Map other) { + var result = new HashMap(); + + var serviceSettings = (Map) other.get(ModelConfigurations.SERVICE_SETTINGS); + if (serviceSettings != null) { + var copiedServiceSettings = copyMap1LevelDeep(serviceSettings); + result.put(ModelConfigurations.SERVICE_SETTINGS, copiedServiceSettings); + } + + var taskSettings = (Map) other.get(ModelConfigurations.TASK_SETTINGS); + if (taskSettings != null) { + var copiedTaskSettings = copyMap1LevelDeep(taskSettings); + result.put(ModelConfigurations.TASK_SETTINGS, copiedTaskSettings); + } + + var chunkSettings = (Map) other.get(ModelConfigurations.CHUNKING_SETTINGS); + if (chunkSettings != null) { + var copiedChunkSettings = copyMap1LevelDeep(chunkSettings); + result.put(ModelConfigurations.CHUNKING_SETTINGS, copiedChunkSettings); + } + + return result; + } + + static Map copySecretsMap(Map other) { + return copyMap1LevelDeep(other); + } + + @SuppressWarnings("unchecked") + static Map copyMap1LevelDeep(Map other) { + var result = new HashMap(); + for (var entry : other.entrySet()) { + if (entry.getValue() instanceof Map) { + result.put(entry.getKey(), new HashMap<>((Map) entry.getValue())); + } else { + result.put(entry.getKey(), entry.getValue()); + } + } + return result; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index 23e806e01300a..0dd41db2f016c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.SubscribableListener; @@ -31,6 +32,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; +import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag; import org.elasticsearch.xpack.inference.InferencePlugin; import java.io.IOException; @@ -80,7 +82,6 @@ public BaseElasticsearchInternalService( @Override public void start(Model model, ActionListener finalListener) { if (model instanceof ElasticsearchInternalModel esModel) { - if (supportedTaskTypes().contains(model.getTaskType()) == false) { finalListener.onFailure( new IllegalStateException(TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name())) @@ -149,7 +150,7 @@ public void putModel(Model model, ActionListener listener) { } } - private void putBuiltInModel(String modelId, ActionListener listener) { + protected void putBuiltInModel(String modelId, ActionListener listener) { var input = new TrainedModelInput(List.of("text_field")); // by convention text_field is used var config = TrainedModelConfig.builder().setInput(input).setModelId(modelId).validate(true).build(); PutTrainedModelAction.Request putRequest = new PutTrainedModelAction.Request(config, false, true); @@ -258,4 +259,27 @@ public static InferModelAction.Request buildInferenceRequest( request.setChunked(chunk); return request; } + + protected abstract boolean isDefaultId(String inferenceId); + + protected void maybeStartDeployment( + ElasticsearchInternalModel model, + Exception e, + InferModelAction.Request request, + ActionListener listener + ) { + if (DefaultElserFeatureFlag.isEnabled() == false) { + listener.onFailure(e); + return; + } + + if (isDefaultId(model.getInferenceEntityId()) && ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + this.start( + model, + listener.delegateFailureAndWrap((l, started) -> { client.execute(InferModelAction.INSTANCE, request, listener); }) + ); + } else { + listener.onFailure(e); + } + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index e274c641e30be..dd14e16412996 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -26,6 +26,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; @@ -73,6 +74,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 ); + public static final String DEFAULT_ELSER_ID = ".elser-2"; + private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class); private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class); @@ -100,6 +103,17 @@ public void parseRequestConfig( Map config, ActionListener modelListener ) { + if (inferenceEntityId.equals(DEFAULT_ELSER_ID)) { + modelListener.onFailure( + new ElasticsearchStatusException( + "[{}] is a reserved inference Id. Cannot create a new inference endpoint with a reserved Id", + RestStatus.BAD_REQUEST, + inferenceEntityId + ) + ); + return; + } + try { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMap(config, ModelConfigurations.TASK_SETTINGS); @@ -459,20 +473,24 @@ public void infer( TimeValue timeout, ActionListener listener ) { - var taskType = model.getConfigurations().getTaskType(); - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - inferTextEmbedding(model, input, inputType, timeout, listener); - } else if (TaskType.RERANK.equals(taskType)) { - inferRerank(model, query, input, inputType, timeout, taskSettings, listener); - } else if (TaskType.SPARSE_EMBEDDING.equals(taskType)) { - inferSparseEmbedding(model, input, inputType, timeout, listener); + if (model instanceof ElasticsearchInternalModel esModel) { + var taskType = model.getConfigurations().getTaskType(); + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + inferTextEmbedding(esModel, input, inputType, timeout, listener); + } else if (TaskType.RERANK.equals(taskType)) { + inferRerank(esModel, query, input, inputType, timeout, taskSettings, listener); + } else if (TaskType.SPARSE_EMBEDDING.equals(taskType)) { + inferSparseEmbedding(esModel, input, inputType, timeout, listener); + } else { + throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); + } } else { - throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); + listener.onFailure(notElasticsearchModelException(model)); } } public void inferTextEmbedding( - Model model, + ElasticsearchInternalModel model, List inputs, InputType inputType, TimeValue timeout, @@ -487,17 +505,19 @@ public void inferTextEmbedding( false ); - client.execute( - InferModelAction.INSTANCE, - request, - listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse(InferenceTextEmbeddingFloatResults.of(inferenceResult.getInferenceResults())) - ) + ActionListener mlResultsListener = listener.delegateFailureAndWrap( + (l, inferenceResult) -> l.onResponse(InferenceTextEmbeddingFloatResults.of(inferenceResult.getInferenceResults())) + ); + + var maybeDeployListener = mlResultsListener.delegateResponse( + (l, exception) -> maybeStartDeployment(model, exception, request, mlResultsListener) ); + + client.execute(InferModelAction.INSTANCE, request, maybeDeployListener); } public void inferSparseEmbedding( - Model model, + ElasticsearchInternalModel model, List inputs, InputType inputType, TimeValue timeout, @@ -512,17 +532,19 @@ public void inferSparseEmbedding( false ); - client.execute( - InferModelAction.INSTANCE, - request, - listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse(SparseEmbeddingResults.of(inferenceResult.getInferenceResults())) - ) + ActionListener mlResultsListener = listener.delegateFailureAndWrap( + (l, inferenceResult) -> l.onResponse(SparseEmbeddingResults.of(inferenceResult.getInferenceResults())) + ); + + var maybeDeployListener = mlResultsListener.delegateResponse( + (l, exception) -> maybeStartDeployment(model, exception, request, mlResultsListener) ); + + client.execute(InferModelAction.INSTANCE, request, maybeDeployListener); } public void inferRerank( - Model model, + ElasticsearchInternalModel model, String query, List inputs, InputType inputType, @@ -671,4 +693,42 @@ private RankedDocsResults textSimilarityResultsToRankedDocs( Collections.sort(rankings); return new RankedDocsResults(rankings); } + + @Override + public List defaultConfigs() { + // TODO Chunking settings + Map elserSettings = Map.of( + ModelConfigurations.SERVICE_SETTINGS, + Map.of( + ElasticsearchInternalServiceSettings.MODEL_ID, + ElserModels.ELSER_V2_MODEL, // TODO pick model depending on platform + ElasticsearchInternalServiceSettings.NUM_THREADS, + 1, + ElasticsearchInternalServiceSettings.ADAPTIVE_ALLOCATIONS, + Map.of( + "enabled", + Boolean.TRUE, + "min_number_of_allocations", + 1, + "max_number_of_allocations", + 8 // no max? + ) + ) + ); + + return List.of( + new UnparsedModel( + DEFAULT_ELSER_ID, + TaskType.SPARSE_EMBEDDING, + NAME, + elserSettings, + Map.of() // no secrets + ) + ); + } + + @Override + protected boolean isDefaultId(String inferenceId) { + return DEFAULT_ELSER_ID.equals(inferenceId); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index d78ea7933e836..770e6e3cb9cf4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; @@ -266,12 +267,11 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool ModelRegistry modelRegistry = mock(ModelRegistry.class); Answer unparsedModelAnswer = invocationOnMock -> { String id = (String) invocationOnMock.getArguments()[0]; - ActionListener listener = (ActionListener) invocationOnMock - .getArguments()[1]; + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; var model = modelMap.get(id); if (model != null) { listener.onResponse( - new ModelRegistry.UnparsedModel( + new UnparsedModel( model.getInferenceEntityId(), model.getTaskType(), model.getServiceSettings().model(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index fbd8ccd621559..75c370fd4d3fb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchResponseUtils; @@ -38,9 +39,12 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.core.Strings.format; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.sameInstance; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -68,7 +72,7 @@ public void testGetUnparsedModelMap_ThrowsResourceNotFound_WhenNoHitsReturned() var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); ResourceNotFoundException exception = expectThrows(ResourceNotFoundException.class, () -> listener.actionGet(TIMEOUT)); @@ -82,7 +86,7 @@ public void testGetUnparsedModelMap_ThrowsIllegalArgumentException_WhenInvalidIn var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> listener.actionGet(TIMEOUT)); @@ -99,7 +103,7 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); @@ -116,7 +120,7 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); @@ -150,7 +154,7 @@ public void testGetModelWithSecrets() { var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); var modelConfig = listener.actionGet(TIMEOUT); @@ -179,7 +183,7 @@ public void testGetModelNoSecrets() { var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModel("1", listener); registry.getModel("1", listener); @@ -288,6 +292,80 @@ public void testStoreModel_ThrowsException_WhenFailureIsNotAVersionConflict() { ); } + @SuppressWarnings("unchecked") + public void testDeepCopyDefaultConfig() { + { + var toCopy = new UnparsedModel("tocopy", randomFrom(TaskType.values()), "service-a", Map.of(), Map.of()); + var copied = ModelRegistry.deepCopyDefaultConfig(toCopy); + assertThat(copied, not(sameInstance(toCopy))); + assertThat(copied.taskType(), is(toCopy.taskType())); + assertThat(copied.service(), is(toCopy.service())); + assertThat(copied.secrets(), not(sameInstance(toCopy.secrets()))); + assertThat(copied.secrets(), is(toCopy.secrets())); + // Test copied is a modifiable map + copied.secrets().put("foo", "bar"); + + assertThat(copied.settings(), not(sameInstance(toCopy.settings()))); + assertThat(copied.settings(), is(toCopy.settings())); + // Test copied is a modifiable map + copied.settings().put("foo", "bar"); + } + + { + Map secretsMap = Map.of("secret", "value"); + Map chunking = Map.of("strategy", "word"); + Map task = Map.of("user", "name"); + Map service = Map.of("num_threads", 1, "adaptive_allocations", Map.of("enabled", true)); + Map settings = Map.of("chunking_settings", chunking, "service_settings", service, "task_settings", task); + + var toCopy = new UnparsedModel("tocopy", randomFrom(TaskType.values()), "service-a", settings, secretsMap); + var copied = ModelRegistry.deepCopyDefaultConfig(toCopy); + assertThat(copied, not(sameInstance(toCopy))); + + assertThat(copied.secrets(), not(sameInstance(toCopy.secrets()))); + assertThat(copied.secrets(), is(toCopy.secrets())); + // Test copied is a modifiable map + copied.secrets().remove("secret"); + + assertThat(copied.settings(), not(sameInstance(toCopy.settings()))); + assertThat(copied.settings(), is(toCopy.settings())); + // Test copied is a modifiable map + var chunkOut = (Map) copied.settings().get("chunking_settings"); + assertThat(chunkOut, is(chunking)); + chunkOut.remove("strategy"); + + var taskOut = (Map) copied.settings().get("task_settings"); + assertThat(taskOut, is(task)); + taskOut.remove("user"); + + var serviceOut = (Map) copied.settings().get("service_settings"); + assertThat(serviceOut, is(service)); + var adaptiveOut = (Map) serviceOut.remove("adaptive_allocations"); + assertThat(adaptiveOut, is(Map.of("enabled", true))); + adaptiveOut.remove("enabled"); + } + } + + public void testDuplicateDefaultIds() { + var client = mockBulkClient(); + var registry = new ModelRegistry(client); + + var id = "my-inference"; + + registry.addDefaultConfiguration(new UnparsedModel(id, randomFrom(TaskType.values()), "service-a", Map.of(), Map.of())); + var ise = expectThrows( + IllegalStateException.class, + () -> registry.addDefaultConfiguration(new UnparsedModel(id, randomFrom(TaskType.values()), "service-b", Map.of(), Map.of())) + ); + assertThat( + ise.getMessage(), + containsString( + "Cannot add default endpoint to the inference endpoint registry with duplicate inference id [my-inference] declared by " + + "service [service-b]. The inference Id is already use by [service-a] service." + ) + ); + } + private Client mockBulkClient() { var client = mockClient(); when(client.prepareBulk()).thenReturn(new BulkRequestBuilder(client)); diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml index 6aec721b35418..11be68cc764e2 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml @@ -44,15 +44,18 @@ - do: inference.get: inference_id: "*" - - length: { endpoints: 0} + - length: { endpoints: 1} + - match: { endpoints.0.inference_id: ".elser-2" } - do: inference.get: inference_id: _all - - length: { endpoints: 0} + - length: { endpoints: 1} + - match: { endpoints.0.inference_id: ".elser-2" } - do: inference.get: inference_id: "" - - length: { endpoints: 0} + - length: { endpoints: 1} + - match: { endpoints.0.inference_id: ".elser-2" } From c84471a838ab4da59ce2e3104a41fc5ec57447e2 Mon Sep 17 00:00:00 2001 From: Chris Hegarty <62058229+ChrisHegarty@users.noreply.github.com> Date: Mon, 7 Oct 2024 21:03:17 +0100 Subject: [PATCH 07/22] [8.x] Allow incubating Panama Vector in simdvec, and add vectorized ipByteBin (#114245) * Allow incubating Panama Vector in simdvec, and add vectorized ipByteBin (#112933) Add support for vectorized ipByteBin. The structure of the implementation and loading framework mirror that of Lucene, but is simplified by avoiding reflective loading since ES has support for a MRJar section for 21. For now, we just disable warnings-as-errors in this small sourceset, since -Xlint:-incubating is only support since JDK 22. The number of source files is small here. Will investigate how to assert that just the single incubating warning is emitted by javac, at a later point. * fix build runtime java check * Fix Gradle configuration in idea for :libs:simdvec (#114251) --------- Co-authored-by: Elastic Machine Co-authored-by: Rene Groeschke --- docs/changelog/112933.yaml | 5 + libs/simdvec/build.gradle | 15 ++ libs/simdvec/src/main/java/module-info.java | 1 + .../elasticsearch/simdvec/ESVectorUtil.java | 27 ++++ .../DefaultESVectorUtilSupport.java | 39 +++++ .../DefaultESVectorizationProvider.java | 23 +++ .../vectorization/ESVectorUtilSupport.java | 17 ++ .../ESVectorizationProvider.java | 38 +++++ .../ESVectorizationProvider.java | 87 ++++++++++ .../PanamaESVectorUtilSupport.java | 153 ++++++++++++++++++ .../PanamaESVectorizationProvider.java | 24 +++ .../simdvec/ESVectorUtilTests.java | 130 +++++++++++++++ .../vectorization/BaseVectorizationTests.java | 29 ++++ 13 files changed, 588 insertions(+) create mode 100644 docs/changelog/112933.yaml create mode 100644 libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java create mode 100644 libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java create mode 100644 libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java create mode 100644 libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java create mode 100644 libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java create mode 100644 libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java create mode 100644 libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java create mode 100644 libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java create mode 100644 libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java create mode 100644 libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/BaseVectorizationTests.java diff --git a/docs/changelog/112933.yaml b/docs/changelog/112933.yaml new file mode 100644 index 0000000000000..222cd5aadf739 --- /dev/null +++ b/docs/changelog/112933.yaml @@ -0,0 +1,5 @@ +pr: 112933 +summary: "Allow incubating Panama Vector in simdvec, and add vectorized `ipByteBin`" +area: Search +type: enhancement +issues: [] diff --git a/libs/simdvec/build.gradle b/libs/simdvec/build.gradle index 5a523a19d4b68..eee56be72d0bf 100644 --- a/libs/simdvec/build.gradle +++ b/libs/simdvec/build.gradle @@ -7,6 +7,7 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ +import org.elasticsearch.gradle.internal.info.BuildParams import org.elasticsearch.gradle.internal.precommit.CheckForbiddenApisTask apply plugin: 'elasticsearch.publish' @@ -23,6 +24,20 @@ dependencies { } } +// compileMain21Java does not exist within idea (see MrJarPlugin) so we cannot reference directly by name +tasks.matching { it.name == "compileMain21Java" }.configureEach { + options.compilerArgs << '--add-modules=jdk.incubator.vector' + // we remove Werror, since incubating suppression (-Xlint:-incubating) + // is only support since JDK 22 + options.compilerArgs -= '-Werror' +} + +tasks.named('test').configure { + if (BuildParams.getRuntimeJavaVersion().majorVersion.toInteger() >= 21) { + jvmArgs '--add-modules=jdk.incubator.vector' + } +} + tasks.withType(CheckForbiddenApisTask).configureEach { replaceSignatureFiles 'jdk-signatures' } diff --git a/libs/simdvec/src/main/java/module-info.java b/libs/simdvec/src/main/java/module-info.java index 64e685ba3cbb5..44f6e39d5dbab 100644 --- a/libs/simdvec/src/main/java/module-info.java +++ b/libs/simdvec/src/main/java/module-info.java @@ -10,6 +10,7 @@ module org.elasticsearch.simdvec { requires org.elasticsearch.nativeaccess; requires org.apache.lucene.core; + requires org.elasticsearch.logging; exports org.elasticsearch.simdvec to org.elasticsearch.server; } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java new file mode 100644 index 0000000000000..91193d5fa6eaf --- /dev/null +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec; + +import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport; +import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider; + +import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY; + +public class ESVectorUtil { + + private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport(); + + public static long ipByteBinByte(byte[] q, byte[] d) { + if (q.length != d.length * B_QUERY) { + throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length); + } + return IMPL.ipByteBinByte(q, d); + } +} diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java new file mode 100644 index 0000000000000..4a08096119d6a --- /dev/null +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal.vectorization; + +import org.apache.lucene.util.BitUtil; + +final class DefaultESVectorUtilSupport implements ESVectorUtilSupport { + + DefaultESVectorUtilSupport() {} + + @Override + public long ipByteBinByte(byte[] q, byte[] d) { + return ipByteBinByteImpl(q, d); + } + + public static long ipByteBinByteImpl(byte[] q, byte[] d) { + long ret = 0; + int size = d.length; + for (int i = 0; i < B_QUERY; i++) { + int r = 0; + long subRet = 0; + for (final int upperBound = d.length & -Integer.BYTES; r < upperBound; r += Integer.BYTES) { + subRet += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(q, i * size + r) & (int) BitUtil.VH_NATIVE_INT.get(d, r)); + } + for (; r < d.length; r++) { + subRet += Integer.bitCount((q[i * size + r] & d[r]) & 0xFF); + } + ret += subRet << i; + } + return ret; + } +} diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java new file mode 100644 index 0000000000000..6c0f7ed146b86 --- /dev/null +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal.vectorization; + +final class DefaultESVectorizationProvider extends ESVectorizationProvider { + private final ESVectorUtilSupport vectorUtilSupport; + + DefaultESVectorizationProvider() { + vectorUtilSupport = new DefaultESVectorUtilSupport(); + } + + @Override + public ESVectorUtilSupport getVectorUtilSupport() { + return vectorUtilSupport; + } +} diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java new file mode 100644 index 0000000000000..d7611173ca693 --- /dev/null +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java @@ -0,0 +1,17 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal.vectorization; + +public interface ESVectorUtilSupport { + + short B_QUERY = 4; + + long ipByteBinByte(byte[] q, byte[] d); +} diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java new file mode 100644 index 0000000000000..e541c10e145bf --- /dev/null +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal.vectorization; + +import java.util.Objects; + +public abstract class ESVectorizationProvider { + + public static ESVectorizationProvider getInstance() { + return Objects.requireNonNull( + ESVectorizationProvider.Holder.INSTANCE, + "call to getInstance() from subclass of VectorizationProvider" + ); + } + + ESVectorizationProvider() {} + + public abstract ESVectorUtilSupport getVectorUtilSupport(); + + // visible for tests + static ESVectorizationProvider lookup(boolean testMode) { + return new DefaultESVectorizationProvider(); + } + + /** This static holder class prevents classloading deadlock. */ + private static final class Holder { + private Holder() {} + + static final ESVectorizationProvider INSTANCE = lookup(false); + } +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java new file mode 100644 index 0000000000000..5b7aab7ddfa48 --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java @@ -0,0 +1,87 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal.vectorization; + +import org.apache.lucene.util.Constants; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; + +import java.util.Locale; +import java.util.Objects; +import java.util.Optional; + +public abstract class ESVectorizationProvider { + + protected static final Logger logger = LogManager.getLogger(ESVectorizationProvider.class); + + public static ESVectorizationProvider getInstance() { + return Objects.requireNonNull( + ESVectorizationProvider.Holder.INSTANCE, + "call to getInstance() from subclass of VectorizationProvider" + ); + } + + ESVectorizationProvider() {} + + public abstract ESVectorUtilSupport getVectorUtilSupport(); + + // visible for tests + static ESVectorizationProvider lookup(boolean testMode) { + final int runtimeVersion = Runtime.version().feature(); + assert runtimeVersion >= 21; + if (runtimeVersion <= 23) { + // only use vector module with Hotspot VM + if (Constants.IS_HOTSPOT_VM == false) { + logger.warn("Java runtime is not using Hotspot VM; Java vector incubator API can't be enabled."); + return new DefaultESVectorizationProvider(); + } + // is the incubator module present and readable (JVM providers may to exclude them or it is + // build with jlink) + final var vectorMod = lookupVectorModule(); + if (vectorMod.isEmpty()) { + logger.warn( + "Java vector incubator module is not readable. " + + "For optimal vector performance, pass '--add-modules jdk.incubator.vector' to enable Vector API." + ); + return new DefaultESVectorizationProvider(); + } + vectorMod.ifPresent(ESVectorizationProvider.class.getModule()::addReads); + var impl = new PanamaESVectorizationProvider(); + logger.info( + String.format( + Locale.ENGLISH, + "Java vector incubator API enabled; uses preferredBitSize=%d", + PanamaESVectorUtilSupport.VECTOR_BITSIZE + ) + ); + return impl; + } else { + logger.warn( + "You are running with unsupported Java " + + runtimeVersion + + ". To make full use of the Vector API, please update Elasticsearch." + ); + } + return new DefaultESVectorizationProvider(); + } + + private static Optional lookupVectorModule() { + return Optional.ofNullable(ESVectorizationProvider.class.getModule().getLayer()) + .orElse(ModuleLayer.boot()) + .findModule("jdk.incubator.vector"); + } + + /** This static holder class prevents classloading deadlock. */ + private static final class Holder { + private Holder() {} + + static final ESVectorizationProvider INSTANCE = lookup(false); + } +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java new file mode 100644 index 0000000000000..0e5827d046736 --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java @@ -0,0 +1,153 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal.vectorization; + +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.LongVector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorShape; +import jdk.incubator.vector.VectorSpecies; + +import org.apache.lucene.util.Constants; + +public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport { + + static final int VECTOR_BITSIZE; + + /** Whether integer vectors can be trusted to actually be fast. */ + static final boolean HAS_FAST_INTEGER_VECTORS; + + static { + // default to platform supported bitsize + VECTOR_BITSIZE = VectorShape.preferredShape().vectorBitSize(); + + // hotspot misses some SSE intrinsics, workaround it + // to be fair, they do document this thing only works well with AVX2/AVX3 and Neon + boolean isAMD64withoutAVX2 = Constants.OS_ARCH.equals("amd64") && VECTOR_BITSIZE < 256; + HAS_FAST_INTEGER_VECTORS = isAMD64withoutAVX2 == false; + } + + @Override + public long ipByteBinByte(byte[] q, byte[] d) { + // 128 / 8 == 16 + if (d.length >= 16 && HAS_FAST_INTEGER_VECTORS) { + if (VECTOR_BITSIZE >= 256) { + return ipByteBin256(q, d); + } else if (VECTOR_BITSIZE == 128) { + return ipByteBin128(q, d); + } + } + return DefaultESVectorUtilSupport.ipByteBinByteImpl(q, d); + } + + private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; + private static final VectorSpecies BYTE_SPECIES_256 = ByteVector.SPECIES_256; + + static long ipByteBin256(byte[] q, byte[] d) { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + + if (d.length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { + int limit = ByteVector.SPECIES_256.loopBound(d.length); + var sum0 = LongVector.zero(LongVector.SPECIES_256); + var sum1 = LongVector.zero(LongVector.SPECIES_256); + var sum2 = LongVector.zero(LongVector.SPECIES_256); + var sum3 = LongVector.zero(LongVector.SPECIES_256); + for (; i < limit; i += ByteVector.SPECIES_256.length()) { + var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length * 3).reinterpretAsLongs(); + var vd = ByteVector.fromArray(BYTE_SPECIES_256, d, i).reinterpretAsLongs(); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + + if (d.length - i >= ByteVector.SPECIES_128.vectorByteSize()) { + var sum0 = LongVector.zero(LongVector.SPECIES_128); + var sum1 = LongVector.zero(LongVector.SPECIES_128); + var sum2 = LongVector.zero(LongVector.SPECIES_128); + var sum3 = LongVector.zero(LongVector.SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(d.length); + for (; i < limit; i += ByteVector.SPECIES_128.length()) { + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 3).reinterpretAsLongs(); + var vd = ByteVector.fromArray(BYTE_SPECIES_128, d, i).reinterpretAsLongs(); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + // tail as bytes + for (; i < d.length; i++) { + subRet0 += Integer.bitCount((q[i] & d[i]) & 0xFF); + subRet1 += Integer.bitCount((q[i + d.length] & d[i]) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * d.length] & d[i]) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * d.length] & d[i]) & 0xFF); + } + return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + + public static long ipByteBin128(byte[] q, byte[] d) { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + + var sum0 = IntVector.zero(IntVector.SPECIES_128); + var sum1 = IntVector.zero(IntVector.SPECIES_128); + var sum2 = IntVector.zero(IntVector.SPECIES_128); + var sum3 = IntVector.zero(IntVector.SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(d.length); + for (; i < limit; i += ByteVector.SPECIES_128.length()) { + var vd = ByteVector.fromArray(BYTE_SPECIES_128, d, i).reinterpretAsInts(); + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length).reinterpretAsInts(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 2).reinterpretAsInts(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 3).reinterpretAsInts(); + sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + // tail as bytes + for (; i < d.length; i++) { + int dValue = d[i]; + subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF); + subRet1 += Integer.bitCount((dValue & q[i + d.length]) & 0xFF); + subRet2 += Integer.bitCount((dValue & q[i + 2 * d.length]) & 0xFF); + subRet3 += Integer.bitCount((dValue & q[i + 3 * d.length]) & 0xFF); + } + return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java new file mode 100644 index 0000000000000..62d25d79487ed --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal.vectorization; + +final class PanamaESVectorizationProvider extends ESVectorizationProvider { + + private final ESVectorUtilSupport vectorUtilSupport; + + PanamaESVectorizationProvider() { + vectorUtilSupport = new PanamaESVectorUtilSupport(); + } + + @Override + public ESVectorUtilSupport getVectorUtilSupport() { + return vectorUtilSupport; + } +} diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java new file mode 100644 index 0000000000000..0dbc41c0c1055 --- /dev/null +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java @@ -0,0 +1,130 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec; + +import org.elasticsearch.simdvec.internal.vectorization.BaseVectorizationTests; +import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider; + +import java.util.Arrays; + +import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY; + +public class ESVectorUtilTests extends BaseVectorizationTests { + + static final ESVectorizationProvider defaultedProvider = BaseVectorizationTests.defaultProvider(); + static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider(); + + public void testIpByteBinInvariants() { + int iterations = atLeast(10); + for (int i = 0; i < iterations; i++) { + int size = randomIntBetween(1, 10); + var d = new byte[size]; + var q = new byte[size * B_QUERY - 1]; + expectThrows(IllegalArgumentException.class, () -> ESVectorUtil.ipByteBinByte(q, d)); + } + } + + public void testBasicIpByteBin() { + testBasicIpByteBinImpl(ESVectorUtil::ipByteBinByte); + testBasicIpByteBinImpl(defaultedProvider.getVectorUtilSupport()::ipByteBinByte); + testBasicIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte); + } + + interface IpByteBin { + long apply(byte[] q, byte[] d); + } + + void testBasicIpByteBinImpl(IpByteBin ipByteBinFunc) { + assertEquals(15L, ipByteBinFunc.apply(new byte[] { 1, 1, 1, 1 }, new byte[] { 1 })); + assertEquals(30L, ipByteBinFunc.apply(new byte[] { 1, 2, 1, 2, 1, 2, 1, 2 }, new byte[] { 1, 2 })); + + var d = new byte[] { 1, 2, 3 }; + var q = new byte[] { 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3 }; + assert scalarIpByteBin(q, d) == 60L; // 4 + 8 + 16 + 32 + assertEquals(60L, ipByteBinFunc.apply(q, d)); + + d = new byte[] { 1, 2, 3, 4 }; + q = new byte[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }; + assert scalarIpByteBin(q, d) == 75L; // 5 + 10 + 20 + 40 + assertEquals(75L, ipByteBinFunc.apply(q, d)); + + d = new byte[] { 1, 2, 3, 4, 5 }; + q = new byte[] { 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5 }; + assert scalarIpByteBin(q, d) == 105L; // 7 + 14 + 28 + 56 + assertEquals(105L, ipByteBinFunc.apply(q, d)); + + d = new byte[] { 1, 2, 3, 4, 5, 6 }; + q = new byte[] { 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6 }; + assert scalarIpByteBin(q, d) == 135L; // 9 + 18 + 36 + 72 + assertEquals(135L, ipByteBinFunc.apply(q, d)); + + d = new byte[] { 1, 2, 3, 4, 5, 6, 7 }; + q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7 }; + assert scalarIpByteBin(q, d) == 180L; // 12 + 24 + 48 + 96 + assertEquals(180L, ipByteBinFunc.apply(q, d)); + + d = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }; + q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8 }; + assert scalarIpByteBin(q, d) == 195L; // 13 + 26 + 52 + 104 + assertEquals(195L, ipByteBinFunc.apply(q, d)); + + d = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + assert scalarIpByteBin(q, d) == 225L; // 15 + 30 + 60 + 120 + assertEquals(225L, ipByteBinFunc.apply(q, d)); + } + + public void testIpByteBin() { + testIpByteBinImpl(ESVectorUtil::ipByteBinByte); + testIpByteBinImpl(defaultedProvider.getVectorUtilSupport()::ipByteBinByte); + testIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte); + } + + void testIpByteBinImpl(IpByteBin ipByteBinFunc) { + int iterations = atLeast(50); + for (int i = 0; i < iterations; i++) { + int size = random().nextInt(5000); + var d = new byte[size]; + var q = new byte[size * B_QUERY]; + random().nextBytes(d); + random().nextBytes(q); + assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d)); + + Arrays.fill(d, Byte.MAX_VALUE); + Arrays.fill(q, Byte.MAX_VALUE); + assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d)); + + Arrays.fill(d, Byte.MIN_VALUE); + Arrays.fill(q, Byte.MIN_VALUE); + assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d)); + } + } + + static int scalarIpByteBin(byte[] q, byte[] d) { + int res = 0; + for (int i = 0; i < B_QUERY; i++) { + res += (popcount(q, i * d.length, d, d.length) << i); + } + return res; + } + + public static int popcount(byte[] a, int aOffset, byte[] b, int length) { + int res = 0; + for (int j = 0; j < length; j++) { + int value = (a[aOffset + j] & b[j]) & 0xFF; + for (int k = 0; k < Byte.SIZE; k++) { + if ((value & (1 << k)) != 0) { + ++res; + } + } + } + return res; + } +} diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/BaseVectorizationTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/BaseVectorizationTests.java new file mode 100644 index 0000000000000..f2bc8a11b04aa --- /dev/null +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/BaseVectorizationTests.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal.vectorization; + +import org.elasticsearch.test.ESTestCase; +import org.junit.Before; + +public class BaseVectorizationTests extends ESTestCase { + + @Before + public void sanity() { + assert Runtime.version().feature() < 21 || ModuleLayer.boot().findModule("jdk.incubator.vector").isPresent(); + } + + public static ESVectorizationProvider defaultProvider() { + return new DefaultESVectorizationProvider(); + } + + public static ESVectorizationProvider maybePanamaProvider() { + return ESVectorizationProvider.lookup(true); + } +} From 8da430bd6c1d676c0b853dd106f2c93fceab5ed0 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Mon, 7 Oct 2024 16:08:52 -0400 Subject: [PATCH 08/22] [ML] Add Streaming Inference spec (#113812) (#114262) API for `/_inference/{task_type}/{inference_id}/_stream` and `/_inference/{inference_id}/_stream` Request is `application/json` Response is `text/event-stream` --- docs/changelog/113812.yaml | 5 ++ .../api/inference.stream_inference.json | 49 +++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 docs/changelog/113812.yaml create mode 100644 rest-api-spec/src/main/resources/rest-api-spec/api/inference.stream_inference.json diff --git a/docs/changelog/113812.yaml b/docs/changelog/113812.yaml new file mode 100644 index 0000000000000..04498b4ae5f7e --- /dev/null +++ b/docs/changelog/113812.yaml @@ -0,0 +1,5 @@ +pr: 113812 +summary: Add Streaming Inference spec +area: Machine Learning +type: enhancement +issues: [] diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/inference.stream_inference.json b/rest-api-spec/src/main/resources/rest-api-spec/api/inference.stream_inference.json new file mode 100644 index 0000000000000..32b4b2f311837 --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/inference.stream_inference.json @@ -0,0 +1,49 @@ +{ + "inference.stream_inference":{ + "documentation":{ + "url":"https://www.elastic.co/guide/en/elasticsearch/reference/master/post-stream-inference-api.html", + "description":"Perform streaming inference" + }, + "stability":"experimental", + "visibility":"public", + "headers":{ + "accept": [ "text/event-stream"], + "content_type": ["application/json"] + }, + "url":{ + "paths":[ + { + "path":"/_inference/{inference_id}/_stream", + "methods":[ + "POST" + ], + "parts":{ + "inference_id":{ + "type":"string", + "description":"The inference Id" + } + } + }, + { + "path":"/_inference/{task_type}/{inference_id}/_stream", + "methods":[ + "POST" + ], + "parts":{ + "task_type":{ + "type":"string", + "description":"The task type" + }, + "inference_id":{ + "type":"string", + "description":"The inference Id" + } + } + } + ] + }, + "body":{ + "description":"The inference payload" + } + } +} From 80f1c1b7614f53c9beaf26de015e0dde7a39b5c7 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Mon, 7 Oct 2024 16:23:29 -0400 Subject: [PATCH 09/22] [ML] Stream Cohere Completion (#114080) (#114261) Implement and enable streaming for Cohere chat completions (v1). Includes processor for ND JSON streaming responses. Co-authored-by: Elastic Machine --- docs/changelog/114080.yaml | 5 + .../inference/common/DelegatingProcessor.java | 4 +- .../cohere/CohereResponseHandler.java | 23 ++- .../cohere/CohereStreamingProcessor.java | 101 ++++++++++ .../CohereCompletionRequestManager.java | 9 +- .../CohereEmbeddingsRequestManager.java | 2 +- .../sender/CohereRerankRequestManager.java | 2 +- .../completion/CohereCompletionRequest.java | 15 +- .../CohereCompletionRequestEntity.java | 8 +- .../NewlineDelimitedByteProcessor.java | 67 +++++++ .../services/cohere/CohereService.java | 6 + .../cohere/CohereResponseHandlerTests.java | 2 +- .../cohere/CohereStreamingProcessorTests.java | 189 ++++++++++++++++++ .../CohereCompletionRequestEntityTests.java | 8 +- .../cohere/CohereCompletionRequestTests.java | 8 +- .../NewlineDelimitedByteProcessorTests.java | 112 +++++++++++ .../services/cohere/CohereServiceTests.java | 50 +++++ 17 files changed, 585 insertions(+), 26 deletions(-) create mode 100644 docs/changelog/114080.yaml create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereStreamingProcessor.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/NewlineDelimitedByteProcessor.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereStreamingProcessorTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/NewlineDelimitedByteProcessorTests.java diff --git a/docs/changelog/114080.yaml b/docs/changelog/114080.yaml new file mode 100644 index 0000000000000..395768c46369a --- /dev/null +++ b/docs/changelog/114080.yaml @@ -0,0 +1,5 @@ +pr: 114080 +summary: Stream Cohere Completion +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java index 9af5668ecf75b..fc2d890dd89e6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java @@ -21,7 +21,7 @@ public abstract class DelegatingProcessor implements Flow.Processor { private static final Logger log = LogManager.getLogger(DelegatingProcessor.class); private final AtomicLong pendingRequests = new AtomicLong(); - private final AtomicBoolean isClosed = new AtomicBoolean(false); + protected final AtomicBoolean isClosed = new AtomicBoolean(false); private Flow.Subscriber downstream; private Flow.Subscription upstream; @@ -49,7 +49,7 @@ private Flow.Subscription forwardingSubscription() { @Override public void request(long n) { if (isClosed.get()) { - downstream.onComplete(); // shouldn't happen, but reinforce that we're no longer listening + downstream.onComplete(); } else if (upstream != null) { upstream.request(n); } else { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java index b5af0b474834f..3579cd4100bbb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java @@ -8,14 +8,19 @@ package org.elasticsearch.xpack.inference.external.cohere; import org.apache.logging.log4j.Logger; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; import org.elasticsearch.xpack.inference.external.http.retry.RetryException; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.cohere.CohereErrorResponseEntity; +import org.elasticsearch.xpack.inference.external.response.streaming.NewlineDelimitedByteProcessor; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import java.util.concurrent.Flow; + import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody; /** @@ -33,9 +38,11 @@ public class CohereResponseHandler extends BaseResponseHandler { static final String TEXTS_ARRAY_TOO_LARGE_MESSAGE_MATCHER = "invalid request: total number of texts must be at most"; static final String TEXTS_ARRAY_ERROR_MESSAGE = "Received a texts array too large response"; + private final boolean canHandleStreamingResponse; - public CohereResponseHandler(String requestType, ResponseParser parseFunction) { + public CohereResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponse) { super(requestType, parseFunction, CohereErrorResponseEntity::fromResponse); + this.canHandleStreamingResponse = canHandleStreamingResponse; } @Override @@ -45,6 +52,20 @@ public void validateResponse(ThrottlerManager throttlerManager, Logger logger, R checkForEmptyBody(throttlerManager, logger, request, result); } + @Override + public boolean canHandleStreamingResponses() { + return canHandleStreamingResponse; + } + + @Override + public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { + var ndProcessor = new NewlineDelimitedByteProcessor(); + var cohereProcessor = new CohereStreamingProcessor(); + flow.subscribe(ndProcessor); + ndProcessor.subscribe(cohereProcessor); + return new StreamingChatCompletionResults(cohereProcessor); + } + /** * Validates the status code throws an RetryException if not in the range [200, 300). * diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereStreamingProcessor.java new file mode 100644 index 0000000000000..2516a647a91fd --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereStreamingProcessor.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.cohere; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Map; +import java.util.Optional; + +class CohereStreamingProcessor extends DelegatingProcessor, StreamingChatCompletionResults.Results> { + private static final Logger log = LogManager.getLogger(CohereStreamingProcessor.class); + + @Override + protected void next(Deque item) throws Exception { + if (item.isEmpty()) { + // discard empty result and go to the next + upstream().request(1); + return; + } + + var results = new ArrayDeque(item.size()); + for (String json : item) { + try (var jsonParser = jsonParser(json)) { + var responseMap = jsonParser.map(); + var eventType = (String) responseMap.get("event_type"); + switch (eventType) { + case "text-generation" -> parseText(responseMap).ifPresent(results::offer); + case "stream-end" -> validateResponse(responseMap); + case "stream-start", "search-queries-generation", "search-results", "citation-generation", "tool-calls-generation", + "tool-calls-chunk" -> { + log.debug("Skipping event type [{}] for line [{}].", eventType, item); + } + default -> throw new IOException("Unknown eventType found: " + eventType); + } + } catch (ElasticsearchStatusException e) { + throw e; + } catch (Exception e) { + log.warn("Failed to parse json from cohere: {}", json); + throw e; + } + } + + if (results.isEmpty()) { + upstream().request(1); + } else { + downstream().onNext(new StreamingChatCompletionResults.Results(results)); + } + } + + private static XContentParser jsonParser(String line) throws IOException { + return XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, line); + } + + private Optional parseText(Map responseMap) throws IOException { + var text = (String) responseMap.get("text"); + if (text != null) { + return Optional.of(new StreamingChatCompletionResults.Result(text)); + } else { + throw new IOException("Null text found in text-generation cohere event"); + } + } + + private void validateResponse(Map responseMap) { + var finishReason = (String) responseMap.get("finish_reason"); + switch (finishReason) { + case "ERROR", "ERROR_TOXIC" -> throw new ElasticsearchStatusException( + "Cohere stopped the stream due to an error: {}", + RestStatus.INTERNAL_SERVER_ERROR, + parseErrorMessage(responseMap) + ); + case "ERROR_LIMIT" -> throw new ElasticsearchStatusException( + "Cohere stopped the stream due to an error: {}", + RestStatus.TOO_MANY_REQUESTS, + parseErrorMessage(responseMap) + ); + } + } + + @SuppressWarnings("unchecked") + private String parseErrorMessage(Map responseMap) { + var innerResponseMap = (Map) responseMap.get("response"); + return (String) innerResponseMap.get("text"); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java index 423093a14a9f0..ae46fbe0fef87 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java @@ -19,7 +19,6 @@ import org.elasticsearch.xpack.inference.external.response.cohere.CohereCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; -import java.util.List; import java.util.Objects; import java.util.function.Supplier; @@ -30,7 +29,7 @@ public class CohereCompletionRequestManager extends CohereRequestManager { private static final ResponseHandler HANDLER = createCompletionHandler(); private static ResponseHandler createCompletionHandler() { - return new CohereResponseHandler("cohere completion", CohereCompletionResponseEntity::fromResponse); + return new CohereResponseHandler("cohere completion", CohereCompletionResponseEntity::fromResponse, true); } public static CohereCompletionRequestManager of(CohereCompletionModel model, ThreadPool threadPool) { @@ -51,8 +50,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); - CohereCompletionRequest request = new CohereCompletionRequest(docsInput, model); + var docsOnly = DocumentsOnlyInput.of(inferenceInputs); + var docsInput = docsOnly.getInputs(); + var stream = docsOnly.stream(); + CohereCompletionRequest request = new CohereCompletionRequest(docsInput, model, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java index 402f91a0838dc..80617ea56e63c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java @@ -28,7 +28,7 @@ public class CohereEmbeddingsRequestManager extends CohereRequestManager { private static final ResponseHandler HANDLER = createEmbeddingsHandler(); private static ResponseHandler createEmbeddingsHandler() { - return new CohereResponseHandler("cohere text embedding", CohereEmbeddingsResponseEntity::fromResponse); + return new CohereResponseHandler("cohere text embedding", CohereEmbeddingsResponseEntity::fromResponse, false); } public static CohereEmbeddingsRequestManager of(CohereEmbeddingsModel model, ThreadPool threadPool) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java index 9d565e7124b03..d27812b17399b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java @@ -27,7 +27,7 @@ public class CohereRerankRequestManager extends CohereRequestManager { private static final ResponseHandler HANDLER = createCohereResponseHandler(); private static ResponseHandler createCohereResponseHandler() { - return new CohereResponseHandler("cohere rerank", (request, response) -> CohereRankedResponseEntity.fromResponse(response)); + return new CohereResponseHandler("cohere rerank", (request, response) -> CohereRankedResponseEntity.fromResponse(response), false); } public static CohereRerankRequestManager of(CohereRerankModel model, ThreadPool threadPool) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/completion/CohereCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/completion/CohereCompletionRequest.java index f68f919a7d85b..2172dcd4d791f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/completion/CohereCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/completion/CohereCompletionRequest.java @@ -25,22 +25,20 @@ import java.util.Objects; public class CohereCompletionRequest extends CohereRequest { - private final CohereAccount account; - private final List input; - private final String modelId; - private final String inferenceEntityId; + private final boolean stream; - public CohereCompletionRequest(List input, CohereCompletionModel model) { + public CohereCompletionRequest(List input, CohereCompletionModel model, boolean stream) { Objects.requireNonNull(model); this.account = CohereAccount.of(model, CohereCompletionRequest::buildDefaultUri); this.input = Objects.requireNonNull(input); this.modelId = model.getServiceSettings().modelId(); this.inferenceEntityId = model.getInferenceEntityId(); + this.stream = stream; } @Override @@ -48,7 +46,7 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(account.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new CohereCompletionRequestEntity(input, modelId)).getBytes(StandardCharsets.UTF_8) + Strings.toString(new CohereCompletionRequestEntity(input, modelId, isStreaming())).getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); @@ -62,6 +60,11 @@ public String getInferenceEntityId() { return inferenceEntityId; } + @Override + public boolean isStreaming() { + return stream; + } + @Override public URI getURI() { return account.uri(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/completion/CohereCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/completion/CohereCompletionRequestEntity.java index 8cb3dc6e3c8e8..b834e4335d73c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/completion/CohereCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/completion/CohereCompletionRequestEntity.java @@ -15,11 +15,11 @@ import java.util.List; import java.util.Objects; -public record CohereCompletionRequestEntity(List input, @Nullable String model) implements ToXContentObject { +public record CohereCompletionRequestEntity(List input, @Nullable String model, boolean stream) implements ToXContentObject { private static final String MESSAGE_FIELD = "message"; - private static final String MODEL = "model"; + private static final String STREAM = "stream"; public CohereCompletionRequestEntity { Objects.requireNonNull(input); @@ -36,6 +36,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(MODEL, model); } + if (stream) { + builder.field(STREAM, true); + } + builder.endObject(); return builder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/NewlineDelimitedByteProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/NewlineDelimitedByteProcessor.java new file mode 100644 index 0000000000000..7c44b202a816b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/NewlineDelimitedByteProcessor.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.streaming; + +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.regex.Pattern; + +/** + * Processes HttpResult bytes into lines separated by newlines, delimited by either line-feed or carriage-return line-feed. + * Downstream is responsible for validating the structure of the lines after they have been separated. + * Because Upstream (Apache) can send us a single line split between two HttpResults, this processor will aggregate bytes from the last + * HttpResult and append them to the front of the next HttpResult. + * When onComplete is called, the last batch is always flushed to the downstream onNext. + */ +public class NewlineDelimitedByteProcessor extends DelegatingProcessor> { + private static final Pattern END_OF_LINE_REGEX = Pattern.compile("\\n|\\r\\n"); + private volatile String previousTokens = ""; + + @Override + protected void next(HttpResult item) { + // discard empty result and go to the next + if (item.isBodyEmpty()) { + upstream().request(1); + return; + } + + var body = previousTokens + new String(item.body(), StandardCharsets.UTF_8); + var lines = END_OF_LINE_REGEX.split(body, -1); // -1 because we actually want trailing empty strings + + var results = new ArrayDeque(lines.length); + for (var i = 0; i < lines.length - 1; i++) { + var line = lines[i].trim(); + if (line.isBlank() == false) { + results.offer(line); + } + } + + previousTokens = lines[lines.length - 1].trim(); + + if (results.isEmpty()) { + upstream().request(1); + } else { + downstream().onNext(results); + } + } + + @Override + public void onComplete() { + if (previousTokens.isBlank()) { + super.onComplete(); + } else if (isClosed.compareAndSet(false, true)) { + var results = new ArrayDeque(1); + results.offer(previousTokens); + downstream().onNext(results); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 728a4ac137dff..3ba93dd8d1b66 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -39,6 +39,7 @@ import java.util.List; import java.util.Map; +import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; @@ -288,4 +289,9 @@ static SimilarityMeasure defaultSimilarity() { public TransportVersion getMinimalSupportedVersion() { return TransportVersions.ML_INFERENCE_RATE_LIMIT_SETTINGS_ADDED; } + + @Override + public Set supportedStreamingTasks() { + return COMPLETION_ONLY; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java index d64ac495c8c99..444415dfc8e48 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java @@ -132,7 +132,7 @@ private static void callCheckForFailureStatusCode(int statusCode, @Nullable Stri var mockRequest = mock(Request.class); when(mockRequest.getInferenceEntityId()).thenReturn(modelId); var httpResult = new HttpResult(httpResponse, errorMessage == null ? new byte[] {} : responseJson.getBytes(StandardCharsets.UTF_8)); - var handler = new CohereResponseHandler("", (request, result) -> null); + var handler = new CohereResponseHandler("", (request, result) -> null, false); handler.checkForFailureStatusCode(mockRequest, httpResult); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereStreamingProcessorTests.java new file mode 100644 index 0000000000000..87d6d63bb8c51 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereStreamingProcessorTests.java @@ -0,0 +1,189 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.cohere; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.concurrent.Flow; +import java.util.function.Consumer; + +import static org.elasticsearch.xpack.inference.common.DelegatingProcessorTests.onError; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.isA; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.assertArg; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class CohereStreamingProcessorTests extends ESTestCase { + + public void testParseErrorCallsOnError() { + var item = new ArrayDeque(); + item.offer("this is not json"); + + var exception = onError(new CohereStreamingProcessor(), item); + assertThat(exception, instanceOf(XContentParseException.class)); + } + + public void testUnrecognizedEventCallsOnError() { + var item = new ArrayDeque(); + item.offer("{\"event_type\":\"test\"}"); + + var exception = onError(new CohereStreamingProcessor(), item); + assertThat(exception, instanceOf(IOException.class)); + assertThat(exception.getMessage(), equalTo("Unknown eventType found: test")); + } + + public void testMissingTextCallsOnError() { + var item = new ArrayDeque(); + item.offer("{\"event_type\":\"text-generation\"}"); + + var exception = onError(new CohereStreamingProcessor(), item); + assertThat(exception, instanceOf(IOException.class)); + assertThat(exception.getMessage(), equalTo("Null text found in text-generation cohere event")); + } + + public void testEmptyResultsRequestsMoreData() throws Exception { + var emptyDeque = new ArrayDeque(); + + var processor = new CohereStreamingProcessor(); + + Flow.Subscriber downstream = mock(); + processor.subscribe(downstream); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + + processor.next(emptyDeque); + + verify(upstream, times(1)).request(1); + verify(downstream, times(0)).onNext(any()); + } + + public void testNonDataEventsAreSkipped() throws Exception { + var item = new ArrayDeque(); + item.offer("{\"event_type\":\"stream-start\"}"); + item.offer("{\"event_type\":\"search-queries-generation\"}"); + item.offer("{\"event_type\":\"search-results\"}"); + item.offer("{\"event_type\":\"citation-generation\"}"); + item.offer("{\"event_type\":\"tool-calls-generation\"}"); + item.offer("{\"event_type\":\"tool-calls-chunk\"}"); + + var processor = new CohereStreamingProcessor(); + + Flow.Subscriber downstream = mock(); + processor.subscribe(downstream); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + + processor.next(item); + + verify(upstream, times(1)).request(1); + verify(downstream, times(0)).onNext(any()); + } + + public void testParseError() { + var json = "{\"event_type\":\"stream-end\", \"finish_reason\":\"ERROR\", \"response\":{ \"text\": \"a wild error appears\" }}"; + testError(json, e -> { + assertThat(e.status().getStatus(), equalTo(500)); + assertThat(e.getMessage(), containsString("a wild error appears")); + }); + } + + private void testError(String json, Consumer test) { + var item = new ArrayDeque(); + item.offer(json); + + var processor = new CohereStreamingProcessor(); + + Flow.Subscriber downstream = mock(); + processor.subscribe(downstream); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + + try { + processor.next(item); + fail("Expected an exception to be thrown"); + } catch (ElasticsearchStatusException e) { + test.accept(e); + } catch (Exception e) { + fail(e, "Expected an exception of type ElasticsearchStatusException to be thrown"); + } + } + + public void testParseToxic() { + var json = "{\"event_type\":\"stream-end\", \"finish_reason\":\"ERROR_TOXIC\", \"response\":{ \"text\": \"by britney spears\" }}"; + testError(json, e -> { + assertThat(e.status().getStatus(), equalTo(500)); + assertThat(e.getMessage(), containsString("by britney spears")); + }); + } + + public void testParseLimit() { + var json = "{\"event_type\":\"stream-end\", \"finish_reason\":\"ERROR_LIMIT\", \"response\":{ \"text\": \"over the limit\" }}"; + testError(json, e -> { + assertThat(e.status().getStatus(), equalTo(429)); + assertThat(e.getMessage(), containsString("over the limit")); + }); + } + + public void testNonErrorFinishesAreSkipped() throws Exception { + var item = new ArrayDeque(); + item.offer("{\"event_type\":\"stream-end\", \"finish_reason\":\"COMPLETE\"}"); + item.offer("{\"event_type\":\"stream-end\", \"finish_reason\":\"STOP_SEQUENCE\"}"); + item.offer("{\"event_type\":\"stream-end\", \"finish_reason\":\"USER_CANCEL\"}"); + item.offer("{\"event_type\":\"stream-end\", \"finish_reason\":\"MAX_TOKENS\"}"); + + var processor = new CohereStreamingProcessor(); + + Flow.Subscriber downstream = mock(); + processor.subscribe(downstream); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + + processor.next(item); + + verify(upstream, times(1)).request(1); + verify(downstream, times(0)).onNext(any()); + } + + public void testParseCohereData() throws Exception { + var item = new ArrayDeque(); + item.offer("{\"event_type\":\"text-generation\", \"text\":\"hello there\"}"); + + var processor = new CohereStreamingProcessor(); + + Flow.Subscriber downstream = mock(); + processor.subscribe(downstream); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + + processor.next(item); + + verify(upstream, times(0)).request(1); + verify(downstream, times(1)).onNext(assertArg(chunks -> { + assertThat(chunks, isA(StreamingChatCompletionResults.Results.class)); + var results = (StreamingChatCompletionResults.Results) chunks; + assertThat(results.results().size(), equalTo(1)); + assertThat(results.results().getFirst().delta(), equalTo("hello there")); + })); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestEntityTests.java index dbe6a9438d884..c3b534f42e7ee 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestEntityTests.java @@ -22,7 +22,7 @@ public class CohereCompletionRequestEntityTests extends ESTestCase { public void testXContent_WritesAllFields() throws IOException { - var entity = new CohereCompletionRequestEntity(List.of("some input"), "model"); + var entity = new CohereCompletionRequestEntity(List.of("some input"), "model", false); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -33,7 +33,7 @@ public void testXContent_WritesAllFields() throws IOException { } public void testXContent_DoesNotWriteModelIfNotSpecified() throws IOException { - var entity = new CohereCompletionRequestEntity(List.of("some input"), null); + var entity = new CohereCompletionRequestEntity(List.of("some input"), null, false); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -44,10 +44,10 @@ public void testXContent_DoesNotWriteModelIfNotSpecified() throws IOException { } public void testXContent_ThrowsIfInputIsNull() { - expectThrows(NullPointerException.class, () -> new CohereCompletionRequestEntity(null, null)); + expectThrows(NullPointerException.class, () -> new CohereCompletionRequestEntity(null, null, false)); } public void testXContent_ThrowsIfMessageInInputIsNull() { - expectThrows(NullPointerException.class, () -> new CohereCompletionRequestEntity(List.of((String) null), null)); + expectThrows(NullPointerException.class, () -> new CohereCompletionRequestEntity(List.of((String) null), null, false)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestTests.java index d6d0d5c00eaf4..f2e6d4305f9e6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestTests.java @@ -26,7 +26,7 @@ public class CohereCompletionRequestTests extends ESTestCase { public void testCreateRequest_UrlDefined() throws IOException { - var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", null)); + var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", null), false); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -43,7 +43,7 @@ public void testCreateRequest_UrlDefined() throws IOException { } public void testCreateRequest_ModelDefined() throws IOException { - var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model")); + var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"), false); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -60,14 +60,14 @@ public void testCreateRequest_ModelDefined() throws IOException { } public void testTruncate_ReturnsSameInstance() { - var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model")); + var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"), false); var truncatedRequest = request.truncate(); assertThat(truncatedRequest, sameInstance(request)); } public void testTruncationInfo_ReturnsNull() { - var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model")); + var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"), false); assertNull(request.getTruncationInfo()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/NewlineDelimitedByteProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/NewlineDelimitedByteProcessorTests.java new file mode 100644 index 0000000000000..488cbccd0e7c3 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/NewlineDelimitedByteProcessorTests.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.streaming; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.nio.charset.StandardCharsets; +import java.util.Deque; +import java.util.concurrent.Flow; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.assertArg; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class NewlineDelimitedByteProcessorTests extends ESTestCase { + private Flow.Subscription upstream; + private Flow.Subscriber> downstream; + private NewlineDelimitedByteProcessor processor; + + @Before + public void setUp() throws Exception { + super.setUp(); + upstream = mock(); + downstream = mock(); + processor = new NewlineDelimitedByteProcessor(); + processor.onSubscribe(upstream); + processor.subscribe(downstream); + } + + public void testEmptyBody() { + processor.next(result(null)); + processor.onComplete(); + verify(upstream, times(1)).request(1); + verify(downstream, times(0)).onNext(any()); + } + + private HttpResult result(String response) { + return new HttpResult(mock(), response == null ? new byte[0] : response.getBytes(StandardCharsets.UTF_8)); + } + + public void testEmptyParseResponse() { + processor.next(result("")); + verify(upstream, times(1)).request(1); + verify(downstream, times(0)).onNext(any()); + } + + public void testValidResponse() { + processor.next(result("{\"hello\":\"there\"}\n")); + verify(downstream, times(1)).onNext(assertArg(deque -> { + assertThat(deque, notNullValue()); + assertThat(deque.size(), is(1)); + assertThat(deque.getFirst(), is("{\"hello\":\"there\"}")); + })); + } + + public void testMultipleValidResponse() { + processor.next(result(""" + {"value": 1} + {"value": 2} + {"value": 3} + """)); + verify(upstream, times(0)).request(1); + verify(downstream, times(1)).onNext(assertArg(deque -> { + assertThat(deque, notNullValue()); + assertThat(deque.size(), is(3)); + var items = deque.iterator(); + IntStream.range(1, 4).forEach(i -> { + assertThat(items.hasNext(), is(true)); + assertThat(items.next(), containsString(String.valueOf(i))); + }); + })); + } + + public void testOnCompleteFlushesResponse() { + processor.next(result(""" + {"value": 1}""")); + + // onNext should not be called with only one value + verify(downstream, times(0)).onNext(any()); + verify(downstream, times(0)).onComplete(); + + // onComplete should flush the value pending, and onNext should be called + processor.onComplete(); + verify(downstream, times(1)).onNext(assertArg(deque -> { + assertThat(deque, notNullValue()); + assertThat(deque.size(), is(1)); + var item = deque.getFirst(); + assertThat(item, containsString(String.valueOf(1))); + })); + verify(downstream, times(0)).onComplete(); + + // next time the downstream requests data, onComplete is called + var downstreamSubscription = ArgumentCaptor.forClass(Flow.Subscription.class); + verify(downstream).onSubscribe(downstreamSubscription.capture()); + downstreamSubscription.getValue().request(1); + verify(downstream, times(1)).onComplete(); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index af73986ec551e..d45d639f4dd11 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -38,6 +38,8 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; @@ -1347,6 +1349,54 @@ public void testDefaultSimilarity() { assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity()); } + public void testInfer_StreamRequest() throws Exception { + String responseJson = """ + {"event_type":"text-generation", "text":"hello"} + {"event_type":"text-generation", "text":"there"} + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var result = streamChatCompletion(); + + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"hello"},{"delta":"there"}]}"""); + } + + private InferenceServiceResults streamChatCompletion() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", "model"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + true, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + return listener.actionGet(TIMEOUT); + } + } + + public void testInfer_StreamRequest_ErrorResponse() throws Exception { + String responseJson = """ + { "event_type":"stream-end", "finish_reason":"ERROR", "response":{ "text": "how dare you" } } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var result = streamChatCompletion(); + + InferenceEventsAssertion.assertThat(result) + .hasFinishedStream() + .hasNoEvents() + .hasErrorWithStatusCode(500) + .hasErrorContaining("how dare you"); + } + private Map getRequestConfigMap( Map serviceSettings, Map taskSettings, From 183aa71e769bd159f48be7ae12a22745c0edfb78 Mon Sep 17 00:00:00 2001 From: Dan Rubinstein Date: Mon, 7 Oct 2024 17:09:28 -0400 Subject: [PATCH 10/22] [8.x] Adding chunking settings to MistralService, GoogleAiStudioService, and HuggingFaceService (#113623) (#113886) * Adding chunking settings to MistralService, GoogleAiStudioService, and HuggingFaceService (#113623) * Adding chunking settings to MistralService, GoogleAiStudioService, and HuggingFaceService * Update docs/changelog/113623.yaml * Removing chunking settings from HuggingFaceElser model inputs * Removing unnecessary platfromArchitectures from InferenceService tests (#113890) * Removing unnecessary platfromArchitectures from InferenceService tests * Update docs/changelog/113890.yaml * Delete docs/changelog/113890.yaml * Fix HuggingFaceMixedIT test sometimes failing when run on version before 8.16 (#114061) * Fix HuggingFaceMixedIT test sometimes failing when run on version before 8.16 * Fixing typo in expected error message --------- Co-authored-by: Elastic Machine --- docs/changelog/113623.yaml | 6 + .../inference/ModelConfigurations.java | 10 + .../qa/mixed/HuggingFaceServiceMixedIT.java | 21 +- .../googleaistudio/GoogleAiStudioService.java | 48 ++- .../GoogleAiStudioEmbeddingsModel.java | 29 +- .../huggingface/HuggingFaceBaseService.java | 25 ++ .../huggingface/HuggingFaceService.java | 25 +- .../elser/HuggingFaceElserService.java | 2 + .../HuggingFaceEmbeddingsModel.java | 7 +- .../services/mistral/MistralService.java | 57 +++- .../embeddings/MistralEmbeddingsModel.java | 6 +- .../GoogleAiStudioServiceTests.java | 280 +++++++++++++++++- .../GoogleAiStudioEmbeddingsModelTests.java | 15 + .../HuggingFaceBaseServiceTests.java | 15 - .../huggingface/HuggingFaceServiceTests.java | 206 +++++++++++++ .../HuggingFaceEmbeddingsModelTests.java | 4 + .../services/mistral/MistralServiceTests.java | 265 ++++++++++++++++- .../MistralEmbeddingModelTests.java | 23 ++ 18 files changed, 1003 insertions(+), 41 deletions(-) create mode 100644 docs/changelog/113623.yaml diff --git a/docs/changelog/113623.yaml b/docs/changelog/113623.yaml new file mode 100644 index 0000000000000..8587687d27080 --- /dev/null +++ b/docs/changelog/113623.yaml @@ -0,0 +1,6 @@ +pr: 113623 +summary: "Adding chunking settings to `MistralService,` `GoogleAiStudioService,` and\ + \ `HuggingFaceService`" +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java b/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java index e5bd5a629a912..9f7a247d00a3a 100644 --- a/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java +++ b/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java @@ -74,6 +74,16 @@ public ModelConfigurations(String inferenceEntityId, TaskType taskType, String s this(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE); } + public ModelConfigurations( + String inferenceEntityId, + TaskType taskType, + String service, + ServiceSettings serviceSettings, + ChunkingSettings chunkingSettings + ) { + this(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings); + } + public ModelConfigurations( String inferenceEntityId, TaskType taskType, diff --git a/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/HuggingFaceServiceMixedIT.java b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/HuggingFaceServiceMixedIT.java index 59d3faf6489a6..457ae525b7f4b 100644 --- a/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/HuggingFaceServiceMixedIT.java +++ b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/HuggingFaceServiceMixedIT.java @@ -19,6 +19,7 @@ import java.util.List; import java.util.Map; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.not; @@ -27,6 +28,7 @@ public class HuggingFaceServiceMixedIT extends BaseMixedTestCase { private static final String HF_EMBEDDINGS_ADDED = "8.12.0"; private static final String HF_ELSER_ADDED = "8.12.0"; + private static final String HF_EMBEDDINGS_CHUNKING_SETTINGS_ADDED = "8.16.0"; private static final String MINIMUM_SUPPORTED_VERSION = "8.15.0"; private static MockWebServer embeddingsServer; @@ -59,7 +61,24 @@ public void testHFEmbeddings() throws IOException { final String inferenceId = "mixed-cluster-embeddings"; embeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse())); - put(inferenceId, embeddingConfig(getUrl(embeddingsServer)), TaskType.TEXT_EMBEDDING); + + try { + put(inferenceId, embeddingConfig(getUrl(embeddingsServer)), TaskType.TEXT_EMBEDDING); + } catch (Exception e) { + if (bwcVersion.before(Version.fromString(HF_EMBEDDINGS_CHUNKING_SETTINGS_ADDED))) { + // Chunking settings were added in 8.16.0. if the version is before that, an exception will be thrown if the index mapping + // was created based on a mapping from an old node + assertThat( + e.getMessage(), + containsString( + "One or more nodes in your cluster does not support chunking_settings. " + + "Please update all nodes in your cluster to the latest version to use chunking_settings." + ) + ); + return; + } + } + var configs = (List>) get(TaskType.TEXT_EMBEDDING, inferenceId).get("endpoints"); assertThat(configs, hasSize(1)); assertEquals("hugging_face", configs.get(0).get("service")); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index b84dadc27dd77..9a04edf138996 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -15,6 +15,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -23,6 +24,8 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; @@ -71,11 +74,19 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + GoogleAiStudioModel model = createModel( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, serviceSettingsMap, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST @@ -97,6 +108,7 @@ private static GoogleAiStudioModel createModel( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage, ConfigurationParseContext context @@ -117,6 +129,7 @@ private static GoogleAiStudioModel createModel( NAME, serviceSettings, taskSettings, + chunkingSettings, secretSettings, context ); @@ -135,11 +148,17 @@ public GoogleAiStudioModel parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, secretSettingsMap, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -150,6 +169,7 @@ private static GoogleAiStudioModel createModelFromPersistent( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, Map secretSettings, String failureMessage ) { @@ -158,6 +178,7 @@ private static GoogleAiStudioModel createModelFromPersistent( taskType, serviceSettings, taskSettings, + chunkingSettings, secretSettings, failureMessage, ConfigurationParseContext.PERSISTENT @@ -169,11 +190,17 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, null, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -245,11 +272,22 @@ protected void doChunkedInfer( GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model; var actionCreator = new GoogleAiStudioActionCreator(getSender(), getServiceComponents()); - var batchedRequests = new EmbeddingRequestChunker( - inputs.getInputs(), - EMBEDDING_MAX_BATCH_SIZE, - EmbeddingRequestChunker.EmbeddingType.FLOAT - ).batchRequestsWithListeners(listener); + List batchedRequests; + if (ChunkingSettingsFeatureFlag.isEnabled()) { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT, + googleAiStudioModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + } else { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT + ).batchRequestsWithListeners(listener); + } + for (var request : batchedRequests) { var action = googleAiStudioModel.accept(actionCreator, taskSettings, inputType); action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java index af19e26f3e97a..5d46a8e129dff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java @@ -9,6 +9,7 @@ import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.ModelConfigurations; @@ -38,6 +39,7 @@ public GoogleAiStudioEmbeddingsModel( String service, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, Map secrets, ConfigurationParseContext context ) { @@ -47,6 +49,7 @@ public GoogleAiStudioEmbeddingsModel( service, GoogleAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings, context), EmptyTaskSettings.INSTANCE, + chunkingSettings, DefaultSecretSettings.fromMap(secrets) ); } @@ -62,10 +65,11 @@ public GoogleAiStudioEmbeddingsModel(GoogleAiStudioEmbeddingsModel model, Google String service, GoogleAiStudioEmbeddingsServiceSettings serviceSettings, TaskSettings taskSettings, + ChunkingSettings chunkingSettings, @Nullable DefaultSecretSettings secrets ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secrets), serviceSettings ); @@ -98,6 +102,29 @@ public GoogleAiStudioEmbeddingsModel(GoogleAiStudioEmbeddingsModel model, Google } } + // Should only be used directly for testing + GoogleAiStudioEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + String uri, + GoogleAiStudioEmbeddingsServiceSettings serviceSettings, + TaskSettings taskSettings, + ChunkingSettings chunkingsettings, + @Nullable DefaultSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingsettings), + new ModelSecrets(secrets), + serviceSettings + ); + try { + this.uri = new URI(uri); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + @Override public GoogleAiStudioEmbeddingsServiceSettings getServiceSettings() { return (GoogleAiStudioEmbeddingsServiceSettings) super.getServiceSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java index fb4221c0d5ebc..9f2615ac5c515 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java @@ -9,12 +9,15 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; @@ -26,6 +29,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -52,10 +56,18 @@ public void parseRequestConfig( try { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + var model = createModel( inferenceEntityId, taskType, serviceSettingsMap, + chunkingSettings, serviceSettingsMap, TaskType.unsupportedTaskTypeErrorMsg(taskType, name()), ConfigurationParseContext.REQUEST @@ -80,10 +92,16 @@ public HuggingFaceModel parsePersistedConfigWithSecrets( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModel( inferenceEntityId, taskType, serviceSettingsMap, + chunkingSettings, secretSettingsMap, parsePersistedConfigErrorMsg(inferenceEntityId, name()), ConfigurationParseContext.PERSISTENT @@ -94,10 +112,16 @@ public HuggingFaceModel parsePersistedConfigWithSecrets( public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModel( inferenceEntityId, taskType, serviceSettingsMap, + chunkingSettings, null, parsePersistedConfigErrorMsg(inferenceEntityId, name()), ConfigurationParseContext.PERSISTENT @@ -108,6 +132,7 @@ protected abstract HuggingFaceModel createModel( String inferenceEntityId, TaskType taskType, Map serviceSettings, + ChunkingSettings chunkingSettings, Map secretSettings, String failureMessage, ConfigurationParseContext context diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 6b142edca80aa..14af29bab9078 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -15,11 +15,13 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; @@ -48,6 +50,7 @@ protected HuggingFaceModel createModel( String inferenceEntityId, TaskType taskType, Map serviceSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage, ConfigurationParseContext context @@ -58,6 +61,7 @@ protected HuggingFaceModel createModel( taskType, NAME, serviceSettings, + chunkingSettings, secretSettings, context ); @@ -111,11 +115,22 @@ protected void doChunkedInfer( var huggingFaceModel = (HuggingFaceModel) model; var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents()); - var batchedRequests = new EmbeddingRequestChunker( - inputs.getInputs(), - EMBEDDING_MAX_BATCH_SIZE, - EmbeddingRequestChunker.EmbeddingType.FLOAT - ).batchRequestsWithListeners(listener); + List batchedRequests; + if (ChunkingSettingsFeatureFlag.isEnabled()) { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT, + huggingFaceModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + } else { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT + ).batchRequestsWithListeners(listener); + } + for (var request : batchedRequests) { var action = huggingFaceModel.accept(actionCreator); action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index b9540cab17a9a..a8de51c23831f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -16,6 +16,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -56,6 +57,7 @@ protected HuggingFaceModel createModel( String inferenceEntityId, TaskType taskType, Map serviceSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage, ConfigurationParseContext context diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModel.java index fedd6380d035f..7c4d8094fc213 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModel.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.huggingface.embeddings; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; @@ -26,6 +27,7 @@ public HuggingFaceEmbeddingsModel( TaskType taskType, String service, Map serviceSettings, + ChunkingSettings chunkingSettings, @Nullable Map secrets, ConfigurationParseContext context ) { @@ -34,6 +36,7 @@ public HuggingFaceEmbeddingsModel( taskType, service, HuggingFaceServiceSettings.fromMap(serviceSettings, context), + chunkingSettings, DefaultSecretSettings.fromMap(secrets) ); } @@ -44,10 +47,11 @@ public HuggingFaceEmbeddingsModel( TaskType taskType, String service, HuggingFaceServiceSettings serviceSettings, + ChunkingSettings chunkingSettings, @Nullable DefaultSecretSettings secrets ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, chunkingSettings), new ModelSecrets(secrets), serviceSettings, secrets @@ -60,6 +64,7 @@ public HuggingFaceEmbeddingsModel(HuggingFaceEmbeddingsModel model, HuggingFaceS model.getTaskType(), model.getConfigurations().getService(), serviceSettings, + model.getConfigurations().getChunkingSettings(), model.getSecretSettings() ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 47191ba96cb82..8af817d06fefa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -14,6 +14,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -22,6 +23,8 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.mistral.MistralActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; @@ -84,11 +87,21 @@ protected void doChunkedInfer( var actionCreator = new MistralActionCreator(getSender(), getServiceComponents()); if (model instanceof MistralEmbeddingsModel mistralEmbeddingsModel) { - var batchedRequests = new EmbeddingRequestChunker( - inputs.getInputs(), - MistralConstants.MAX_BATCH_SIZE, - EmbeddingRequestChunker.EmbeddingType.FLOAT - ).batchRequestsWithListeners(listener); + List batchedRequests; + if (ChunkingSettingsFeatureFlag.isEnabled()) { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + MistralConstants.MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT, + mistralEmbeddingsModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + } else { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + MistralConstants.MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT + ).batchRequestsWithListeners(listener); + } for (var request : batchedRequests) { var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings); @@ -115,11 +128,19 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + MistralEmbeddingsModel model = createModel( modelId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, serviceSettingsMap, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST @@ -146,11 +167,17 @@ public Model parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( modelId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, secretSettingsMap, parsePersistedConfigErrorMsg(modelId, NAME) ); @@ -161,11 +188,17 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( modelId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, null, parsePersistedConfigErrorMsg(modelId, NAME) ); @@ -181,12 +214,22 @@ private static MistralEmbeddingsModel createModel( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage, ConfigurationParseContext context ) { if (taskType == TaskType.TEXT_EMBEDDING) { - return new MistralEmbeddingsModel(modelId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context); + return new MistralEmbeddingsModel( + modelId, + taskType, + NAME, + serviceSettings, + taskSettings, + chunkingSettings, + secretSettings, + context + ); } throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); @@ -197,6 +240,7 @@ private MistralEmbeddingsModel createModelFromPersistent( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, Map secretSettings, String failureMessage ) { @@ -205,6 +249,7 @@ private MistralEmbeddingsModel createModelFromPersistent( taskType, serviceSettings, taskSettings, + chunkingSettings, secretSettings, failureMessage, ConfigurationParseContext.PERSISTENT diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java index d883e7c687c20..11f6c456b88cc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.mistral.embeddings; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; @@ -37,6 +38,7 @@ public MistralEmbeddingsModel( String service, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secrets, ConfigurationParseContext context ) { @@ -46,6 +48,7 @@ public MistralEmbeddingsModel( service, MistralEmbeddingsServiceSettings.fromMap(serviceSettings, context), EmptyTaskSettings.INSTANCE, // no task settings for Mistral embeddings + chunkingSettings, DefaultSecretSettings.fromMap(secrets) ); } @@ -61,10 +64,11 @@ public MistralEmbeddingsModel( String service, MistralEmbeddingsServiceSettings serviceSettings, TaskSettings taskSettings, + ChunkingSettings chunkingSettings, DefaultSecretSettings secrets ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings()), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings(), chunkingSettings), new ModelSecrets(secrets) ); setPropertiesFromServiceSettings(serviceSettings); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 8be61e728ff99..58b0e9cf5da6c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -29,6 +30,7 @@ import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; @@ -59,6 +61,8 @@ import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; @@ -152,6 +156,93 @@ public void testParseRequestConfig_CreatesAGoogleAiStudioEmbeddingsModel() throw } } + public void testParseRequestConfig_ThrowsElasticsearchStatusExceptionWhenChunkingSettingsProvidedAndFeatureFlagDisabled() + throws IOException { + var apiKey = "apiKey"; + var modelId = "model"; + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createGoogleAiStudioService()) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, Matchers.instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + new HashMap<>(Map.of()), + createRandomChunkingSettingsMap(), + getSecretSettingsMap(apiKey) + ), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesAGoogleAiStudioEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var apiKey = "apiKey"; + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (GoogleAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey)); + }, e -> fail("Model parsing should have succeeded, but failed: " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + new HashMap<>(Map.of()), + createRandomChunkingSettingsMap(), + getSecretSettingsMap(apiKey) + ), + modelListener + ); + } + } + + public void testParseRequestConfig_CreatesAGoogleAiStudioEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var apiKey = "apiKey"; + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (GoogleAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey)); + }, e -> fail("Model parsing should have succeeded, but failed: " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + new HashMap<>(Map.of()), + getSecretSettingsMap(apiKey) + ), + modelListener + ); + } + } + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { try (var service = createGoogleAiStudioService()) { var failureListener = getModelListenerForException( @@ -296,6 +387,98 @@ public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioEmbeddings } } + public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + getTaskSettingsMapEmpty(), + createRandomChunkingSettingsMap(), + getSecretSettingsMap(apiKey) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (GoogleAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + getTaskSettingsMapEmpty(), + createRandomChunkingSettingsMap(), + getSecretSettingsMap(apiKey) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (GoogleAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), Matchers.instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + getTaskSettingsMapEmpty(), + getSecretSettingsMap(apiKey) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (GoogleAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), Matchers.instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { var modelId = "model"; var apiKey = "apiKey"; @@ -427,6 +610,74 @@ public void testParsePersistedConfig_CreatesAGoogleAiStudioCompletionModel() thr } } + public void testParsePersistedConfig_CreatesAGoogleAiEmbeddingsModelWithoutChunkingSettingsWhenChunkingSettingsFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + getTaskSettingsMapEmpty(), + createRandomChunkingSettingsMap() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (GoogleAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_CreatesAGoogleAiStudioEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + getTaskSettingsMapEmpty(), + createRandomChunkingSettingsMap() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (GoogleAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_CreatesAGoogleAiStudioEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), getTaskSettingsMapEmpty()); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (GoogleAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertNull(embeddingsModel.getSecretSettings()); + } + } + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { var modelId = "model"; @@ -666,6 +917,22 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { public void testChunkedInfer_Batches() throws IOException { var modelId = "modelId"; var apiKey = "apiKey"; + var model = GoogleAiStudioEmbeddingsModelTests.createModel(modelId, apiKey, getUrl(webServer)); + + testChunkedInfer(modelId, apiKey, model); + } + + public void testChunkedInfer_ChunkingSettingsSetAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var modelId = "modelId"; + var apiKey = "apiKey"; + var model = GoogleAiStudioEmbeddingsModelTests.createModel(modelId, createRandomChunkingSettings(), apiKey, getUrl(webServer)); + + testChunkedInfer(modelId, apiKey, model); + } + + private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbeddingsModel model) throws IOException { + var input = List.of("foo", "bar"); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -692,7 +959,6 @@ public void testChunkedInfer_Batches() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = GoogleAiStudioEmbeddingsModelTests.createModel(modelId, apiKey, getUrl(webServer)); PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, @@ -966,6 +1232,18 @@ private static ActionListener getModelListenerForException(Class excep }); } + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, taskSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + private Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModelTests.java index 5ea9bbfc9d970..32bd95c954292 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModelTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -29,6 +30,19 @@ public static GoogleAiStudioEmbeddingsModel createModel(String model, String api ); } + public static GoogleAiStudioEmbeddingsModel createModel(String model, ChunkingSettings chunkingSettings, String apiKey, String url) { + return new GoogleAiStudioEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "service", + url, + new GoogleAiStudioEmbeddingsServiceSettings(model, null, null, SimilarityMeasure.DOT_PRODUCT, null), + EmptyTaskSettings.INSTANCE, + chunkingSettings, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + public static GoogleAiStudioEmbeddingsModel createModel( String url, String model, @@ -59,6 +73,7 @@ public static GoogleAiStudioEmbeddingsModel createModel( "service", new GoogleAiStudioEmbeddingsServiceSettings(model, tokenLimit, dimensions, SimilarityMeasure.DOT_PRODUCT, null), EmptyTaskSettings.INSTANCE, + null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java index 168110ae8f7c7..27b90f32bef2d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java @@ -13,13 +13,11 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.junit.After; import org.junit.Before; @@ -27,7 +25,6 @@ import java.io.IOException; import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; @@ -105,17 +102,5 @@ public String name() { public TransportVersion getMinimalSupportedVersion() { return TransportVersion.current(); } - - @Override - protected HuggingFaceModel createModel( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - Map secretSettings, - String failureMessage, - ConfigurationParseContext context - ) { - return null; - } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 0faa8d8e9056e..cf613b2337706 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -28,6 +29,7 @@ import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; @@ -56,6 +58,7 @@ import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; @@ -63,6 +66,7 @@ import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettingsTests.getServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -108,6 +112,72 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException } } + public void testParseRequestConfig_ThrowsElasticsearchStatusExceptionWhenChunkingSettingsProvidedAndFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createHuggingFaceService()) { + var config = getRequestConfigMap(getServiceSettingsMap("url"), getSecretSettingsMap("secret")); + config.put("extra_key", "value"); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap(getServiceSettingsMap("url"), createRandomChunkingSettingsMap(), getSecretSettingsMap("secret")), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createHuggingFaceService()) { + ActionListener modelVerificationActionListener = ActionListener.wrap((model) -> { + assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); + + var embeddingsModel = (HuggingFaceEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, (e) -> fail("parse request should not fail " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap(getServiceSettingsMap("url"), createRandomChunkingSettingsMap(), getSecretSettingsMap("secret")), + modelVerificationActionListener + ); + } + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createHuggingFaceService()) { + ActionListener modelVerificationActionListener = ActionListener.wrap((model) -> { + assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); + + var embeddingsModel = (HuggingFaceEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, (e) -> fail("parse request should not fail " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap(getServiceSettingsMap("url"), getSecretSettingsMap("secret")), + modelVerificationActionListener + ); + } + } + public void testParseRequestConfig_CreatesAnElserModel() throws IOException { try (var service = createHuggingFaceService()) { ActionListener modelVerificationActionListener = ActionListener.wrap((model) -> { @@ -210,6 +280,82 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() throw } } + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createHuggingFaceService()) { + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap("url"), + new HashMap<>(), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); + + var embeddingsModel = (HuggingFaceEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createHuggingFaceService()) { + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap("url"), + new HashMap<>(), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); + + var embeddingsModel = (HuggingFaceEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createHuggingFaceService()) { + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), new HashMap<>(), getSecretSettingsMap("secret")); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); + + var embeddingsModel = (HuggingFaceEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + public void testParsePersistedConfigWithSecrets_CreatesAnElserModel() throws IOException { try (var service = createHuggingFaceService()) { var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), new HashMap<>(), getSecretSettingsMap("secret")); @@ -349,6 +495,55 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModel() throws IOExcepti } } + public void testParsePersistedConfig_CreatesAnEmbeddingsModelWithoutChunkingSettingsWhenChunkingSettingsFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createHuggingFaceService()) { + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), createRandomChunkingSettingsMap()); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); + + var embeddingsModel = (HuggingFaceEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createHuggingFaceService()) { + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), createRandomChunkingSettingsMap()); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); + + var embeddingsModel = (HuggingFaceEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createHuggingFaceService()) { + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url")); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); + + var embeddingsModel = (HuggingFaceEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertNull(embeddingsModel.getSecretSettings()); + } + } + public void testParsePersistedConfig_CreatesAnElserModel() throws IOException { try (var service = createHuggingFaceService()) { var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), new HashMap<>()); @@ -745,6 +940,17 @@ private HuggingFaceService createHuggingFaceService() { return new HuggingFaceService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); } + private Map getRequestConfigMap( + Map serviceSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + private Map getRequestConfigMap(Map serviceSettings, Map secretSettings) { var builtServiceSettings = new HashMap<>(); builtServiceSettings.putAll(serviceSettings); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java index baf5467d8fe06..b81eabb20edc3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java @@ -31,6 +31,7 @@ public static HuggingFaceEmbeddingsModel createModel(String url, String apiKey) TaskType.TEXT_EMBEDDING, "service", new HuggingFaceServiceSettings(url), + null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } @@ -41,6 +42,7 @@ public static HuggingFaceEmbeddingsModel createModel(String url, String apiKey, TaskType.TEXT_EMBEDDING, "service", new HuggingFaceServiceSettings(createUri(url), null, null, tokenLimit, null), + null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } @@ -51,6 +53,7 @@ public static HuggingFaceEmbeddingsModel createModel(String url, String apiKey, TaskType.TEXT_EMBEDDING, "service", new HuggingFaceServiceSettings(createUri(url), null, dimensions, tokenLimit, null), + null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } @@ -67,6 +70,7 @@ public static HuggingFaceEmbeddingsModel createModel( TaskType.TEXT_EMBEDDING, "service", new HuggingFaceServiceSettings(createUri(url), similarityMeasure, dimensions, tokenLimit, null), + null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 19c96dd0e174f..d866424c6ec67 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -28,6 +29,7 @@ import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.ModelConfigurationsTests; @@ -47,7 +49,6 @@ import org.junit.Before; import java.io.IOException; -import java.net.URISyntaxException; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -58,6 +59,8 @@ import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; @@ -118,6 +121,87 @@ public void testParseRequestConfig_CreatesAMistralEmbeddingsModel() throws IOExc } } + public void testParseRequestConfig_ThrowsElasticsearchStatusExceptionWhenChunkingSettingsProvidedAndFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), + getEmbeddingsTaskSettingsMap(), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesAMistralEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + var serviceSettings = (MistralEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(serviceSettings.modelId(), is("mistral-embed")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), + getEmbeddingsTaskSettingsMap(), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesAMistralEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + var serviceSettings = (MistralEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(serviceSettings.modelId(), is("mistral-embed")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), + getEmbeddingsTaskSettingsMap(), + getSecretSettingsMap("secret") + ), + modelVerificationListener + ); + } + } + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { try (var service = createService()) { ActionListener modelVerificationListener = ActionListener.wrap( @@ -237,6 +321,76 @@ public void testParsePersistedConfig_CreatesAMistralEmbeddingsModel() throws IOE } } + public void testParsePersistedConfig_CreatesAMistralEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), + getEmbeddingsTaskSettingsMap(), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is("mistral-embed")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfig_CreatesAMistralEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), + getEmbeddingsTaskSettingsMap(), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is("mistral-embed")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfig_CreatesAMistralEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), + getEmbeddingsTaskSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is("mistral-embed")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + public void testParsePersistedConfig_ThrowsUnsupportedModelType() throws IOException { try (var service = createService()) { ActionListener modelVerificationListener = ActionListener.wrap( @@ -358,6 +512,74 @@ public void testParsePersistedConfig_WithoutSecretsCreatesEmbeddingsModel() thro } } + public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), + getEmbeddingsTaskSettingsMap(), + createRandomChunkingSettingsMap(), + Map.of() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config()); + + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is("mistral-embed")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + } + } + + public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), + getEmbeddingsTaskSettingsMap(), + createRandomChunkingSettingsMap(), + Map.of() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config()); + + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is("mistral-embed")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + + public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), + getEmbeddingsTaskSettingsMap(), + Map.of() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config()); + + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is("mistral-embed")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + public void testCheckModelConfig_ForEmbeddingsModel_Works() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -464,7 +686,31 @@ public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws I verifyNoMoreInteractions(sender); } - public void testChunkedInfer_Embeddings_CallsInfer_ConvertsFloatResponse() throws IOException, URISyntaxException { + public void testChunkedInfer_Embeddings_CallsInfer_ConvertsFloatResponse() throws IOException { + var model = MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey", null, null, null, null); + model.setURI(getUrl(webServer)); + + testChunkedInfer(model); + } + + public void testChunkedInfer_ChunkingSettingsSetAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var model = MistralEmbeddingModelTests.createModel( + "id", + "mistral-embed", + createRandomChunkingSettings(), + "apikey", + null, + null, + null, + null + ); + model.setURI(getUrl(webServer)); + + testChunkedInfer(model); + } + + public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { @@ -499,9 +745,6 @@ public void testChunkedInfer_Embeddings_CallsInfer_ConvertsFloatResponse() throw """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey", null, null, null, null); - model.setURI(getUrl(webServer)); - PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, @@ -587,6 +830,18 @@ private MistralService createService() { return new MistralService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); } + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, taskSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + private Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java index 0fe8723664c6e..6f8b40fd7f19c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -24,6 +25,7 @@ public static MistralEmbeddingsModel createModel(String inferenceId, String mode public static MistralEmbeddingsModel createModel( String inferenceId, String model, + ChunkingSettings chunkingSettings, String apiKey, @Nullable Integer dimensions, @Nullable Integer maxTokens, @@ -36,6 +38,27 @@ public static MistralEmbeddingsModel createModel( "mistral", new MistralEmbeddingsServiceSettings(model, dimensions, maxTokens, similarity, rateLimitSettings), EmptyTaskSettings.INSTANCE, + chunkingSettings, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static MistralEmbeddingsModel createModel( + String inferenceId, + String model, + String apiKey, + @Nullable Integer dimensions, + @Nullable Integer maxTokens, + @Nullable SimilarityMeasure similarity, + RateLimitSettings rateLimitSettings + ) { + return new MistralEmbeddingsModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + "mistral", + new MistralEmbeddingsServiceSettings(model, dimensions, maxTokens, similarity, rateLimitSettings), + EmptyTaskSettings.INSTANCE, + null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } From 374f448722f556e060a5a4372df177664c8d4319 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 7 Oct 2024 17:21:48 -0400 Subject: [PATCH 11/22] [ML] Deploy default on chunked infer (#114158) Co-authored-by: Elastic Machine --- .../ElasticsearchInternalService.java | 41 +++++++++++-------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index dd14e16412996..9b4c0e50bdebe 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -608,26 +608,33 @@ public void chunkedInfer( return; } - var configUpdate = chunkingOptions != null - ? new TokenizationConfigUpdate(chunkingOptions.windowSize(), chunkingOptions.span()) - : new TokenizationConfigUpdate(null, null); + if (model instanceof ElasticsearchInternalModel esModel) { - var request = buildInferenceRequest( - model.getConfigurations().getInferenceEntityId(), - configUpdate, - input, - inputType, - timeout, - true - ); + var configUpdate = chunkingOptions != null + ? new TokenizationConfigUpdate(chunkingOptions.windowSize(), chunkingOptions.span()) + : new TokenizationConfigUpdate(null, null); + + var request = buildInferenceRequest( + model.getConfigurations().getInferenceEntityId(), + configUpdate, + input, + inputType, + timeout, + true + ); - client.execute( - InferModelAction.INSTANCE, - request, - listener.delegateFailureAndWrap( + ActionListener mlResultsListener = listener.delegateFailureAndWrap( (l, inferenceResult) -> l.onResponse(translateToChunkedResults(inferenceResult.getInferenceResults())) - ) - ); + ); + + var maybeDeployListener = mlResultsListener.delegateResponse( + (l, exception) -> maybeStartDeployment(esModel, exception, request, mlResultsListener) + ); + + client.execute(InferModelAction.INSTANCE, request, maybeDeployListener); + } else { + listener.onFailure(notElasticsearchModelException(model)); + } } private static List translateToChunkedResults(List inferenceResults) { From 586ea41217d6ee39d933cc1d9827bbb8cbb6b807 Mon Sep 17 00:00:00 2001 From: Joe Gallo Date: Mon, 7 Oct 2024 18:23:45 -0300 Subject: [PATCH 12/22] Add postal_code support to the City and Enterprise databases (#114193) (#114263) --- .../org/elasticsearch/ingest/geoip/Database.java | 9 ++++++--- .../ingest/geoip/MaxmindIpDataLookups.java | 13 +++++++++++++ .../ingest/geoip/GeoIpProcessorFactoryTests.java | 3 ++- .../ingest/geoip/GeoIpProcessorTests.java | 6 ++++-- .../ingest/geoip/MaxMindSupportTests.java | 6 +++--- 5 files changed, 28 insertions(+), 9 deletions(-) diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java index 31d7a43e3869a..10817c920e1e5 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java @@ -40,7 +40,8 @@ enum Database { Property.REGION_NAME, Property.CITY_NAME, Property.TIMEZONE, - Property.LOCATION + Property.LOCATION, + Property.POSTAL_CODE ), Set.of( Property.COUNTRY_ISO_CODE, @@ -108,7 +109,8 @@ enum Database { Property.MOBILE_COUNTRY_CODE, Property.MOBILE_NETWORK_CODE, Property.USER_TYPE, - Property.CONNECTION_TYPE + Property.CONNECTION_TYPE, + Property.POSTAL_CODE ), Set.of( Property.COUNTRY_ISO_CODE, @@ -228,7 +230,8 @@ enum Property { MOBILE_NETWORK_CODE, CONNECTION_TYPE, USER_TYPE, - TYPE; + TYPE, + POSTAL_CODE; /** * Parses a string representation of a property into an actual Property instance. Not all properties that exist are diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java index 5b22b3f4005a9..2e0d4a031a072 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java @@ -23,6 +23,7 @@ import com.maxmind.geoip2.model.IspResponse; import com.maxmind.geoip2.record.Continent; import com.maxmind.geoip2.record.Location; +import com.maxmind.geoip2.record.Postal; import com.maxmind.geoip2.record.Subdivision; import org.elasticsearch.common.network.InetAddresses; @@ -139,6 +140,7 @@ protected Map transform(final CityResponse response) { Location location = response.getLocation(); Continent continent = response.getContinent(); Subdivision subdivision = response.getMostSpecificSubdivision(); + Postal postal = response.getPostal(); Map data = new HashMap<>(); for (Database.Property property : this.properties) { @@ -206,6 +208,11 @@ protected Map transform(final CityResponse response) { data.put("location", locationObject); } } + case POSTAL_CODE -> { + if (postal != null && postal.getCode() != null) { + data.put("postal_code", postal.getCode()); + } + } } } return data; @@ -324,6 +331,7 @@ protected Map transform(final EnterpriseResponse response) { Location location = response.getLocation(); Continent continent = response.getContinent(); Subdivision subdivision = response.getMostSpecificSubdivision(); + Postal postal = response.getPostal(); Long asn = response.getTraits().getAutonomousSystemNumber(); String organizationName = response.getTraits().getAutonomousSystemOrganization(); @@ -413,6 +421,11 @@ protected Map transform(final EnterpriseResponse response) { data.put("location", locationObject); } } + case POSTAL_CODE -> { + if (postal != null && postal.getCode() != null) { + data.put("postal_code", postal.getCode()); + } + } case ASN -> { if (asn != null) { data.put("asn", asn); diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java index 9972db26b3642..d4017268b53db 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java @@ -274,7 +274,8 @@ public void testBuildIllegalFieldOption() { e.getMessage(), equalTo( "[properties] illegal property value [invalid]. valid values are [IP, COUNTRY_ISO_CODE, " - + "COUNTRY_NAME, CONTINENT_CODE, CONTINENT_NAME, REGION_ISO_CODE, REGION_NAME, CITY_NAME, TIMEZONE, LOCATION]" + + "COUNTRY_NAME, CONTINENT_CODE, CONTINENT_NAME, REGION_ISO_CODE, REGION_NAME, CITY_NAME, TIMEZONE, " + + "LOCATION, POSTAL_CODE]" ) ); diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java index 46024cb6ad215..3fb082f33b3f6 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java @@ -222,7 +222,7 @@ public void testCity_withIpV6() throws Exception { @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); assertThat(geoData, notNullValue()); - assertThat(geoData.size(), equalTo(10)); + assertThat(geoData.size(), equalTo(11)); assertThat(geoData.get("ip"), equalTo(ip)); assertThat(geoData.get("country_iso_code"), equalTo("US")); assertThat(geoData.get("country_name"), equalTo("United States")); @@ -233,6 +233,7 @@ public void testCity_withIpV6() throws Exception { assertThat(geoData.get("city_name"), equalTo("Homestead")); assertThat(geoData.get("timezone"), equalTo("America/New_York")); assertThat(geoData.get("location"), equalTo(Map.of("lat", 25.4573d, "lon", -80.4572d))); + assertThat(geoData.get("postal_code"), equalTo("33035")); } public void testCityWithMissingLocation() throws Exception { @@ -470,7 +471,7 @@ public void testEnterprise() throws Exception { @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); assertThat(geoData, notNullValue()); - assertThat(geoData.size(), equalTo(24)); + assertThat(geoData.size(), equalTo(25)); assertThat(geoData.get("ip"), equalTo(ip)); assertThat(geoData.get("country_iso_code"), equalTo("US")); assertThat(geoData.get("country_name"), equalTo("United States")); @@ -481,6 +482,7 @@ public void testEnterprise() throws Exception { assertThat(geoData.get("city_name"), equalTo("Chatham")); assertThat(geoData.get("timezone"), equalTo("America/New_York")); assertThat(geoData.get("location"), equalTo(Map.of("lat", 42.3478, "lon", -73.5549))); + assertThat(geoData.get("postal_code"), equalTo("12037")); assertThat(geoData.get("asn"), equalTo(14671L)); assertThat(geoData.get("organization_name"), equalTo("FairPoint Communications")); assertThat(geoData.get("network"), equalTo("74.209.16.0/20")); diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java index 3b12003637783..7a3de6ca199aa 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java @@ -84,7 +84,8 @@ public class MaxMindSupportTests extends ESTestCase { "location.longitude", "location.timeZone", "mostSpecificSubdivision.isoCode", - "mostSpecificSubdivision.name" + "mostSpecificSubdivision.name", + "postal.code" ); private static final Set CITY_UNSUPPORTED_FIELDS = Set.of( "city.confidence", @@ -109,7 +110,6 @@ public class MaxMindSupportTests extends ESTestCase { "mostSpecificSubdivision.confidence", "mostSpecificSubdivision.geoNameId", "mostSpecificSubdivision.names", - "postal.code", "postal.confidence", "registeredCountry.confidence", "registeredCountry.geoNameId", @@ -223,6 +223,7 @@ public class MaxMindSupportTests extends ESTestCase { "location.timeZone", "mostSpecificSubdivision.isoCode", "mostSpecificSubdivision.name", + "postal.code", "traits.anonymous", "traits.anonymousVpn", "traits.autonomousSystemNumber", @@ -263,7 +264,6 @@ public class MaxMindSupportTests extends ESTestCase { "mostSpecificSubdivision.confidence", "mostSpecificSubdivision.geoNameId", "mostSpecificSubdivision.names", - "postal.code", "postal.confidence", "registeredCountry.confidence", "registeredCountry.geoNameId", From 69a23d41e47c0856e0f45282a0b1c3bda71f467c Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Mon, 7 Oct 2024 18:17:53 -0400 Subject: [PATCH 13/22] ESQL: Reenable part of heap attack test (#114252) (#114255) This reenables a test and adds more debugging to another one. We'll use this to collect more information the next time it fails. --- .../elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java index e45ac8a9e0f70..a24bd91206ac0 100644 --- a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java +++ b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java @@ -355,7 +355,6 @@ public void testManyEval() throws IOException { assertMap(map, mapMatcher.entry("columns", columns).entry("values", hasSize(10_000)).entry("took", greaterThanOrEqualTo(0))); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch-serverless/issues/1874") public void testTooManyEval() throws IOException { initManyLongs(); assertCircuitBreaks(() -> manyEval(490)); @@ -616,14 +615,13 @@ private void initMvLongsIndex(int docs, int fields, int fieldValues) throws IOEx private void bulk(String name, String bulk) throws IOException { Request request = new Request("POST", "/" + name + "/_bulk"); - request.addParameter("filter_path", "errors"); request.setJsonEntity(bulk); request.setOptions( RequestOptions.DEFAULT.toBuilder() .setRequestConfig(RequestConfig.custom().setSocketTimeout(Math.toIntExact(TimeValue.timeValueMinutes(5).millis())).build()) ); Response response = client().performRequest(request); - assertThat(EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8), equalTo("{\"errors\":false}")); + assertThat(entityAsMap(response), matchesMap().entry("errors", false).extraOk()); } private void initIndex(String name, String bulk) throws IOException { From 204950ad06d8d7660c46ddb68fbccf4cf26fc990 Mon Sep 17 00:00:00 2001 From: Keith Massey Date: Mon, 7 Oct 2024 21:26:26 -0500 Subject: [PATCH 14/22] Supporting more maxmind fields in the geoip processor (#114268) (#114270) --- .../elasticsearch/ingest/geoip/Database.java | 27 +++++++++-- .../ingest/geoip/MaxmindIpDataLookups.java | 48 +++++++++++++++++++ .../geoip/GeoIpProcessorFactoryTests.java | 6 +-- .../ingest/geoip/GeoIpProcessorTests.java | 17 +++++-- .../ingest/geoip/MaxMindSupportTests.java | 16 +++---- 5 files changed, 95 insertions(+), 19 deletions(-) diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java index 10817c920e1e5..128c16e163764 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java @@ -32,6 +32,7 @@ enum Database { City( Set.of( Property.IP, + Property.COUNTRY_IN_EUROPEAN_UNION, Property.COUNTRY_ISO_CODE, Property.CONTINENT_CODE, Property.COUNTRY_NAME, @@ -41,7 +42,8 @@ enum Database { Property.CITY_NAME, Property.TIMEZONE, Property.LOCATION, - Property.POSTAL_CODE + Property.POSTAL_CODE, + Property.ACCURACY_RADIUS ), Set.of( Property.COUNTRY_ISO_CODE, @@ -54,7 +56,14 @@ enum Database { ) ), Country( - Set.of(Property.IP, Property.CONTINENT_CODE, Property.CONTINENT_NAME, Property.COUNTRY_NAME, Property.COUNTRY_ISO_CODE), + Set.of( + Property.IP, + Property.CONTINENT_CODE, + Property.CONTINENT_NAME, + Property.COUNTRY_NAME, + Property.COUNTRY_IN_EUROPEAN_UNION, + Property.COUNTRY_ISO_CODE + ), Set.of(Property.CONTINENT_NAME, Property.COUNTRY_NAME, Property.COUNTRY_ISO_CODE) ), Asn( @@ -85,12 +94,15 @@ enum Database { Enterprise( Set.of( Property.IP, + Property.COUNTRY_CONFIDENCE, + Property.COUNTRY_IN_EUROPEAN_UNION, Property.COUNTRY_ISO_CODE, Property.COUNTRY_NAME, Property.CONTINENT_CODE, Property.CONTINENT_NAME, Property.REGION_ISO_CODE, Property.REGION_NAME, + Property.CITY_CONFIDENCE, Property.CITY_NAME, Property.TIMEZONE, Property.LOCATION, @@ -110,7 +122,9 @@ enum Database { Property.MOBILE_NETWORK_CODE, Property.USER_TYPE, Property.CONNECTION_TYPE, - Property.POSTAL_CODE + Property.POSTAL_CODE, + Property.POSTAL_CONFIDENCE, + Property.ACCURACY_RADIUS ), Set.of( Property.COUNTRY_ISO_CODE, @@ -205,12 +219,15 @@ public Set parseProperties(@Nullable final List propertyNames) enum Property { IP, + COUNTRY_CONFIDENCE, + COUNTRY_IN_EUROPEAN_UNION, COUNTRY_ISO_CODE, COUNTRY_NAME, CONTINENT_CODE, CONTINENT_NAME, REGION_ISO_CODE, REGION_NAME, + CITY_CONFIDENCE, CITY_NAME, TIMEZONE, LOCATION, @@ -231,7 +248,9 @@ enum Property { CONNECTION_TYPE, USER_TYPE, TYPE, - POSTAL_CODE; + POSTAL_CODE, + POSTAL_CONFIDENCE, + ACCURACY_RADIUS; /** * Parses a string representation of a property into an actual Property instance. Not all properties that exist are diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java index 2e0d4a031a072..e7c3481938033 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java @@ -146,6 +146,12 @@ protected Map transform(final CityResponse response) { for (Database.Property property : this.properties) { switch (property) { case IP -> data.put("ip", response.getTraits().getIpAddress()); + case COUNTRY_IN_EUROPEAN_UNION -> { + if (country.getIsoCode() != null) { + // isInEuropeanUnion is a boolean so it can't be null. But it really only makes sense if we have a country + data.put("country_in_european_union", country.isInEuropeanUnion()); + } + } case COUNTRY_ISO_CODE -> { String countryIsoCode = country.getIsoCode(); if (countryIsoCode != null) { @@ -208,6 +214,12 @@ protected Map transform(final CityResponse response) { data.put("location", locationObject); } } + case ACCURACY_RADIUS -> { + Integer accuracyRadius = location.getAccuracyRadius(); + if (accuracyRadius != null) { + data.put("accuracy_radius", accuracyRadius); + } + } case POSTAL_CODE -> { if (postal != null && postal.getCode() != null) { data.put("postal_code", postal.getCode()); @@ -261,6 +273,12 @@ protected Map transform(final CountryResponse response) { for (Database.Property property : this.properties) { switch (property) { case IP -> data.put("ip", response.getTraits().getIpAddress()); + case COUNTRY_IN_EUROPEAN_UNION -> { + if (country.getIsoCode() != null) { + // isInEuropeanUnion is a boolean so it can't be null. But it really only makes sense if we have a country + data.put("country_in_european_union", country.isInEuropeanUnion()); + } + } case COUNTRY_ISO_CODE -> { String countryIsoCode = country.getIsoCode(); if (countryIsoCode != null) { @@ -359,6 +377,18 @@ protected Map transform(final EnterpriseResponse response) { for (Database.Property property : this.properties) { switch (property) { case IP -> data.put("ip", response.getTraits().getIpAddress()); + case COUNTRY_CONFIDENCE -> { + Integer countryConfidence = country.getConfidence(); + if (countryConfidence != null) { + data.put("country_confidence", countryConfidence); + } + } + case COUNTRY_IN_EUROPEAN_UNION -> { + if (country.getIsoCode() != null) { + // isInEuropeanUnion is a boolean so it can't be null. But it really only makes sense if we have a country + data.put("country_in_european_union", country.isInEuropeanUnion()); + } + } case COUNTRY_ISO_CODE -> { String countryIsoCode = country.getIsoCode(); if (countryIsoCode != null) { @@ -399,6 +429,12 @@ protected Map transform(final EnterpriseResponse response) { data.put("region_name", subdivisionName); } } + case CITY_CONFIDENCE -> { + Integer cityConfidence = city.getConfidence(); + if (cityConfidence != null) { + data.put("city_confidence", cityConfidence); + } + } case CITY_NAME -> { String cityName = city.getName(); if (cityName != null) { @@ -421,11 +457,23 @@ protected Map transform(final EnterpriseResponse response) { data.put("location", locationObject); } } + case ACCURACY_RADIUS -> { + Integer accuracyRadius = location.getAccuracyRadius(); + if (accuracyRadius != null) { + data.put("accuracy_radius", accuracyRadius); + } + } case POSTAL_CODE -> { if (postal != null && postal.getCode() != null) { data.put("postal_code", postal.getCode()); } } + case POSTAL_CONFIDENCE -> { + Integer postalConfidence = postal.getConfidence(); + if (postalConfidence != null) { + data.put("postal_confidence", postalConfidence); + } + } case ASN -> { if (asn != null) { data.put("asn", asn); diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java index d4017268b53db..cfea54d2520bd 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java @@ -195,7 +195,7 @@ public void testBuildWithCountryDbAndAsnFields() { equalTo( "[properties] illegal property value [" + asnProperty - + "]. valid values are [IP, COUNTRY_ISO_CODE, COUNTRY_NAME, CONTINENT_CODE, CONTINENT_NAME]" + + "]. valid values are [IP, COUNTRY_IN_EUROPEAN_UNION, COUNTRY_ISO_CODE, COUNTRY_NAME, CONTINENT_CODE, CONTINENT_NAME]" ) ); } @@ -273,9 +273,9 @@ public void testBuildIllegalFieldOption() { assertThat( e.getMessage(), equalTo( - "[properties] illegal property value [invalid]. valid values are [IP, COUNTRY_ISO_CODE, " + "[properties] illegal property value [invalid]. valid values are [IP, COUNTRY_IN_EUROPEAN_UNION, COUNTRY_ISO_CODE, " + "COUNTRY_NAME, CONTINENT_CODE, CONTINENT_NAME, REGION_ISO_CODE, REGION_NAME, CITY_NAME, TIMEZONE, " - + "LOCATION, POSTAL_CODE]" + + "LOCATION, POSTAL_CODE, ACCURACY_RADIUS]" ) ); diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java index 3fb082f33b3f6..ffc40324bd886 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java @@ -106,8 +106,9 @@ public void testCity() throws Exception { @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); assertThat(geoData, notNullValue()); - assertThat(geoData.size(), equalTo(7)); + assertThat(geoData.size(), equalTo(9)); assertThat(geoData.get("ip"), equalTo(ip)); + assertThat(geoData.get("country_in_european_union"), equalTo(false)); assertThat(geoData.get("country_iso_code"), equalTo("US")); assertThat(geoData.get("country_name"), equalTo("United States")); assertThat(geoData.get("continent_code"), equalTo("NA")); @@ -222,8 +223,9 @@ public void testCity_withIpV6() throws Exception { @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); assertThat(geoData, notNullValue()); - assertThat(geoData.size(), equalTo(11)); + assertThat(geoData.size(), equalTo(13)); assertThat(geoData.get("ip"), equalTo(ip)); + assertThat(geoData.get("country_in_european_union"), equalTo(false)); assertThat(geoData.get("country_iso_code"), equalTo("US")); assertThat(geoData.get("country_name"), equalTo("United States")); assertThat(geoData.get("continent_code"), equalTo("NA")); @@ -233,6 +235,7 @@ public void testCity_withIpV6() throws Exception { assertThat(geoData.get("city_name"), equalTo("Homestead")); assertThat(geoData.get("timezone"), equalTo("America/New_York")); assertThat(geoData.get("location"), equalTo(Map.of("lat", 25.4573d, "lon", -80.4572d))); + assertThat(geoData.get("accuracy_radius"), equalTo(50)); assertThat(geoData.get("postal_code"), equalTo("33035")); } @@ -288,8 +291,9 @@ public void testCountry() throws Exception { @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); assertThat(geoData, notNullValue()); - assertThat(geoData.size(), equalTo(5)); + assertThat(geoData.size(), equalTo(6)); assertThat(geoData.get("ip"), equalTo(ip)); + assertThat(geoData.get("country_in_european_union"), equalTo(true)); assertThat(geoData.get("country_iso_code"), equalTo("NL")); assertThat(geoData.get("country_name"), equalTo("Netherlands")); assertThat(geoData.get("continent_code"), equalTo("EU")); @@ -471,18 +475,23 @@ public void testEnterprise() throws Exception { @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); assertThat(geoData, notNullValue()); - assertThat(geoData.size(), equalTo(25)); + assertThat(geoData.size(), equalTo(30)); assertThat(geoData.get("ip"), equalTo(ip)); + assertThat(geoData.get("country_confidence"), equalTo(99)); + assertThat(geoData.get("country_in_european_union"), equalTo(false)); assertThat(geoData.get("country_iso_code"), equalTo("US")); assertThat(geoData.get("country_name"), equalTo("United States")); assertThat(geoData.get("continent_code"), equalTo("NA")); assertThat(geoData.get("continent_name"), equalTo("North America")); assertThat(geoData.get("region_iso_code"), equalTo("US-NY")); assertThat(geoData.get("region_name"), equalTo("New York")); + assertThat(geoData.get("city_confidence"), equalTo(11)); assertThat(geoData.get("city_name"), equalTo("Chatham")); assertThat(geoData.get("timezone"), equalTo("America/New_York")); assertThat(geoData.get("location"), equalTo(Map.of("lat", 42.3478, "lon", -73.5549))); + assertThat(geoData.get("accuracy_radius"), equalTo(27)); assertThat(geoData.get("postal_code"), equalTo("12037")); + assertThat(geoData.get("city_confidence"), equalTo(11)); assertThat(geoData.get("asn"), equalTo(14671L)); assertThat(geoData.get("organization_name"), equalTo("FairPoint Communications")); assertThat(geoData.get("network"), equalTo("74.209.16.0/20")); diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java index 7a3de6ca199aa..1e05cf2b3ba33 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java @@ -78,8 +78,10 @@ public class MaxMindSupportTests extends ESTestCase { "city.name", "continent.code", "continent.name", + "country.inEuropeanUnion", "country.isoCode", "country.name", + "location.accuracyRadius", "location.latitude", "location.longitude", "location.timeZone", @@ -95,14 +97,12 @@ public class MaxMindSupportTests extends ESTestCase { "continent.names", "country.confidence", "country.geoNameId", - "country.inEuropeanUnion", "country.names", "leastSpecificSubdivision.confidence", "leastSpecificSubdivision.geoNameId", "leastSpecificSubdivision.isoCode", "leastSpecificSubdivision.name", "leastSpecificSubdivision.names", - "location.accuracyRadius", "location.averageIncome", "location.metroCode", "location.populationDensity", @@ -159,6 +159,7 @@ public class MaxMindSupportTests extends ESTestCase { private static final Set COUNTRY_SUPPORTED_FIELDS = Set.of( "continent.name", + "country.inEuropeanUnion", "country.isoCode", "continent.code", "country.name" @@ -168,7 +169,6 @@ public class MaxMindSupportTests extends ESTestCase { "continent.names", "country.confidence", "country.geoNameId", - "country.inEuropeanUnion", "country.names", "maxMind", "registeredCountry.confidence", @@ -213,17 +213,22 @@ public class MaxMindSupportTests extends ESTestCase { private static final Set DOMAIN_UNSUPPORTED_FIELDS = Set.of("ipAddress", "network"); private static final Set ENTERPRISE_SUPPORTED_FIELDS = Set.of( + "city.confidence", "city.name", "continent.code", "continent.name", + "country.confidence", + "country.inEuropeanUnion", "country.isoCode", "country.name", + "location.accuracyRadius", "location.latitude", "location.longitude", "location.timeZone", "mostSpecificSubdivision.isoCode", "mostSpecificSubdivision.name", "postal.code", + "postal.confidence", "traits.anonymous", "traits.anonymousVpn", "traits.autonomousSystemNumber", @@ -242,21 +247,17 @@ public class MaxMindSupportTests extends ESTestCase { "traits.userType" ); private static final Set ENTERPRISE_UNSUPPORTED_FIELDS = Set.of( - "city.confidence", "city.geoNameId", "city.names", "continent.geoNameId", "continent.names", - "country.confidence", "country.geoNameId", - "country.inEuropeanUnion", "country.names", "leastSpecificSubdivision.confidence", "leastSpecificSubdivision.geoNameId", "leastSpecificSubdivision.isoCode", "leastSpecificSubdivision.name", "leastSpecificSubdivision.names", - "location.accuracyRadius", "location.averageIncome", "location.metroCode", "location.populationDensity", @@ -264,7 +265,6 @@ public class MaxMindSupportTests extends ESTestCase { "mostSpecificSubdivision.confidence", "mostSpecificSubdivision.geoNameId", "mostSpecificSubdivision.names", - "postal.confidence", "registeredCountry.confidence", "registeredCountry.geoNameId", "registeredCountry.inEuropeanUnion", From 3e4957da52e4874e4071566f659eca1980696379 Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 8 Oct 2024 07:24:15 +0100 Subject: [PATCH 15/22] Document that `?wait_for_active_shards=0` is permitted (#114091) (#114273) Today the docs for the `?wait_for_active_shards` parameter say that it must be a positive integer, proscribing `0`, yet `0` is a legitimate value for this parameter. This commit fixes this point and rewords the docs slightly for clarity. --- docs/reference/rest-api/common-parms.asciidoc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/reference/rest-api/common-parms.asciidoc b/docs/reference/rest-api/common-parms.asciidoc index fabd495cdc525..993bb8cb894f9 100644 --- a/docs/reference/rest-api/common-parms.asciidoc +++ b/docs/reference/rest-api/common-parms.asciidoc @@ -1298,10 +1298,11 @@ tag::wait_for_active_shards[] `wait_for_active_shards`:: + -- -(Optional, string) The number of shard copies that must be active before -proceeding with the operation. Set to `all` or any positive integer up -to the total number of shards in the index (`number_of_replicas+1`). -Default: 1, the primary shard. +(Optional, string) The number of copies of each shard that must be active +before proceeding with the operation. Set to `all` or any non-negative integer +up to the total number of copies of each shard in the index +(`number_of_replicas+1`). Defaults to `1`, meaning to wait just for each +primary shard to be active. See <>. -- From 309d234bf08d9361c987411146acc0003cb511fc Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Tue, 8 Oct 2024 09:33:30 +0300 Subject: [PATCH 16/22] Updating toXContent implementation for retrievers (#114017) (#114272) --- .../search/retriever/RetrieverBuilder.java | 8 +++ .../builder/SearchSourceBuilderTests.java | 71 +++++++++++++++++++ .../KnnRetrieverBuilderParsingTests.java | 2 +- .../StandardRetrieverBuilderParsingTests.java | 2 +- .../random/RandomRankRetrieverBuilder.java | 5 +- .../TextSimilarityRankRetrieverBuilder.java | 8 +-- .../RandomRankRetrieverBuilderTests.java | 10 ++- ...xtSimilarityRankRetrieverBuilderTests.java | 46 +++++++++++- .../xpack/rank/rrf/RRFRetrieverBuilder.java | 3 - .../rrf/RRFRetrieverBuilderParsingTests.java | 57 ++++++++++++++- 10 files changed, 187 insertions(+), 25 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java index 1328106896bcb..1c6f8c4a7ce44 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java @@ -251,11 +251,19 @@ public ActionRequestValidationException validate( @Override public final XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { builder.startObject(); + builder.startObject(getName()); if (preFilterQueryBuilders.isEmpty() == false) { builder.field(PRE_FILTER_FIELD.getPreferredName(), preFilterQueryBuilders); } + if (minScore != null) { + builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore); + } + if (retrieverName != null) { + builder.field(NAME_FIELD.getPreferredName(), retrieverName); + } doToXContent(builder, params); builder.endObject(); + builder.endObject(); return builder; } diff --git a/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java b/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java index 3f33bbfe6f6cb..240a677f4cbfd 100644 --- a/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java @@ -41,6 +41,8 @@ import org.elasticsearch.search.collapse.CollapseBuilderTests; import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; import org.elasticsearch.search.rescore.QueryRescorerBuilder; +import org.elasticsearch.search.retriever.KnnRetrieverBuilder; +import org.elasticsearch.search.retriever.StandardRetrieverBuilder; import org.elasticsearch.search.slice.SliceBuilder; import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.ScoreSortBuilder; @@ -600,6 +602,75 @@ public void testNegativeTrackTotalHits() throws IOException { } } + public void testStandardRetrieverParsing() throws IOException { + String restContent = "{" + + " \"retriever\": {" + + " \"standard\": {" + + " \"query\": {" + + " \"match_all\": {}" + + " }," + + " \"min_score\": 10," + + " \"_name\": \"foo_standard\"" + + " }" + + " }" + + "}"; + SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder(); + try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) { + SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true); + assertThat(source.retriever(), instanceOf(StandardRetrieverBuilder.class)); + StandardRetrieverBuilder parsed = (StandardRetrieverBuilder) source.retriever(); + assertThat(parsed.minScore(), equalTo(10f)); + assertThat(parsed.retrieverName(), equalTo("foo_standard")); + try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) { + SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent( + parseSerialized, + true, + searchUsageHolder, + nf -> true + ); + assertThat(deserializedSource.retriever(), instanceOf(StandardRetrieverBuilder.class)); + StandardRetrieverBuilder deserialized = (StandardRetrieverBuilder) source.retriever(); + assertThat(parsed, equalTo(deserialized)); + } + } + } + + public void testKnnRetrieverParsing() throws IOException { + String restContent = "{" + + " \"retriever\": {" + + " \"knn\": {" + + " \"query_vector\": [" + + " 3" + + " ]," + + " \"field\": \"vector\"," + + " \"k\": 10," + + " \"num_candidates\": 15," + + " \"min_score\": 10," + + " \"_name\": \"foo_knn\"" + + " }" + + " }" + + "}"; + SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder(); + try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) { + SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true); + assertThat(source.retriever(), instanceOf(KnnRetrieverBuilder.class)); + KnnRetrieverBuilder parsed = (KnnRetrieverBuilder) source.retriever(); + assertThat(parsed.minScore(), equalTo(10f)); + assertThat(parsed.retrieverName(), equalTo("foo_knn")); + try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) { + SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent( + parseSerialized, + true, + searchUsageHolder, + nf -> true + ); + assertThat(deserializedSource.retriever(), instanceOf(KnnRetrieverBuilder.class)); + KnnRetrieverBuilder deserialized = (KnnRetrieverBuilder) source.retriever(); + assertThat(parsed, equalTo(deserialized)); + } + } + } + public void testStoredFieldsUsage() throws IOException { Set storedFieldRestVariations = Set.of( "{\"stored_fields\" : [\"_none_\"]}", diff --git a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java index f3dd86e0b1fa2..b0bf7e6636498 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java @@ -74,7 +74,7 @@ protected KnnRetrieverBuilder createTestInstance() { @Override protected KnnRetrieverBuilder doParseInstance(XContentParser parser) throws IOException { - return KnnRetrieverBuilder.fromXContent( + return (KnnRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder( parser, new RetrieverParserContext( new SearchUsage(), diff --git a/server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java index d2a1cac43c154..eacd949077bc4 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java @@ -98,7 +98,7 @@ protected StandardRetrieverBuilder createTestInstance() { @Override protected StandardRetrieverBuilder doParseInstance(XContentParser parser) throws IOException { - return StandardRetrieverBuilder.fromXContent( + return (StandardRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder( parser, new RetrieverParserContext( new SearchUsage(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilder.java index eb36c445506a7..134f8af0e083d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilder.java @@ -103,10 +103,7 @@ public int rankWindowSize() { @Override protected void doToXContent(XContentBuilder builder, Params params) throws IOException { - builder.field(RETRIEVER_FIELD.getPreferredName()); - builder.startObject(); - builder.field(retrieverBuilder.getName(), retrieverBuilder); - builder.endObject(); + builder.field(RETRIEVER_FIELD.getPreferredName(), retrieverBuilder); builder.field(FIELD_FIELD.getPreferredName(), field); builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); if (seed != null) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index ab013e0275a69..50d762e7b90aa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -179,17 +179,11 @@ public int rankWindowSize() { @Override protected void doToXContent(XContentBuilder builder, Params params) throws IOException { - builder.field(RETRIEVER_FIELD.getPreferredName()); - builder.startObject(); - builder.field(retrieverBuilder.getName(), retrieverBuilder); - builder.endObject(); + builder.field(RETRIEVER_FIELD.getPreferredName(), retrieverBuilder); builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); builder.field(INFERENCE_TEXT_FIELD.getPreferredName(), inferenceText); builder.field(FIELD_FIELD.getPreferredName(), field); builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); - if (minScore != null) { - builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore); - } } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilderTests.java index c33f30d461350..c0ef4e45f101f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilderTests.java @@ -17,8 +17,6 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.json.JsonXContent; -import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder; -import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder; import java.io.IOException; import java.util.ArrayList; @@ -48,8 +46,8 @@ protected RandomRankRetrieverBuilder createTestInstance() { } @Override - protected RandomRankRetrieverBuilder doParseInstance(XContentParser parser) { - return RandomRankRetrieverBuilder.PARSER.apply( + protected RandomRankRetrieverBuilder doParseInstance(XContentParser parser) throws IOException { + return (RandomRankRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder( parser, new RetrieverParserContext( new SearchUsage(), @@ -77,8 +75,8 @@ protected NamedXContentRegistry xContentRegistry() { entries.add( new NamedXContentRegistry.Entry( RetrieverBuilder.class, - new ParseField(TextSimilarityRankBuilder.NAME), - (p, c) -> TextSimilarityRankRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c) + new ParseField(RandomRankBuilder.NAME), + (p, c) -> RandomRankRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c) ) ); return new NamedXContentRegistry(entries); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java index 1a72cb0da2899..140b181a42a0a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.rank.textsimilarity; import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.common.Strings; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.MatchNoneQueryBuilder; @@ -25,6 +26,8 @@ import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.usage.SearchUsage; +import org.elasticsearch.usage.SearchUsageHolder; +import org.elasticsearch.usage.UsageService; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; @@ -72,8 +75,8 @@ protected TextSimilarityRankRetrieverBuilder createTestInstance() { } @Override - protected TextSimilarityRankRetrieverBuilder doParseInstance(XContentParser parser) { - return TextSimilarityRankRetrieverBuilder.PARSER.apply( + protected TextSimilarityRankRetrieverBuilder doParseInstance(XContentParser parser) throws IOException { + return (TextSimilarityRankRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder( parser, new RetrieverParserContext( new SearchUsage(), @@ -208,6 +211,45 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder } } + public void testTextSimilarityRetrieverParsing() throws IOException { + String restContent = "{" + + " \"retriever\": {" + + " \"text_similarity_reranker\": {" + + " \"retriever\": {" + + " \"test\": {" + + " \"value\": \"my-test-retriever\"" + + " }" + + " }," + + " \"field\": \"my-field\"," + + " \"inference_id\": \"my-inference-id\"," + + " \"inference_text\": \"my-inference-text\"," + + " \"rank_window_size\": 100," + + " \"min_score\": 20.0," + + " \"_name\": \"foo_reranker\"" + + " }" + + " }" + + "}"; + SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder(); + try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) { + SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true); + assertThat(source.retriever(), instanceOf(TextSimilarityRankRetrieverBuilder.class)); + TextSimilarityRankRetrieverBuilder parsed = (TextSimilarityRankRetrieverBuilder) source.retriever(); + assertThat(parsed.minScore(), equalTo(20f)); + assertThat(parsed.retrieverName(), equalTo("foo_reranker")); + try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) { + SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent( + parseSerialized, + true, + searchUsageHolder, + nf -> true + ); + assertThat(deserializedSource.retriever(), instanceOf(TextSimilarityRankRetrieverBuilder.class)); + TextSimilarityRankRetrieverBuilder deserialized = (TextSimilarityRankRetrieverBuilder) source.retriever(); + assertThat(parsed, equalTo(deserialized)); + } + } + } + public void testIsCompound() { RetrieverBuilder compoundInnerRetriever = new TestRetrieverBuilder(ESTestCase.randomAlphaOfLengthBetween(5, 10)) { @Override diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index 496af99574431..5f19e361d857d 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -180,10 +180,7 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept builder.startArray(RETRIEVERS_FIELD.getPreferredName()); for (var entry : innerRetrievers) { - builder.startObject(); - builder.field(entry.retriever().getName()); entry.retriever().toXContent(builder, params); - builder.endObject(); } builder.endArray(); } diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java index e360237371a82..d324effe41c22 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java @@ -8,19 +8,27 @@ package org.elasticsearch.xpack.rank.rrf; import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.common.Strings; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.retriever.TestRetrieverBuilder; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.usage.SearchUsage; +import org.elasticsearch.usage.SearchUsageHolder; +import org.elasticsearch.usage.UsageService; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.json.JsonXContent; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + public class RRFRetrieverBuilderParsingTests extends AbstractXContentTestCase { /** @@ -53,7 +61,10 @@ protected RRFRetrieverBuilder createTestInstance() { @Override protected RRFRetrieverBuilder doParseInstance(XContentParser parser) throws IOException { - return RRFRetrieverBuilder.PARSER.apply(parser, new RetrieverParserContext(new SearchUsage(), nf -> true)); + return (RRFRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder( + parser, + new RetrieverParserContext(new SearchUsage(), nf -> true) + ); } @Override @@ -81,4 +92,48 @@ protected NamedXContentRegistry xContentRegistry() { ); return new NamedXContentRegistry(entries); } + + public void testRRFRetrieverParsing() throws IOException { + String restContent = "{" + + " \"retriever\": {" + + " \"rrf\": {" + + " \"retrievers\": [" + + " {" + + " \"test\": {" + + " \"value\": \"foo\"" + + " }" + + " }," + + " {" + + " \"test\": {" + + " \"value\": \"bar\"" + + " }" + + " }" + + " ]," + + " \"rank_window_size\": 100," + + " \"rank_constant\": 10," + + " \"min_score\": 20.0," + + " \"_name\": \"foo_rrf\"" + + " }" + + " }" + + "}"; + SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder(); + try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) { + SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true); + assertThat(source.retriever(), instanceOf(RRFRetrieverBuilder.class)); + RRFRetrieverBuilder parsed = (RRFRetrieverBuilder) source.retriever(); + assertThat(parsed.minScore(), equalTo(20f)); + assertThat(parsed.retrieverName(), equalTo("foo_rrf")); + try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) { + SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent( + parseSerialized, + true, + searchUsageHolder, + nf -> true + ); + assertThat(deserializedSource.retriever(), instanceOf(RRFRetrieverBuilder.class)); + RRFRetrieverBuilder deserialized = (RRFRetrieverBuilder) source.retriever(); + assertThat(parsed, equalTo(deserialized)); + } + } + } } From 74f523d4ccb6d26dc65c99138ccad1d43d6cdaaf Mon Sep 17 00:00:00 2001 From: Luca Cavanna Date: Tue, 8 Oct 2024 08:48:53 +0200 Subject: [PATCH 17/22] Span term query to convert to match no docs when unmapped field is targeted (#113251) SpanTermQueryBuilder currently creates a valid SpanTermQuery against unmapped fields. In practice, if the field is unmapped, there won't be a match. This commit changes the toQuery impl to return a MatchNoDocsQuery instead like we do in similar scenarios. --- docs/changelog/113251.yaml | 5 ++++ .../upgrades/QueryBuilderBWCIT.java | 29 ++++++++---------- .../index/query/SpanTermQueryBuilder.java | 20 ++++++------- .../FieldMaskingSpanQueryBuilderTests.java | 5 ++-- .../query/SpanTermQueryBuilderTests.java | 30 ++++++++++++------- 5 files changed, 48 insertions(+), 41 deletions(-) create mode 100644 docs/changelog/113251.yaml diff --git a/docs/changelog/113251.yaml b/docs/changelog/113251.yaml new file mode 100644 index 0000000000000..49167e6e4c915 --- /dev/null +++ b/docs/changelog/113251.yaml @@ -0,0 +1,5 @@ +pr: 113251 +summary: Span term query to convert to match no docs when unmapped field is targeted +area: Search +type: bug +issues: [] diff --git a/qa/full-cluster-restart/src/javaRestTest/java/org/elasticsearch/upgrades/QueryBuilderBWCIT.java b/qa/full-cluster-restart/src/javaRestTest/java/org/elasticsearch/upgrades/QueryBuilderBWCIT.java index e817290d71a6a..42a0a3d6da055 100644 --- a/qa/full-cluster-restart/src/javaRestTest/java/org/elasticsearch/upgrades/QueryBuilderBWCIT.java +++ b/qa/full-cluster-restart/src/javaRestTest/java/org/elasticsearch/upgrades/QueryBuilderBWCIT.java @@ -154,29 +154,29 @@ public QueryBuilderBWCIT(@Name("cluster") FullClusterRestartUpgradeStatus upgrad ); addCandidate( """ - "span_near": {"clauses": [{ "span_term": { "keyword_field": "value1" }}, \ - { "span_term": { "keyword_field": "value2" }}]} + "span_near": {"clauses": [{ "span_term": { "text_field": "value1" }}, \ + { "span_term": { "text_field": "value2" }}]} """, - new SpanNearQueryBuilder(new SpanTermQueryBuilder("keyword_field", "value1"), 0).addClause( - new SpanTermQueryBuilder("keyword_field", "value2") + new SpanNearQueryBuilder(new SpanTermQueryBuilder("text_field", "value1"), 0).addClause( + new SpanTermQueryBuilder("text_field", "value2") ) ); addCandidate( """ - "span_near": {"clauses": [{ "span_term": { "keyword_field": "value1" }}, \ - { "span_term": { "keyword_field": "value2" }}], "slop": 2} + "span_near": {"clauses": [{ "span_term": { "text_field": "value1" }}, \ + { "span_term": { "text_field": "value2" }}], "slop": 2} """, - new SpanNearQueryBuilder(new SpanTermQueryBuilder("keyword_field", "value1"), 2).addClause( - new SpanTermQueryBuilder("keyword_field", "value2") + new SpanNearQueryBuilder(new SpanTermQueryBuilder("text_field", "value1"), 2).addClause( + new SpanTermQueryBuilder("text_field", "value2") ) ); addCandidate( """ - "span_near": {"clauses": [{ "span_term": { "keyword_field": "value1" }}, \ - { "span_term": { "keyword_field": "value2" }}], "slop": 2, "in_order": false} + "span_near": {"clauses": [{ "span_term": { "text_field": "value1" }}, \ + { "span_term": { "text_field": "value2" }}], "slop": 2, "in_order": false} """, - new SpanNearQueryBuilder(new SpanTermQueryBuilder("keyword_field", "value1"), 2).addClause( - new SpanTermQueryBuilder("keyword_field", "value2") + new SpanNearQueryBuilder(new SpanTermQueryBuilder("text_field", "value1"), 2).addClause( + new SpanTermQueryBuilder("text_field", "value2") ).inOrder(false) ); } @@ -204,11 +204,6 @@ public void testQueryBuilderBWC() throws Exception { mappingsAndSettings.field("type", "percolator"); mappingsAndSettings.endObject(); } - { - mappingsAndSettings.startObject("keyword_field"); - mappingsAndSettings.field("type", "keyword"); - mappingsAndSettings.endObject(); - } { mappingsAndSettings.startObject("text_field"); mappingsAndSettings.field("type", "text"); diff --git a/server/src/main/java/org/elasticsearch/index/query/SpanTermQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/SpanTermQueryBuilder.java index 55df1f144bb14..fcbe4ded2fea0 100644 --- a/server/src/main/java/org/elasticsearch/index/query/SpanTermQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/SpanTermQueryBuilder.java @@ -9,16 +9,14 @@ package org.elasticsearch.index.query; -import org.apache.lucene.index.Term; -import org.apache.lucene.queries.spans.SpanQuery; import org.apache.lucene.queries.spans.SpanTermQuery; import org.apache.lucene.search.Query; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.lucene.queries.SpanMatchNoDocsQuery; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; @@ -71,16 +69,18 @@ public SpanTermQueryBuilder(StreamInput in) throws IOException { } @Override - protected SpanQuery doToQuery(SearchExecutionContext context) throws IOException { + protected Query doToQuery(SearchExecutionContext context) throws IOException { MappedFieldType mapper = context.getFieldType(fieldName); - Term term; if (mapper == null) { - term = new Term(fieldName, BytesRefs.toBytesRef(value)); - } else { - Query termQuery = mapper.termQuery(value, context); - term = MappedFieldType.extractTerm(termQuery); + return new SpanMatchNoDocsQuery(fieldName, "unmapped field: " + fieldName); } - return new SpanTermQuery(term); + Query termQuery = mapper.termQuery(value, context); + if (mapper.getTextSearchInfo().hasPositions() == false) { + throw new IllegalArgumentException( + "Span term query requires position data, but field " + fieldName + " was indexed without position data" + ); + } + return new SpanTermQuery(MappedFieldType.extractTerm(termQuery)); } public static SpanTermQueryBuilder fromXContent(XContentParser parser) throws IOException, ParsingException { diff --git a/server/src/test/java/org/elasticsearch/index/query/FieldMaskingSpanQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/FieldMaskingSpanQueryBuilderTests.java index bc2ba833536ff..3890b53e29ffb 100644 --- a/server/src/test/java/org/elasticsearch/index/query/FieldMaskingSpanQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/FieldMaskingSpanQueryBuilderTests.java @@ -9,13 +9,12 @@ package org.elasticsearch.index.query; -import org.apache.lucene.index.Term; import org.apache.lucene.queries.spans.FieldMaskingSpanQuery; -import org.apache.lucene.queries.spans.SpanTermQuery; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.Query; import org.elasticsearch.common.ParsingException; import org.elasticsearch.core.Strings; +import org.elasticsearch.lucene.queries.SpanMatchNoDocsQuery; import org.elasticsearch.test.AbstractQueryTestCase; import java.io.IOException; @@ -105,7 +104,7 @@ public void testJsonWithTopLevelBoost() throws IOException { } }""", NAME.getPreferredName()); Query q = parseQuery(json).toQuery(createSearchExecutionContext()); - assertEquals(new BoostQuery(new FieldMaskingSpanQuery(new SpanTermQuery(new Term("value", "foo")), "mapped_geo_shape"), 42.0f), q); + assertEquals(new BoostQuery(new FieldMaskingSpanQuery(new SpanMatchNoDocsQuery("value", null), "mapped_geo_shape"), 42.0f), q); } public void testJsonWithDeprecatedName() throws IOException { diff --git a/server/src/test/java/org/elasticsearch/index/query/SpanTermQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/SpanTermQueryBuilderTests.java index 5e01f6c304efd..c4a9267ff68a0 100644 --- a/server/src/test/java/org/elasticsearch/index/query/SpanTermQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/SpanTermQueryBuilderTests.java @@ -11,11 +11,13 @@ import org.apache.lucene.index.Term; import org.apache.lucene.queries.spans.SpanTermQuery; +import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; import org.elasticsearch.common.ParsingException; -import org.elasticsearch.common.lucene.BytesRefs; +import org.elasticsearch.index.mapper.IdFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.lucene.queries.SpanMatchNoDocsQuery; import org.elasticsearch.xcontent.json.JsonStringEncoder; import java.io.IOException; @@ -47,18 +49,16 @@ protected SpanTermQueryBuilder createQueryBuilder(String fieldName, Object value @Override protected void doAssertLuceneQuery(SpanTermQueryBuilder queryBuilder, Query query, SearchExecutionContext context) throws IOException { - assertThat(query, instanceOf(SpanTermQuery.class)); - SpanTermQuery spanTermQuery = (SpanTermQuery) query; - - String expectedFieldName = expectedFieldName(queryBuilder.fieldName); - assertThat(spanTermQuery.getTerm().field(), equalTo(expectedFieldName)); - MappedFieldType mapper = context.getFieldType(queryBuilder.fieldName()); if (mapper != null) { + String expectedFieldName = expectedFieldName(queryBuilder.fieldName); + assertThat(query, instanceOf(SpanTermQuery.class)); + SpanTermQuery spanTermQuery = (SpanTermQuery) query; + assertThat(spanTermQuery.getTerm().field(), equalTo(expectedFieldName)); Term term = ((TermQuery) mapper.termQuery(queryBuilder.value(), null)).getTerm(); assertThat(spanTermQuery.getTerm(), equalTo(term)); } else { - assertThat(spanTermQuery.getTerm().bytes(), equalTo(BytesRefs.toBytesRef(queryBuilder.value()))); + assertThat(query, instanceOf(SpanMatchNoDocsQuery.class)); } } @@ -115,13 +115,21 @@ public void testParseFailsWithMultipleFields() throws IOException { assertEquals("[span_term] query doesn't support multiple fields, found [message1] and [message2]", e.getMessage()); } - public void testWithMetadataField() throws IOException { + public void testWithBoost() throws IOException { SearchExecutionContext context = createSearchExecutionContext(); - for (String field : new String[] { "field1", "field2" }) { + for (String field : new String[] { TEXT_FIELD_NAME, TEXT_ALIAS_FIELD_NAME }) { SpanTermQueryBuilder spanTermQueryBuilder = new SpanTermQueryBuilder(field, "toto"); + spanTermQueryBuilder.boost(10); Query query = spanTermQueryBuilder.toQuery(context); - Query expected = new SpanTermQuery(new Term(field, "toto")); + Query expected = new BoostQuery(new SpanTermQuery(new Term(TEXT_FIELD_NAME, "toto")), 10); assertEquals(expected, query); } } + + public void testFieldWithoutPositions() { + SearchExecutionContext context = createSearchExecutionContext(); + SpanTermQueryBuilder spanTermQueryBuilder = new SpanTermQueryBuilder(IdFieldMapper.NAME, "1234"); + IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> spanTermQueryBuilder.toQuery(context)); + assertEquals("Span term query requires position data, but field _id was indexed without position data", iae.getMessage()); + } } From 419d534f0bc09b3390e80d03da7fd6bb813ab55b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Istv=C3=A1n=20Zolt=C3=A1n=20Szab=C3=B3?= Date: Tue, 8 Oct 2024 09:05:23 +0200 Subject: [PATCH 18/22] [DOCS] Adds DeBERTa v2 tokenization params to infer trained model API docs (#114242) (#114277) * [DOCS] Adds DeBERTa v2 tokenization params to infer trained model API docs. * [DOCS] Mode edits. --- .../apis/infer-trained-model.asciidoc | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/docs/reference/ml/trained-models/apis/infer-trained-model.asciidoc b/docs/reference/ml/trained-models/apis/infer-trained-model.asciidoc index 99c3ecad03a9d..7acbc0bd23859 100644 --- a/docs/reference/ml/trained-models/apis/infer-trained-model.asciidoc +++ b/docs/reference/ml/trained-models/apis/infer-trained-model.asciidoc @@ -225,6 +225,17 @@ include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizatio (Optional, string) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate] ======= +`deberta_v2`:::: +(Optional, object) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-deberta-v2] ++ +.Properties of deberta_v2 +[%collapsible%open] +======= +`truncate`:::: +(Optional, string) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate-deberta-v2] +======= `roberta`:::: (Optional, object) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-roberta] @@ -301,6 +312,17 @@ include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizatio (Optional, string) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate] ======= +`deberta_v2`:::: +(Optional, object) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-deberta-v2] ++ +.Properties of deberta_v2 +[%collapsible%open] +======= +`truncate`:::: +(Optional, string) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate-deberta-v2] +======= `roberta`:::: (Optional, object) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-roberta] @@ -397,6 +419,21 @@ include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizatio (Optional, string) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate] ======= +`deberta_v2`:::: +(Optional, object) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-deberta-v2] ++ +.Properties of deberta_v2 +[%collapsible%open] +======= +`span`:::: +(Optional, integer) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-span] + +`truncate`:::: +(Optional, string) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate-deberta-v2] +======= `roberta`:::: (Optional, object) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-roberta] @@ -517,6 +554,21 @@ include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizatio (Optional, string) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate] ======= +`deberta_v2`:::: +(Optional, object) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-deberta-v2] ++ +.Properties of deberta_v2 +[%collapsible%open] +======= +`span`:::: +(Optional, integer) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-span] + +`truncate`:::: +(Optional, string) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate-deberta-v2] +======= `roberta`:::: (Optional, object) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-roberta] @@ -608,6 +660,17 @@ include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizatio (Optional, string) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate] ======= +`deberta_v2`:::: +(Optional, object) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-deberta-v2] ++ +.Properties of deberta_v2 +[%collapsible%open] +======= +`truncate`:::: +(Optional, string) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate-deberta-v2] +======= `roberta`:::: (Optional, object) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-roberta] @@ -687,6 +750,21 @@ include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizatio (Optional, integer) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-span] +`with_special_tokens`:::: +(Optional, boolean) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens] +======= +`deberta_v2`:::: +(Optional, object) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-deberta-v2] ++ +.Properties of deberta_v2 +[%collapsible%open] +======= +`span`:::: +(Optional, integer) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-span] + `with_special_tokens`:::: (Optional, boolean) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens] @@ -790,6 +868,17 @@ include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizatio (Optional, string) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate] ======= +`deberta_v2`:::: +(Optional, object) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-deberta-v2] ++ +.Properties of deberta_v2 +[%collapsible%open] +======= +`truncate`:::: +(Optional, string) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate-deberta-v2] +======= `roberta`:::: (Optional, object) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-roberta] From 3bfdc8943c29fef1b137f0f32a18bb49728a7c6e Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 8 Oct 2024 09:48:00 +0100 Subject: [PATCH 19/22] Show only committed cluster UUID in `GET /` (#114275) (#114281) Today we show `Metadata#clusterUUID` in the response to `GET /` regardless of whether this value is committed or not, which means that in theory users may see this value change even if nothing is going wrong. To avoid any doubt about the stability of this cluster UUID, this commit suppresses the cluster UUID in this API response until it is committed. --- .../rest/root/TransportMainAction.java | 3 +- .../rest/root/MainActionTests.java | 42 +++++++++++-------- .../cluster/metadata/Metadata.java | 5 +++ 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/modules/rest-root/src/main/java/org/elasticsearch/rest/root/TransportMainAction.java b/modules/rest-root/src/main/java/org/elasticsearch/rest/root/TransportMainAction.java index 15f23f7511445..2598f943ea27c 100644 --- a/modules/rest-root/src/main/java/org/elasticsearch/rest/root/TransportMainAction.java +++ b/modules/rest-root/src/main/java/org/elasticsearch/rest/root/TransportMainAction.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.TransportAction; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; @@ -48,7 +49,7 @@ protected void doExecute(Task task, MainRequest request, ActionListener responseRef = new AtomicReference<>(); - action.doExecute(mock(Task.class), new MainRequest(), new ActionListener<>() { - @Override - public void onResponse(MainResponse mainResponse) { - responseRef.set(mainResponse); - } - - @Override - public void onFailure(Exception e) { - logger.error("unexpected error", e); - } - }); + final AtomicBoolean listenerCalled = new AtomicBoolean(); + new TransportMainAction(settings, transportService, mock(ActionFilters.class), clusterService).doExecute( + mock(Task.class), + new MainRequest(), + ActionTestUtils.assertNoFailureListener(mainResponse -> { + assertNotNull(mainResponse); + assertEquals( + state.metadata().clusterUUIDCommitted() ? state.metadata().clusterUUID() : Metadata.UNKNOWN_CLUSTER_UUID, + mainResponse.getClusterUuid() + ); + assertFalse(listenerCalled.getAndSet(true)); + }) + ); - assertNotNull(responseRef.get()); + assertTrue(listenerCalled.get()); verify(clusterService, times(1)).state(); } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java index 1319e75520ed6..6b3904d84a397 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java @@ -697,6 +697,11 @@ public long version() { return this.version; } + /** + * @return A UUID which identifies this cluster. Nodes record the UUID of the cluster they first join on disk, and will then refuse to + * join clusters with different UUIDs. Note that when the cluster is forming for the first time this value may not yet be committed, + * and therefore it may change. Check {@link #clusterUUIDCommitted()} to verify that the value is committed if needed. + */ public String clusterUUID() { return this.clusterUUID; } From 029c6836d950f8bb133e079ffd22ab0b57424294 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= Date: Tue, 8 Oct 2024 11:21:09 +0200 Subject: [PATCH 20/22] Remove dangling AwaitsFix (#113941) (#114286) The issue referenced seems to be closed, maybe this just wasn't removed properly. --- .../SearchQueryThenFetchAsyncActionTests.java | 483 +++++++++--------- 1 file changed, 253 insertions(+), 230 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java index 3c4976d9bfa86..b63c88f623e21 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -258,11 +258,10 @@ public void run() { } } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/101932") public void testMinimumVersionSameAsNewVersion() throws Exception { var newVersion = VersionInformation.CURRENT; var oldVersion = new VersionInformation( - VersionUtils.randomVersionBetween(random(), Version.CURRENT.minimumCompatibilityVersion(), VersionUtils.getPreviousVersion()), + VersionUtils.randomCompatibleVersion(random(), VersionUtils.getPreviousVersion()), IndexVersions.MINIMUM_COMPATIBLE, IndexVersionUtils.randomCompatibleVersion(random()) ); @@ -340,65 +339,69 @@ private void testMixedVersionsShardsSearch(VersionInformation oldVersion, Versio SearchTransportService searchTransportService = new SearchTransportService(null, null, null); SearchPhaseController controller = new SearchPhaseController((t, r) -> InternalAggregationTestCase.emptyReduceContextBuilder()); SearchTask task = new SearchTask(0, "n/a", "n/a", () -> "test", null, Collections.emptyMap()); - QueryPhaseResultConsumer resultConsumer = new QueryPhaseResultConsumer( - searchRequest, - EsExecutors.DIRECT_EXECUTOR_SERVICE, - new NoopCircuitBreaker(CircuitBreaker.REQUEST), - controller, - task::isCancelled, - task.getProgressListener(), - shardsIter.size(), - exc -> {} - ); - final List responses = new ArrayList<>(); - SearchQueryThenFetchAsyncAction newSearchAsyncAction = new SearchQueryThenFetchAsyncAction( - logger, - null, - searchTransportService, - (clusterAlias, node) -> lookup.get(node), - Collections.singletonMap("_na_", AliasFilter.EMPTY), - Collections.emptyMap(), - EsExecutors.DIRECT_EXECUTOR_SERVICE, - resultConsumer, - searchRequest, - new ActionListener<>() { - @Override - public void onFailure(Exception e) { - responses.add(e); - } + try ( + QueryPhaseResultConsumer resultConsumer = new QueryPhaseResultConsumer( + searchRequest, + EsExecutors.DIRECT_EXECUTOR_SERVICE, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + controller, + task::isCancelled, + task.getProgressListener(), + shardsIter.size(), + exc -> {} + ) + ) { + final List responses = new ArrayList<>(); + SearchQueryThenFetchAsyncAction newSearchAsyncAction = new SearchQueryThenFetchAsyncAction( + logger, + null, + searchTransportService, + (clusterAlias, node) -> lookup.get(node), + Collections.singletonMap("_na_", AliasFilter.EMPTY), + Collections.emptyMap(), + EsExecutors.DIRECT_EXECUTOR_SERVICE, + resultConsumer, + searchRequest, + new ActionListener<>() { + @Override + public void onFailure(Exception e) { + responses.add(e); + } - public void onResponse(SearchResponse response) { - responses.add(response); - }; - }, - shardsIter, - timeProvider, - new ClusterState.Builder(new ClusterName("test")).build(), - task, - SearchResponse.Clusters.EMPTY, - null - ); + public void onResponse(SearchResponse response) { + responses.add(response); + } - newSearchAsyncAction.start(); - assertThat(responses, hasSize(1)); - assertThat(responses.get(0), instanceOf(SearchPhaseExecutionException.class)); - SearchPhaseExecutionException e = (SearchPhaseExecutionException) responses.get(0); - assertThat(e.getCause(), instanceOf(VersionMismatchException.class)); - assertThat( - e.getCause().getMessage(), - equalTo("One of the shards is incompatible with the required minimum version [" + minVersion + "]") - ); + ; + }, + shardsIter, + timeProvider, + new ClusterState.Builder(new ClusterName("test")).build(), + task, + SearchResponse.Clusters.EMPTY, + null + ); + + newSearchAsyncAction.start(); + assertThat(responses, hasSize(1)); + assertThat(responses.get(0), instanceOf(SearchPhaseExecutionException.class)); + SearchPhaseExecutionException e = (SearchPhaseExecutionException) responses.get(0); + assertThat(e.getCause(), instanceOf(VersionMismatchException.class)); + assertThat( + e.getCause().getMessage(), + equalTo("One of the shards is incompatible with the required minimum version [" + minVersion + "]") + ); + } } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/101932") public void testMinimumVersionSameAsOldVersion() throws Exception { - Version newVersion = Version.CURRENT; - Version oldVersion = VersionUtils.randomVersionBetween( - random(), - Version.CURRENT.minimumCompatibilityVersion(), - VersionUtils.getPreviousVersion(newVersion) + var newVersion = VersionInformation.CURRENT; + var oldVersion = new VersionInformation( + VersionUtils.randomCompatibleVersion(random(), VersionUtils.getPreviousVersion()), + IndexVersions.MINIMUM_COMPATIBLE, + IndexVersionUtils.randomCompatibleVersion(random()) ); - Version minVersion = oldVersion; + Version minVersion = oldVersion.nodeVersion(); final TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider( 0, @@ -456,98 +459,106 @@ public void sendExecuteQuery( new SearchShardTarget("node1", new ShardId("idx", "na", shardId), null), null ); - SortField sortField = new SortField("timestamp", SortField.Type.LONG); - if (shardId == 0) { - queryResult.topDocs( - new TopDocsAndMaxScore( - new TopFieldDocs( - new TotalHits(1, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), - new FieldDoc[] { new FieldDoc(randomInt(1000), Float.NaN, new Object[] { shardId }) }, - new SortField[] { sortField } + try { + SortField sortField = new SortField("timestamp", SortField.Type.LONG); + if (shardId == 0) { + queryResult.topDocs( + new TopDocsAndMaxScore( + new TopFieldDocs( + new TotalHits(1, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + new FieldDoc[] { new FieldDoc(randomInt(1000), Float.NaN, new Object[] { shardId }) }, + new SortField[] { sortField } + ), + Float.NaN ), - Float.NaN - ), - new DocValueFormat[] { DocValueFormat.RAW } - ); - } else if (shardId == 1) { - queryResult.topDocs( - new TopDocsAndMaxScore( - new TopFieldDocs( - new TotalHits(1, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), - new FieldDoc[] { new FieldDoc(randomInt(1000), Float.NaN, new Object[] { shardId }) }, - new SortField[] { sortField } + new DocValueFormat[] { DocValueFormat.RAW } + ); + } else if (shardId == 1) { + queryResult.topDocs( + new TopDocsAndMaxScore( + new TopFieldDocs( + new TotalHits(1, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + new FieldDoc[] { new FieldDoc(randomInt(1000), Float.NaN, new Object[] { shardId }) }, + new SortField[] { sortField } + ), + Float.NaN ), - Float.NaN - ), - new DocValueFormat[] { DocValueFormat.RAW } - ); + new DocValueFormat[] { DocValueFormat.RAW } + ); + } + queryResult.from(0); + queryResult.size(1); + successfulOps.incrementAndGet(); + queryResult.incRef(); + new Thread(() -> ActionListener.respondAndRelease(listener, queryResult)).start(); + } finally { + queryResult.decRef(); } - queryResult.from(0); - queryResult.size(1); - successfulOps.incrementAndGet(); - new Thread(() -> listener.onResponse(queryResult)).start(); } }; SearchPhaseController controller = new SearchPhaseController((t, r) -> InternalAggregationTestCase.emptyReduceContextBuilder()); SearchTask task = new SearchTask(0, "n/a", "n/a", () -> "test", null, Collections.emptyMap()); - QueryPhaseResultConsumer resultConsumer = new QueryPhaseResultConsumer( - searchRequest, - EsExecutors.DIRECT_EXECUTOR_SERVICE, - new NoopCircuitBreaker(CircuitBreaker.REQUEST), - controller, - task::isCancelled, - task.getProgressListener(), - shardsIter.size(), - exc -> {} - ); - CountDownLatch latch = new CountDownLatch(1); - SearchQueryThenFetchAsyncAction action = new SearchQueryThenFetchAsyncAction( - logger, - null, - searchTransportService, - (clusterAlias, node) -> lookup.get(node), - Collections.singletonMap("_na_", AliasFilter.EMPTY), - Collections.emptyMap(), - EsExecutors.DIRECT_EXECUTOR_SERVICE, - resultConsumer, - searchRequest, - null, - shardsIter, - timeProvider, - new ClusterState.Builder(new ClusterName("test")).build(), - task, - SearchResponse.Clusters.EMPTY, - null + try ( + QueryPhaseResultConsumer resultConsumer = new QueryPhaseResultConsumer( + searchRequest, + EsExecutors.DIRECT_EXECUTOR_SERVICE, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + controller, + task::isCancelled, + task.getProgressListener(), + shardsIter.size(), + exc -> {} + ) ) { - @Override - protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { - return new SearchPhase("test") { - @Override - public void run() { - latch.countDown(); - } - }; - } - }; + CountDownLatch latch = new CountDownLatch(1); + SearchQueryThenFetchAsyncAction action = new SearchQueryThenFetchAsyncAction( + logger, + null, + searchTransportService, + (clusterAlias, node) -> lookup.get(node), + Collections.singletonMap("_na_", AliasFilter.EMPTY), + Collections.emptyMap(), + EsExecutors.DIRECT_EXECUTOR_SERVICE, + resultConsumer, + searchRequest, + null, + shardsIter, + timeProvider, + new ClusterState.Builder(new ClusterName("test")).build(), + task, + SearchResponse.Clusters.EMPTY, + null + ) { + @Override + protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { + return new SearchPhase("test") { + @Override + public void run() { + latch.countDown(); + } + }; + } + }; - action.start(); - latch.await(); - assertThat(successfulOps.get(), equalTo(2)); - SearchPhaseController.ReducedQueryPhase phase = action.results.reduce(); - assertThat(phase.numReducePhases(), greaterThanOrEqualTo(1)); - assertThat(phase.totalHits().value, equalTo(2L)); - assertThat(phase.totalHits().relation, equalTo(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)); + action.start(); + latch.await(); + assertThat(successfulOps.get(), equalTo(2)); + SearchPhaseController.ReducedQueryPhase phase = action.results.reduce(); + assertThat(phase.numReducePhases(), greaterThanOrEqualTo(1)); + assertThat(phase.totalHits().value, equalTo(2L)); + assertThat(phase.totalHits().relation, equalTo(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)); + } } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/101932") public void testMinimumVersionShardDuringPhaseExecution() throws Exception { - Version newVersion = Version.CURRENT; - Version oldVersion = VersionUtils.randomVersionBetween( - random(), - Version.CURRENT.minimumCompatibilityVersion(), - VersionUtils.getPreviousVersion(newVersion) + var newVersion = VersionInformation.CURRENT; + var oldVersion = new VersionInformation( + VersionUtils.randomCompatibleVersion(random(), VersionUtils.getPreviousVersion()), + IndexVersions.MINIMUM_COMPATIBLE, + IndexVersionUtils.randomCompatibleVersion(random()) ); - Version minVersion = newVersion; + + Version minVersion = newVersion.nodeVersion(); final TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider( 0, @@ -607,111 +618,123 @@ public void sendExecuteQuery( new SearchShardTarget("node1", new ShardId("idx", "na", shardId), null), null ); - SortField sortField = new SortField("timestamp", SortField.Type.LONG); - if (shardId == 0) { - queryResult.topDocs( - new TopDocsAndMaxScore( - new TopFieldDocs( - new TotalHits(1, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), - new FieldDoc[] { new FieldDoc(randomInt(1000), Float.NaN, new Object[] { shardId }) }, - new SortField[] { sortField } + try { + SortField sortField = new SortField("timestamp", SortField.Type.LONG); + if (shardId == 0) { + queryResult.topDocs( + new TopDocsAndMaxScore( + new TopFieldDocs( + new TotalHits(1, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + new FieldDoc[] { new FieldDoc(randomInt(1000), Float.NaN, new Object[] { shardId }) }, + new SortField[] { sortField } + ), + Float.NaN ), - Float.NaN - ), - new DocValueFormat[] { DocValueFormat.RAW } - ); - } else if (shardId == 1) { - queryResult.topDocs( - new TopDocsAndMaxScore( - new TopFieldDocs( - new TotalHits(1, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), - new FieldDoc[] { new FieldDoc(randomInt(1000), Float.NaN, new Object[] { shardId }) }, - new SortField[] { sortField } + new DocValueFormat[] { DocValueFormat.RAW } + ); + } else if (shardId == 1) { + queryResult.topDocs( + new TopDocsAndMaxScore( + new TopFieldDocs( + new TotalHits(1, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + new FieldDoc[] { new FieldDoc(randomInt(1000), Float.NaN, new Object[] { shardId }) }, + new SortField[] { sortField } + ), + Float.NaN ), - Float.NaN - ), - new DocValueFormat[] { DocValueFormat.RAW } - ); + new DocValueFormat[] { DocValueFormat.RAW } + ); + } + queryResult.from(0); + queryResult.size(1); + successfulOps.incrementAndGet(); + queryResult.incRef(); + new Thread(() -> ActionListener.respondAndRelease(listener, queryResult)).start(); + } finally { + queryResult.decRef(); } - queryResult.from(0); - queryResult.size(1); - successfulOps.incrementAndGet(); - new Thread(() -> listener.onResponse(queryResult)).start(); } }; SearchPhaseController controller = new SearchPhaseController((t, r) -> InternalAggregationTestCase.emptyReduceContextBuilder()); SearchTask task = new SearchTask(0, "n/a", "n/a", () -> "test", null, Collections.emptyMap()); - QueryPhaseResultConsumer resultConsumer = new QueryPhaseResultConsumer( - searchRequest, - EsExecutors.DIRECT_EXECUTOR_SERVICE, - new NoopCircuitBreaker(CircuitBreaker.REQUEST), - controller, - task::isCancelled, - task.getProgressListener(), - shardsIter.size(), - exc -> {} - ); + CountDownLatch latch = new CountDownLatch(1); - SearchQueryThenFetchAsyncAction action = new SearchQueryThenFetchAsyncAction( - logger, - null, - searchTransportService, - (clusterAlias, node) -> lookup.get(node), - Collections.singletonMap("_na_", AliasFilter.EMPTY), - Collections.emptyMap(), - EsExecutors.DIRECT_EXECUTOR_SERVICE, - resultConsumer, - searchRequest, - null, - shardsIter, - timeProvider, - new ClusterState.Builder(new ClusterName("test")).build(), - task, - SearchResponse.Clusters.EMPTY, - null + try ( + QueryPhaseResultConsumer resultConsumer = new QueryPhaseResultConsumer( + searchRequest, + EsExecutors.DIRECT_EXECUTOR_SERVICE, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + controller, + task::isCancelled, + task.getProgressListener(), + shardsIter.size(), + exc -> {} + ) ) { - @Override - protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { - return new SearchPhase("test") { - @Override - public void run() { - latch.countDown(); - } - }; - } - }; - ShardRouting routingOldVersionShard = ShardRouting.newUnassigned( - new ShardId(new Index("idx", "_na_"), 2), - true, - RecoverySource.EmptyStoreRecoverySource.INSTANCE, - new UnassignedInfo(UnassignedInfo.Reason.INDEX_CREATED, "foobar"), - ShardRouting.Role.DEFAULT - ); - SearchShardIterator shardIt = new SearchShardIterator( - null, - new ShardId(new Index("idx", "_na_"), 2), - singletonList(routingOldVersionShard), - idx - ); - routingOldVersionShard = routingOldVersionShard.initialize(oldVersionNode.getId(), "p2", 0); - routingOldVersionShard.started(); - action.start(); - latch.await(); - assertThat(successfulOps.get(), equalTo(2)); - SearchPhaseController.ReducedQueryPhase phase = action.results.reduce(); - assertThat(phase.numReducePhases(), greaterThanOrEqualTo(1)); - assertThat(phase.totalHits().value, equalTo(2L)); - assertThat(phase.totalHits().relation, equalTo(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)); - - SearchShardTarget searchShardTarget = new SearchShardTarget("node3", shardIt.shardId(), null); - SearchActionListener listener = new SearchActionListener(searchShardTarget, 0) { - @Override - public void onFailure(Exception e) {} + SearchQueryThenFetchAsyncAction action = new SearchQueryThenFetchAsyncAction( + logger, + null, + searchTransportService, + (clusterAlias, node) -> lookup.get(node), + Collections.singletonMap("_na_", AliasFilter.EMPTY), + Collections.emptyMap(), + EsExecutors.DIRECT_EXECUTOR_SERVICE, + resultConsumer, + searchRequest, + null, + shardsIter, + timeProvider, + new ClusterState.Builder(new ClusterName("test")).build(), + task, + SearchResponse.Clusters.EMPTY, + null + ) { + @Override + protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { + return new SearchPhase("test") { + @Override + public void run() { + latch.countDown(); + } + }; + } + }; + ShardRouting routingOldVersionShard = ShardRouting.newUnassigned( + new ShardId(new Index("idx", "_na_"), 2), + true, + RecoverySource.EmptyStoreRecoverySource.INSTANCE, + new UnassignedInfo(UnassignedInfo.Reason.INDEX_CREATED, "foobar"), + ShardRouting.Role.DEFAULT + ); + SearchShardIterator shardIt = new SearchShardIterator( + null, + new ShardId(new Index("idx", "_na_"), 2), + singletonList(routingOldVersionShard), + idx + ); + routingOldVersionShard = routingOldVersionShard.initialize(oldVersionNode.getId(), "p2", 0); + routingOldVersionShard.started(); + action.start(); + latch.await(); + assertThat(successfulOps.get(), equalTo(2)); + SearchPhaseController.ReducedQueryPhase phase = action.results.reduce(); + assertThat(phase.numReducePhases(), greaterThanOrEqualTo(1)); + assertThat(phase.totalHits().value, equalTo(2L)); + assertThat(phase.totalHits().relation, equalTo(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)); - @Override - protected void innerOnResponse(SearchPhaseResult response) {} - }; - Exception e = expectThrows(VersionMismatchException.class, () -> action.executePhaseOnShard(shardIt, searchShardTarget, listener)); - assertThat(e.getMessage(), equalTo("One of the shards is incompatible with the required minimum version [" + minVersion + "]")); + SearchShardTarget searchShardTarget = new SearchShardTarget("node3", shardIt.shardId(), null); + SearchActionListener listener = new SearchActionListener(searchShardTarget, 0) { + @Override + public void onFailure(Exception e) {} + + @Override + protected void innerOnResponse(SearchPhaseResult response) {} + }; + Exception e = expectThrows( + VersionMismatchException.class, + () -> action.executePhaseOnShard(shardIt, searchShardTarget, listener) + ); + assertThat(e.getMessage(), equalTo("One of the shards is incompatible with the required minimum version [" + minVersion + "]")); + } } } From e745c92ebc6a40067d30e041023d7643c0ff2941 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Tue, 8 Oct 2024 13:44:00 +0300 Subject: [PATCH 21/22] Backporting text_similarity_reranker retriever rework to be evaluated during rewrite phase to 8.x (#114282) * backporting text_similarity_reranker rework to 8.x * fixing comp --------- Co-authored-by: Elastic Machine --- .../org/elasticsearch/TransportVersions.java | 1 + .../action/search/SearchPhaseController.java | 9 + .../elasticsearch/search/rank/RankDoc.java | 11 +- .../retriever/CompoundRetrieverBuilder.java | 4 +- .../search/sort/ShardDocSortField.java | 6 +- ...bstractRankDocWireSerializingTestCase.java | 57 +++ .../search/rank/RankDocTests.java | 16 +- .../xpack/inference/InferenceFeatures.java | 3 +- .../xpack/inference/InferencePlugin.java | 3 + .../textsimilarity/TextSimilarityRankDoc.java | 103 ++++++ .../TextSimilarityRankRetrieverBuilder.java | 127 ++++--- .../TextSimilarityRankDocTests.java | 88 +++++ ...xtSimilarityRankRetrieverBuilderTests.java | 121 +------ .../70_text_similarity_rank_retriever.yml | 40 ++- x-pack/plugin/rank-rrf/build.gradle | 4 + .../xpack/rank/rrf/RRFRankDoc.java | 6 + .../xpack/rank/rrf/RRFRankDocTests.java | 21 +- .../rrf/RRFRankClientYamlTestSuiteIT.java | 2 + ...ith_text_similarity_reranker_retriever.yml | 334 ++++++++++++++++++ 19 files changed, 760 insertions(+), 196 deletions(-) create mode 100644 server/src/test/java/org/elasticsearch/search/rank/AbstractRankDocWireSerializingTestCase.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDoc.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDocTests.java create mode 100644 x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/800_rrf_with_text_similarity_reranker_retriever.yml diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 3ef128c4b1d3a..9b5392803b178 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -236,6 +236,7 @@ static TransportVersion def(int id) { public static final TransportVersion INGEST_GEO_DATABASE_PROVIDERS = def(8_760_00_0); public static final TransportVersion DATE_TIME_DOC_VALUES_LOCALES = def(8_761_00_0); public static final TransportVersion FAST_REFRESH_RCO = def(8_762_00_0); + public static final TransportVersion TEXT_SIMILARITY_RERANKER_QUERY_REWRITE = def(8_763_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index a6acb3ee2a52e..1c4eb1c191370 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -36,6 +36,7 @@ import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchService; +import org.elasticsearch.search.SearchSortValues; import org.elasticsearch.search.aggregations.AggregationReduceContext; import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.InternalAggregations; @@ -51,6 +52,7 @@ import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; +import org.elasticsearch.search.sort.ShardDocSortField; import org.elasticsearch.search.suggest.Suggest; import org.elasticsearch.search.suggest.Suggest.Suggestion; import org.elasticsearch.search.suggest.completion.CompletionSuggestion; @@ -464,6 +466,13 @@ private static SearchHits getHits( assert shardDoc instanceof RankDoc; searchHit.setRank(((RankDoc) shardDoc).rank); searchHit.score(shardDoc.score); + long shardAndDoc = ShardDocSortField.encodeShardAndDoc(shardDoc.shardIndex, shardDoc.doc); + searchHit.sortValues( + new SearchSortValues( + new Object[] { shardDoc.score, shardAndDoc }, + new DocValueFormat[] { DocValueFormat.RAW, DocValueFormat.RAW } + ) + ); } else if (sortedTopDocs.isSortedByField) { FieldDoc fieldDoc = (FieldDoc) shardDoc; searchHit.sortValues(fieldDoc.fields, reducedQueryPhase.sortValueFormats); diff --git a/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java b/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java index b16a234931115..9ab14aa9362b5 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java +++ b/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java @@ -11,9 +11,11 @@ import org.apache.lucene.search.Explanation; import org.apache.lucene.search.ScoreDoc; -import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.VersionedNamedWriteable; import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.XContentBuilder; @@ -24,7 +26,7 @@ * {@code RankDoc} is the base class for all ranked results. * Subclasses should extend this with additional information required for their global ranking method. */ -public class RankDoc extends ScoreDoc implements NamedWriteable, ToXContentFragment, Comparable { +public class RankDoc extends ScoreDoc implements VersionedNamedWriteable, ToXContentFragment, Comparable { public static final String NAME = "rank_doc"; @@ -40,6 +42,11 @@ public String getWriteableName() { return NAME; } + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.RANK_DOCS_RETRIEVER; + } + @Override public final int compareTo(RankDoc other) { if (score != other.score) { diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index 1962145d7336d..22bef026523e9 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -160,7 +160,7 @@ public void onFailure(Exception e) { @Override public final QueryBuilder topDocsQuery() { - throw new IllegalStateException(getName() + " cannot be nested"); + throw new IllegalStateException("Should not be called, missing a rewrite?"); } @Override @@ -208,7 +208,7 @@ public int doHashCode() { return Objects.hash(innerRetrievers); } - private SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { + protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit) .trackTotalHits(false) .storedFields(new StoredFieldsContext(false)) diff --git a/server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java b/server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java index a38a24eb75fca..be1ce9c925037 100644 --- a/server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java +++ b/server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java @@ -64,7 +64,7 @@ public void setTopValue(Long value) { @Override public Long value(int slot) { - return (((long) shardRequestIndex) << 32) | (delegate.value(slot) & 0xFFFFFFFFL); + return encodeShardAndDoc(shardRequestIndex, delegate.value(slot)); } @Override @@ -87,4 +87,8 @@ public static int decodeDoc(long value) { public static int decodeShardRequestIndex(long value) { return (int) (value >> 32); } + + public static long encodeShardAndDoc(int shardIndex, int doc) { + return (((long) shardIndex) << 32) | (doc & 0xFFFFFFFFL); + } } diff --git a/server/src/test/java/org/elasticsearch/search/rank/AbstractRankDocWireSerializingTestCase.java b/server/src/test/java/org/elasticsearch/search/rank/AbstractRankDocWireSerializingTestCase.java new file mode 100644 index 0000000000000..d0c85a33acf09 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/rank/AbstractRankDocWireSerializingTestCase.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.rank; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.hamcrest.Matchers.equalTo; + +public abstract class AbstractRankDocWireSerializingTestCase extends AbstractWireSerializingTestCase { + + protected abstract T createTestRankDoc(); + + @Override + protected NamedWriteableRegistry writableRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList()); + List entries = searchModule.getNamedWriteables(); + entries.addAll(getAdditionalNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + + protected abstract List getAdditionalNamedWriteables(); + + @Override + protected T createTestInstance() { + return createTestRankDoc(); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testRankDocSerialization() throws IOException { + int totalDocs = randomIntBetween(10, 100); + Set docs = new HashSet<>(); + for (int i = 0; i < totalDocs; i++) { + docs.add(createTestRankDoc()); + } + RankDocsQueryBuilder rankDocsQueryBuilder = new RankDocsQueryBuilder(docs.toArray((T[]) new RankDoc[0]), null, randomBoolean()); + RankDocsQueryBuilder copy = (RankDocsQueryBuilder) copyNamedWriteable(rankDocsQueryBuilder, writableRegistry(), QueryBuilder.class); + assertThat(rankDocsQueryBuilder, equalTo(copy)); + } +} diff --git a/server/src/test/java/org/elasticsearch/search/rank/RankDocTests.java b/server/src/test/java/org/elasticsearch/search/rank/RankDocTests.java index d190139309c31..21101b2bc7db1 100644 --- a/server/src/test/java/org/elasticsearch/search/rank/RankDocTests.java +++ b/server/src/test/java/org/elasticsearch/search/rank/RankDocTests.java @@ -9,27 +9,29 @@ package org.elasticsearch.search.rank; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; import java.io.IOException; +import java.util.Collections; +import java.util.List; -public class RankDocTests extends AbstractWireSerializingTestCase { +public class RankDocTests extends AbstractRankDocWireSerializingTestCase { - static RankDoc createTestRankDoc() { + protected RankDoc createTestRankDoc() { RankDoc rankDoc = new RankDoc(randomNonNegativeInt(), randomFloat(), randomIntBetween(0, 1)); rankDoc.rank = randomNonNegativeInt(); return rankDoc; } @Override - protected Writeable.Reader instanceReader() { - return RankDoc::new; + protected List getAdditionalNamedWriteables() { + return Collections.emptyList(); } @Override - protected RankDoc createTestInstance() { - return createTestRankDoc(); + protected Writeable.Reader instanceReader() { + return RankDoc::new; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index fd330a8cf6cc6..640c979deda56 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -25,7 +25,8 @@ public Set getFeatures() { return Set.of( TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED, RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED, - SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID + SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID, + TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index dbb9130ab91e1..927fd94809886 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -36,6 +36,7 @@ import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; import org.elasticsearch.search.rank.RankBuilder; +import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.threadpool.ExecutorBuilder; import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.xcontent.ParseField; @@ -66,6 +67,7 @@ import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder; import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder; +import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankDoc; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceEndpointAction; @@ -253,6 +255,7 @@ public List getNamedWriteables() { var entries = new ArrayList<>(InferenceNamedWriteablesProvider.getNamedWriteables()); entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, TextSimilarityRankBuilder.NAME, TextSimilarityRankBuilder::new)); entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, RandomRankBuilder.NAME, RandomRankBuilder::new)); + entries.add(new NamedWriteableRegistry.Entry(RankDoc.class, TextSimilarityRankDoc.NAME, TextSimilarityRankDoc::new)); return entries; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDoc.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDoc.java new file mode 100644 index 0000000000000..d208623e53324 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDoc.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rank.textsimilarity; + +import org.apache.lucene.search.Explanation; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class TextSimilarityRankDoc extends RankDoc { + + public static final String NAME = "text_similarity_rank_doc"; + + public final String inferenceId; + public final String field; + + public TextSimilarityRankDoc(int doc, float score, int shardIndex, String inferenceId, String field) { + super(doc, score, shardIndex); + this.inferenceId = inferenceId; + this.field = field; + } + + public TextSimilarityRankDoc(StreamInput in) throws IOException { + super(in); + inferenceId = in.readString(); + field = in.readString(); + } + + @Override + public Explanation explain(Explanation[] sources, String[] queryNames) { + final String queryAlias = queryNames[0] == null ? "" : "[" + queryNames[0] + "]"; + return Explanation.match( + score, + "text_similarity_reranker match using inference endpoint: [" + + inferenceId + + "] on document field: [" + + field + + "] matching on source query " + + queryAlias, + sources + ); + } + + @Override + public void doWriteTo(StreamOutput out) throws IOException { + out.writeString(inferenceId); + out.writeString(field); + } + + @Override + public boolean doEquals(RankDoc rd) { + TextSimilarityRankDoc tsrd = (TextSimilarityRankDoc) rd; + return Objects.equals(inferenceId, tsrd.inferenceId) && Objects.equals(field, tsrd.field); + } + + @Override + public int doHashCode() { + return Objects.hash(inferenceId, field); + } + + @Override + public String toString() { + return "TextSimilarityRankDoc{" + + "doc=" + + doc + + ", shardIndex=" + + shardIndex + + ", score=" + + score + + ", inferenceId=" + + inferenceId + + ", field=" + + field + + '}'; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + protected void doToXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("inferenceId", inferenceId); + builder.field("field", field); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.TEXT_SIMILARITY_RERANKER_QUERY_REWRITE; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index 50d762e7b90aa..f201251599f9d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -7,14 +7,20 @@ package org.elasticsearch.xpack.inference.rank.textsimilarity; +import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.common.ParsingException; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.fetch.StoredFieldsContext; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; @@ -32,11 +38,14 @@ /** * A {@code RetrieverBuilder} for parsing and constructing a text similarity reranker retriever. */ -public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder { +public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder { public static final NodeFeature TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED = new NodeFeature( "text_similarity_reranker_retriever_supported" ); + public static final NodeFeature TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED = new NodeFeature( + "text_similarity_reranker_retriever_composition_supported" + ); public static final ParseField RETRIEVER_FIELD = new ParseField("retriever"); public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); @@ -51,7 +60,6 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder { String inferenceText = (String) args[2]; String field = (String) args[3]; int rankWindowSize = args[4] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[4]; - return new TextSimilarityRankRetrieverBuilder(retrieverBuilder, inferenceId, inferenceText, field, rankWindowSize); }); @@ -70,17 +78,20 @@ public static TextSimilarityRankRetrieverBuilder fromXContent(XContentParser par if (context.clusterSupportsFeature(TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED) == false) { throw new ParsingException(parser.getTokenLocation(), "unknown retriever [" + TextSimilarityRankBuilder.NAME + "]"); } + if (context.clusterSupportsFeature(TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED) == false) { + throw new UnsupportedOperationException( + "[text_similarity_reranker] retriever composition feature is not supported by all nodes in the cluster" + ); + } if (TextSimilarityRankBuilder.TEXT_SIMILARITY_RERANKER_FEATURE.check(XPackPlugin.getSharedLicenseState()) == false) { throw LicenseUtils.newComplianceException(TextSimilarityRankBuilder.NAME); } return PARSER.apply(parser, context); } - private final RetrieverBuilder retrieverBuilder; private final String inferenceId; private final String inferenceText; private final String field; - private final int rankWindowSize; public TextSimilarityRankRetrieverBuilder( RetrieverBuilder retrieverBuilder, @@ -89,15 +100,14 @@ public TextSimilarityRankRetrieverBuilder( String field, int rankWindowSize ) { - this.retrieverBuilder = retrieverBuilder; + super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize); this.inferenceId = inferenceId; this.inferenceText = inferenceText; this.field = field; - this.rankWindowSize = rankWindowSize; } public TextSimilarityRankRetrieverBuilder( - RetrieverBuilder retrieverBuilder, + List retrieverSource, String inferenceId, String inferenceText, String field, @@ -106,66 +116,75 @@ public TextSimilarityRankRetrieverBuilder( String retrieverName, List preFilterQueryBuilders ) { - this.retrieverBuilder = retrieverBuilder; + super(retrieverSource, rankWindowSize); + if (retrieverSource.size() != 1) { + throw new IllegalArgumentException("[" + getName() + "] retriever should have exactly one inner retriever"); + } this.inferenceId = inferenceId; this.inferenceText = inferenceText; this.field = field; - this.rankWindowSize = rankWindowSize; this.minScore = minScore; this.retrieverName = retrieverName; this.preFilterQueryBuilders = preFilterQueryBuilders; } @Override - public QueryBuilder topDocsQuery() { - // the original matching set of the TextSimilarityRank retriever is specified by its nested retriever - return retrieverBuilder.topDocsQuery(); + protected TextSimilarityRankRetrieverBuilder clone(List newChildRetrievers) { + return new TextSimilarityRankRetrieverBuilder( + newChildRetrievers, + inferenceId, + inferenceText, + field, + rankWindowSize, + minScore, + retrieverName, + preFilterQueryBuilders + ); } @Override - public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { - // rewrite prefilters - boolean hasChanged = false; - var newPreFilters = rewritePreFilters(ctx); - hasChanged |= newPreFilters != preFilterQueryBuilders; - - // rewrite nested retriever - RetrieverBuilder newRetriever = retrieverBuilder.rewrite(ctx); - hasChanged |= newRetriever != retrieverBuilder; - if (hasChanged) { - return new TextSimilarityRankRetrieverBuilder( - newRetriever, - field, - inferenceText, - inferenceId, - rankWindowSize, - minScore, - this.retrieverName, - newPreFilters - ); + protected RankDoc[] combineInnerRetrieverResults(List rankResults) { + assert rankResults.size() == 1; + ScoreDoc[] scoreDocs = rankResults.get(0); + TextSimilarityRankDoc[] textSimilarityRankDocs = new TextSimilarityRankDoc[scoreDocs.length]; + for (int i = 0; i < scoreDocs.length; i++) { + ScoreDoc scoreDoc = scoreDocs[i]; + textSimilarityRankDocs[i] = new TextSimilarityRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex, inferenceId, field); } - return this; + return textSimilarityRankDocs; } @Override - public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { - retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders); - retrieverBuilder.extractToSearchSourceBuilder(searchSourceBuilder, compoundUsed); - // Combining with other rank builder (such as RRF) is not supported yet - if (searchSourceBuilder.rankBuilder() != null) { - throw new IllegalArgumentException("text similarity rank builder cannot be combined with other rank builders"); - } + public QueryBuilder explainQuery() { + // the original matching set of the TextSimilarityRank retriever is specified by its nested retriever + return new RankDocsQueryBuilder(rankDocs, new QueryBuilder[] { innerRetrievers.get(0).retriever().explainQuery() }, true); + } - searchSourceBuilder.rankBuilder( + @Override + protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { + var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit) + .trackTotalHits(false) + .storedFields(new StoredFieldsContext(false)) + .size(rankWindowSize); + if (preFilterQueryBuilders.isEmpty() == false) { + retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders); + } + retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true); + + // apply the pre-filters + if (preFilterQueryBuilders.size() > 0) { + QueryBuilder query = sourceBuilder.query(); + BoolQueryBuilder newQuery = new BoolQueryBuilder(); + if (query != null) { + newQuery.must(query); + } + preFilterQueryBuilders.forEach(newQuery::filter); + sourceBuilder.query(newQuery); + } + sourceBuilder.rankBuilder( new TextSimilarityRankBuilder(this.field, this.inferenceId, this.inferenceText, this.rankWindowSize, this.minScore) ); - } - - /** - * Determines if this retriever contains sub-retrievers that need to be executed prior to search. - */ - public boolean isCompound() { - return retrieverBuilder.isCompound(); + return sourceBuilder; } @Override @@ -179,7 +198,7 @@ public int rankWindowSize() { @Override protected void doToXContent(XContentBuilder builder, Params params) throws IOException { - builder.field(RETRIEVER_FIELD.getPreferredName(), retrieverBuilder); + builder.field(RETRIEVER_FIELD.getPreferredName(), innerRetrievers.get(0).retriever()); builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); builder.field(INFERENCE_TEXT_FIELD.getPreferredName(), inferenceText); builder.field(FIELD_FIELD.getPreferredName(), field); @@ -187,9 +206,9 @@ protected void doToXContent(XContentBuilder builder, Params params) throws IOExc } @Override - protected boolean doEquals(Object other) { + public boolean doEquals(Object other) { TextSimilarityRankRetrieverBuilder that = (TextSimilarityRankRetrieverBuilder) other; - return Objects.equals(retrieverBuilder, that.retrieverBuilder) + return super.doEquals(other) && Objects.equals(inferenceId, that.inferenceId) && Objects.equals(inferenceText, that.inferenceText) && Objects.equals(field, that.field) @@ -198,7 +217,7 @@ protected boolean doEquals(Object other) { } @Override - protected int doHashCode() { - return Objects.hash(retrieverBuilder, inferenceId, inferenceText, field, rankWindowSize, minScore); + public int doHashCode() { + return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDocTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDocTests.java new file mode 100644 index 0000000000000..fed4565c54bd4 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDocTests.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rank.textsimilarity; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.search.rank.AbstractRankDocWireSerializingTestCase; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.InferencePlugin; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.search.rank.RankDoc.NO_RANK; + +public class TextSimilarityRankDocTests extends AbstractRankDocWireSerializingTestCase { + + static TextSimilarityRankDoc createTestTextSimilarityRankDoc() { + TextSimilarityRankDoc instance = new TextSimilarityRankDoc( + randomNonNegativeInt(), + randomFloat(), + randomBoolean() ? -1 : randomNonNegativeInt(), + randomAlphaOfLength(randomIntBetween(2, 5)), + randomAlphaOfLength(randomIntBetween(2, 5)) + ); + instance.rank = randomBoolean() ? NO_RANK : randomIntBetween(1, 10000); + return instance; + } + + @Override + protected List getAdditionalNamedWriteables() { + try (InferencePlugin plugin = new InferencePlugin(Settings.EMPTY)) { + return plugin.getNamedWriteables(); + } + } + + @Override + protected Writeable.Reader instanceReader() { + return TextSimilarityRankDoc::new; + } + + @Override + protected TextSimilarityRankDoc createTestRankDoc() { + return createTestTextSimilarityRankDoc(); + } + + @Override + protected TextSimilarityRankDoc mutateInstance(TextSimilarityRankDoc instance) throws IOException { + int doc = instance.doc; + int shardIndex = instance.shardIndex; + float score = instance.score; + int rank = instance.rank; + String inferenceId = instance.inferenceId; + String field = instance.field; + + switch (randomInt(5)) { + case 0: + doc = randomValueOtherThan(doc, ESTestCase::randomNonNegativeInt); + break; + case 1: + shardIndex = shardIndex == -1 ? randomNonNegativeInt() : -1; + break; + case 2: + score = randomValueOtherThan(score, ESTestCase::randomFloat); + break; + case 3: + rank = rank == NO_RANK ? randomIntBetween(1, 10000) : NO_RANK; + break; + case 4: + inferenceId = randomValueOtherThan(inferenceId, () -> randomAlphaOfLength(randomIntBetween(2, 5))); + break; + case 5: + field = randomValueOtherThan(field, () -> randomAlphaOfLength(randomIntBetween(2, 5))); + break; + default: + throw new AssertionError(); + } + TextSimilarityRankDoc mutated = new TextSimilarityRankDoc(doc, score, shardIndex, inferenceId, field); + mutated.rank = rank; + return mutated; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java index 140b181a42a0a..32301bf9efea9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java @@ -9,17 +9,9 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.Strings; -import org.elasticsearch.index.query.BoolQueryBuilder; -import org.elasticsearch.index.query.MatchAllQueryBuilder; -import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryRewriteContext; -import org.elasticsearch.index.query.RandomQueryBuilder; -import org.elasticsearch.index.query.RangeQueryBuilder; -import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.builder.SubSearchSourceBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.retriever.TestRetrieverBuilder; @@ -38,10 +30,8 @@ import java.util.List; import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE; -import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.Mockito.mock; public class TextSimilarityRankRetrieverBuilderTests extends AbstractXContentTestCase { @@ -82,6 +72,7 @@ protected TextSimilarityRankRetrieverBuilder doParseInstance(XContentParser pars new SearchUsage(), nf -> nf == RetrieverBuilder.RETRIEVERS_SUPPORTED || nf == TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED + || nf == TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED ) ); } @@ -131,86 +122,6 @@ public void testParserDefaults() throws IOException { } } - public void testRewriteInnerRetriever() throws IOException { - final boolean[] rewritten = { false }; - List preFilterQueryBuilders = new ArrayList<>(); - if (randomBoolean()) { - for (int i = 0; i < randomIntBetween(1, 5); i++) { - preFilterQueryBuilders.add(RandomQueryBuilder.createQuery(random())); - } - } - RetrieverBuilder innerRetriever = new TestRetrieverBuilder("top-level-retriever") { - @Override - public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { - if (randomBoolean()) { - return this; - } - rewritten[0] = true; - return new TestRetrieverBuilder("nested-rewritten-retriever") { - @Override - public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { - if (preFilterQueryBuilders.isEmpty() == false) { - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - - for (QueryBuilder preFilterQueryBuilder : preFilterQueryBuilders) { - boolQueryBuilder.filter(preFilterQueryBuilder); - } - boolQueryBuilder.must(new RangeQueryBuilder("some_field")); - searchSourceBuilder.subSearches().add(new SubSearchSourceBuilder(boolQueryBuilder)); - } else { - searchSourceBuilder.subSearches().add(new SubSearchSourceBuilder(new RangeQueryBuilder("some_field"))); - } - } - }; - } - - @Override - public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { - if (preFilterQueryBuilders.isEmpty() == false) { - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - - for (QueryBuilder preFilterQueryBuilder : preFilterQueryBuilders) { - boolQueryBuilder.filter(preFilterQueryBuilder); - } - boolQueryBuilder.must(new TermQueryBuilder("field", "value")); - searchSourceBuilder.subSearches().add(new SubSearchSourceBuilder(boolQueryBuilder)); - } else { - searchSourceBuilder.subSearches().add(new SubSearchSourceBuilder(new TermQueryBuilder("field", "value"))); - } - } - }; - TextSimilarityRankRetrieverBuilder textSimilarityRankRetrieverBuilder = createRandomTextSimilarityRankRetrieverBuilder( - innerRetriever - ); - textSimilarityRankRetrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders); - SearchSourceBuilder source = new SearchSourceBuilder().retriever(textSimilarityRankRetrieverBuilder); - QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class); - source = Rewriteable.rewrite(source, queryRewriteContext); - assertNull(source.retriever()); - if (false == preFilterQueryBuilders.isEmpty()) { - if (source.query() instanceof MatchAllQueryBuilder == false && source.query() instanceof MatchNoneQueryBuilder == false) { - assertThat(source.query(), instanceOf(BoolQueryBuilder.class)); - BoolQueryBuilder bq = (BoolQueryBuilder) source.query(); - assertFalse(bq.must().isEmpty()); - assertThat(bq.must().size(), equalTo(1)); - if (rewritten[0]) { - assertThat(bq.must().get(0), instanceOf(RangeQueryBuilder.class)); - } else { - assertThat(bq.must().get(0), instanceOf(TermQueryBuilder.class)); - } - for (int j = 0; j < bq.filter().size(); j++) { - assertEqualQueryOrMatchAllNone(bq.filter().get(j), preFilterQueryBuilders.get(j)); - } - } - } else { - if (rewritten[0]) { - assertThat(source.query(), instanceOf(RangeQueryBuilder.class)); - } else { - assertThat(source.query(), instanceOf(TermQueryBuilder.class)); - } - } - } - public void testTextSimilarityRetrieverParsing() throws IOException { String restContent = "{" + " \"retriever\": {" @@ -250,29 +161,6 @@ public void testTextSimilarityRetrieverParsing() throws IOException { } } - public void testIsCompound() { - RetrieverBuilder compoundInnerRetriever = new TestRetrieverBuilder(ESTestCase.randomAlphaOfLengthBetween(5, 10)) { - @Override - public boolean isCompound() { - return true; - } - }; - RetrieverBuilder nonCompoundInnerRetriever = new TestRetrieverBuilder(ESTestCase.randomAlphaOfLengthBetween(5, 10)) { - @Override - public boolean isCompound() { - return false; - } - }; - TextSimilarityRankRetrieverBuilder compoundTextSimilarityRankRetrieverBuilder = createRandomTextSimilarityRankRetrieverBuilder( - compoundInnerRetriever - ); - assertTrue(compoundTextSimilarityRankRetrieverBuilder.isCompound()); - TextSimilarityRankRetrieverBuilder nonCompoundTextSimilarityRankRetrieverBuilder = createRandomTextSimilarityRankRetrieverBuilder( - nonCompoundInnerRetriever - ); - assertFalse(nonCompoundTextSimilarityRankRetrieverBuilder.isCompound()); - } - public void testTopDocsQuery() { RetrieverBuilder innerRetriever = new TestRetrieverBuilder(ESTestCase.randomAlphaOfLengthBetween(5, 10)) { @Override @@ -281,11 +169,6 @@ public QueryBuilder topDocsQuery() { } }; TextSimilarityRankRetrieverBuilder retriever = createRandomTextSimilarityRankRetrieverBuilder(innerRetriever); - assertThat(retriever.topDocsQuery(), instanceOf(TermQueryBuilder.class)); - } - - private static void assertEqualQueryOrMatchAllNone(QueryBuilder actual, QueryBuilder expected) { - assertThat(actual, anyOf(instanceOf(MatchAllQueryBuilder.class), instanceOf(MatchNoneQueryBuilder.class), equalTo(expected))); + expectThrows(IllegalStateException.class, "Should not be called, missing a rewrite?", retriever::topDocsQuery); } - } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml index e2c1417057578..9a4d7f4416164 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml @@ -87,11 +87,9 @@ setup: - length: { hits.hits: 2 } - match: { hits.hits.0._id: "doc_2" } - - match: { hits.hits.0._rank: 1 } - close_to: { hits.hits.0._score: { value: 0.4, error: 0.001 } } - match: { hits.hits.1._id: "doc_1" } - - match: { hits.hits.1._rank: 2 } - close_to: { hits.hits.1._score: { value: 0.2, error: 0.001 } } --- @@ -123,7 +121,6 @@ setup: - length: { hits.hits: 1 } - match: { hits.hits.0._id: "doc_1" } - - match: { hits.hits.0._rank: 1 } - close_to: { hits.hits.0._score: { value: 0.2, error: 0.001 } } @@ -178,3 +175,40 @@ setup: field: text size: 10 + +--- +"text similarity reranking with explain": + + - do: + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "topic" ] + retriever: { + text_similarity_reranker: { + retriever: + { + standard: { + query: { + term: { + topic: "science" + } + } + } + }, + rank_window_size: 10, + inference_id: my-rerank-model, + inference_text: "How often does the moon hide the sun?", + field: text + } + } + size: 10 + explain: true + + - match: { hits.hits.0._id: "doc_2" } + - match: { hits.hits.1._id: "doc_1" } + + - close_to: { hits.hits.0._explanation.value: { value: 0.4, error: 0.000001 } } + - match: {hits.hits.0._explanation.description: "/text_similarity_reranker.match.using.inference.endpoint:.\\[my-rerank-model\\].on.document.field:.\\[text\\].*/" } + - match: {hits.hits.0._explanation.details.0.description: "/weight.*science.*/" } diff --git a/x-pack/plugin/rank-rrf/build.gradle b/x-pack/plugin/rank-rrf/build.gradle index 2db33fa0f2c8d..2c3f217243aa4 100644 --- a/x-pack/plugin/rank-rrf/build.gradle +++ b/x-pack/plugin/rank-rrf/build.gradle @@ -20,7 +20,11 @@ dependencies { compileOnly project(path: xpackModule('core')) testImplementation(testArtifact(project(xpackModule('core')))) + testImplementation(testArtifact(project(':server'))) clusterModules project(xpackModule('rank-rrf')) + clusterModules project(xpackModule('inference')) clusterModules project(':modules:lang-painless') + + clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java index 500ed17395127..272df248e53e9 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.rank.rrf; import org.apache.lucene.search.Explanation; +import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -169,4 +170,9 @@ protected void doToXContent(XContentBuilder builder, Params params) throws IOExc builder.field("scores", scores); builder.field("rankConstant", rankConstant); } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.RRF_QUERY_REWRITE; + } } diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankDocTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankDocTests.java index 4b64b6c173c92..5548392270a08 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankDocTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankDocTests.java @@ -7,15 +7,17 @@ package org.elasticsearch.xpack.rank.rrf; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable.Reader; -import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.search.rank.AbstractRankDocWireSerializingTestCase; import org.elasticsearch.test.ESTestCase; import java.io.IOException; +import java.util.List; import static org.elasticsearch.xpack.rank.rrf.RRFRankDoc.NO_RANK; -public class RRFRankDocTests extends AbstractWireSerializingTestCase { +public class RRFRankDocTests extends AbstractRankDocWireSerializingTestCase { static RRFRankDoc createTestRRFRankDoc(int queryCount) { RRFRankDoc instance = new RRFRankDoc( @@ -35,9 +37,13 @@ static RRFRankDoc createTestRRFRankDoc(int queryCount) { return instance; } - static RRFRankDoc createTestRRFRankDoc() { - int queryCount = randomIntBetween(2, 20); - return createTestRRFRankDoc(queryCount); + @Override + protected List getAdditionalNamedWriteables() { + try (RRFRankPlugin rrfRankPlugin = new RRFRankPlugin()) { + return rrfRankPlugin.getNamedWriteables(); + } catch (IOException ex) { + throw new AssertionError("Failed to create RRFRankPlugin", ex); + } } @Override @@ -46,8 +52,9 @@ protected Reader instanceReader() { } @Override - protected RRFRankDoc createTestInstance() { - return createTestRRFRankDoc(); + protected RRFRankDoc createTestRankDoc() { + int queryCount = randomIntBetween(2, 20); + return createTestRRFRankDoc(queryCount); } @Override diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankClientYamlTestSuiteIT.java b/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankClientYamlTestSuiteIT.java index 3a577eb62faa3..32b5aedd5d99a 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankClientYamlTestSuiteIT.java +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankClientYamlTestSuiteIT.java @@ -23,7 +23,9 @@ public class RRFRankClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase { .nodes(2) .module("rank-rrf") .module("lang-painless") + .module("x-pack-inference") .setting("xpack.license.self_generated.type", "trial") + .plugin("inference-service-test") .build(); public RRFRankClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/800_rrf_with_text_similarity_reranker_retriever.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/800_rrf_with_text_similarity_reranker_retriever.yml new file mode 100644 index 0000000000000..3e758ae11f7e6 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/800_rrf_with_text_similarity_reranker_retriever.yml @@ -0,0 +1,334 @@ +setup: + - requires: + cluster_features: ['rrf_retriever_composition_supported', 'text_similarity_reranker_retriever_supported'] + reason: need to have support for rrf and semantic reranking composition + test_runner_features: "close_to" + + - do: + inference.put: + task_type: rerank + inference_id: my-rerank-model + body: > + { + "service": "test_reranking_service", + "service_settings": { + "model_id": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + + - do: + indices.create: + index: test-index + body: + settings: + number_of_shards: 1 + mappings: + properties: + text: + type: text + topic: + type: keyword + subtopic: + type: keyword + integer: + type: integer + + - do: + index: + index: test-index + id: doc_1 + body: + text: "Sun Moon Lake is a lake in Nantou County, Taiwan. It is the largest lake in Taiwan." + topic: [ "geography" ] + integer: 1 + + - do: + index: + index: test-index + id: doc_2 + body: + text: "The phases of the Moon come from the position of the Moon relative to the Earth and Sun." + topic: [ "science" ] + subtopic: [ "astronomy" ] + integer: 2 + + - do: + index: + index: test-index + id: doc_3 + body: + text: "As seen from Earth, a solar eclipse happens when the Moon is directly between the Earth and the Sun." + topic: [ "science" ] + subtopic: [ "technology" ] + integer: 3 + + - do: + indices.refresh: {} + +--- +"rrf retriever with a nested text similarity reranker": + + - do: + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "topic" ] + retriever: + rrf: { + retrievers: + [ + { + standard: { + query: { + bool: { + should: + [ + { + constant_score: { + filter: { + term: { + integer: 1 + } + }, + boost: 10 + } + }, + { + constant_score: + { + filter: + { + term: + { + integer: 2 + } + }, + boost: 1 + } + } + ] + } + } + } + }, + { + text_similarity_reranker: { + retriever: + { + standard: { + query: { + term: { + topic: "science" + } + } + } + }, + rank_window_size: 10, + inference_id: my-rerank-model, + inference_text: "How often does the moon hide the sun?", + field: text + } + } + ], + rank_window_size: 10, + rank_constant: 1 + } + size: 10 + from: 1 + aggs: + topics: + terms: + field: topic + size: 10 + + - match: { hits.total.value: 3 } + - length: { hits.hits: 2 } + + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_3" } + + - match: { aggregations.topics.buckets.0.key: "science" } + - match: { aggregations.topics.buckets.0.doc_count: 2 } + - match: { aggregations.topics.buckets.1.key: "geography" } + - match: { aggregations.topics.buckets.1.doc_count: 1 } + +--- +"Text similarity reranker on top of an RRF retriever": + + - do: + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "topic" ] + retriever: + { + text_similarity_reranker: { + retriever: + { + rrf: { + retrievers: + [ + { + standard: { + query: { + bool: { + should: + [ + { + constant_score: { + filter: { + term: { + integer: 1 + } + }, + boost: 10 + } + }, + { + constant_score: + { + filter: + { + term: + { + integer: 3 + } + }, + boost: 1 + } + } + ] + } + } + } + }, + { + standard: { + query: { + term: { + topic: "geography" + } + } + } + } + ], + rank_window_size: 10, + rank_constant: 1 + } + }, + rank_window_size: 10, + inference_id: my-rerank-model, + inference_text: "How often does the moon hide the sun?", + field: text + } + } + size: 10 + aggs: + topics: + terms: + field: topic + size: 10 + + - match: { hits.total.value: 2 } + - length: { hits.hits: 2 } + + - match: { hits.hits.0._id: "doc_3" } + - match: { hits.hits.1._id: "doc_1" } + + - match: { aggregations.topics.buckets.0.key: "geography" } + - match: { aggregations.topics.buckets.0.doc_count: 1 } + - match: { aggregations.topics.buckets.1.key: "science" } + - match: { aggregations.topics.buckets.1.doc_count: 1 } + + +--- +"explain using rrf retriever and text-similarity": + + - do: + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "topic" ] + retriever: + rrf: { + retrievers: + [ + { + standard: { + query: { + bool: { + should: + [ + { + constant_score: { + filter: { + term: { + integer: 1 + } + }, + boost: 10 + } + }, + { + constant_score: + { + filter: + { + term: + { + integer: 2 + } + }, + boost: 1 + } + } + ] + } + } + } + }, + { + text_similarity_reranker: { + retriever: + { + standard: { + query: { + term: { + topic: "science" + } + } + } + }, + rank_window_size: 10, + inference_id: my-rerank-model, + inference_text: "How often does the moon hide the sun?", + field: text + } + } + ], + rank_window_size: 10, + rank_constant: 1 + } + size: 10 + explain: true + + - match: { hits.hits.0._id: "doc_2" } + - match: { hits.hits.1._id: "doc_1" } + - match: { hits.hits.2._id: "doc_3" } + + - close_to: { hits.hits.0._explanation.value: { value: 0.6666667, error: 0.000001 } } + - match: {hits.hits.0._explanation.description: "/rrf.score:.\\[0.6666667\\].*/" } + - match: {hits.hits.0._explanation.details.0.value: 2} + - match: {hits.hits.0._explanation.details.0.description: "/rrf.score:.\\[0.33333334\\].*/" } + - match: {hits.hits.0._explanation.details.0.details.0.details.0.description: "/ConstantScore.*/" } + - match: {hits.hits.0._explanation.details.1.value: 2} + - match: {hits.hits.0._explanation.details.1.description: "/rrf.score:.\\[0.33333334\\].*/" } + - match: {hits.hits.0._explanation.details.1.details.0.description: "/text_similarity_reranker.match.using.inference.endpoint:.\\[my-rerank-model\\].on.document.field:.\\[text\\].*/" } + - match: {hits.hits.0._explanation.details.1.details.0.details.0.description: "/weight.*science.*/" } From c1115d24d74a498c93255ec2c119b4b5f0833935 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Tue, 8 Oct 2024 12:47:08 +0100 Subject: [PATCH 22/22] Add a size limit to outputs from mustache (#114002) (#114297) Backport #114002 to 8.16 --- docs/changelog/114002.yaml | 5 ++ .../script/ScriptServiceBridge.java | 2 +- .../script/mustache/MustachePlugin.java | 2 +- .../script/mustache/MustacheScriptEngine.java | 27 +++++++- .../mustache/CustomMustacheFactoryTests.java | 7 +- .../mustache/MustacheScriptEngineTests.java | 28 +++++++- .../script/mustache/MustacheTests.java | 3 +- .../ingest/AbstractScriptTestCase.java | 2 +- .../common/text/SizeLimitingStringWriter.java | 69 +++++++++++++++++++ .../text/SizeLimitingStringWriterTests.java | 29 ++++++++ .../support/mapper/TemplateRoleNameTests.java | 14 ++-- .../SecurityQueryTemplateEvaluatorTests.java | 3 +- .../WildcardServiceProviderResolverTests.java | 2 +- .../ltr/LearningToRankServiceTests.java | 2 +- .../authc/ldap/ActiveDirectoryRealmTests.java | 2 +- .../security/authc/ldap/LdapRealmTests.java | 2 +- .../mapper/ClusterStateRoleMapperTests.java | 2 +- .../mapper/NativeRoleMappingStoreTests.java | 2 +- .../watcher/support/WatcherTemplateTests.java | 2 +- 19 files changed, 181 insertions(+), 24 deletions(-) create mode 100644 docs/changelog/114002.yaml create mode 100644 server/src/main/java/org/elasticsearch/common/text/SizeLimitingStringWriter.java create mode 100644 server/src/test/java/org/elasticsearch/common/text/SizeLimitingStringWriterTests.java diff --git a/docs/changelog/114002.yaml b/docs/changelog/114002.yaml new file mode 100644 index 0000000000000..b6bc7e25bcdea --- /dev/null +++ b/docs/changelog/114002.yaml @@ -0,0 +1,5 @@ +pr: 114002 +summary: Add a `mustache.max_output_size_bytes` setting to limit the length of results from mustache scripts +area: Infra/Scripting +type: enhancement +issues: [] diff --git a/libs/logstash-bridge/src/main/java/org/elasticsearch/logstashbridge/script/ScriptServiceBridge.java b/libs/logstash-bridge/src/main/java/org/elasticsearch/logstashbridge/script/ScriptServiceBridge.java index caf6e87c1a530..1f7a19e333308 100644 --- a/libs/logstash-bridge/src/main/java/org/elasticsearch/logstashbridge/script/ScriptServiceBridge.java +++ b/libs/logstash-bridge/src/main/java/org/elasticsearch/logstashbridge/script/ScriptServiceBridge.java @@ -53,7 +53,7 @@ private static ScriptService getScriptService(final Settings settings, final Lon PainlessScriptEngine.NAME, new PainlessScriptEngine(settings, scriptContexts), MustacheScriptEngine.NAME, - new MustacheScriptEngine() + new MustacheScriptEngine(settings) ); return new ScriptService(settings, scriptEngines, ScriptModule.CORE_CONTEXTS, timeProvider); } diff --git a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustachePlugin.java b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustachePlugin.java index 64bc2799c24d0..b24d60cb8d887 100644 --- a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustachePlugin.java +++ b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustachePlugin.java @@ -44,7 +44,7 @@ public class MustachePlugin extends Plugin implements ScriptPlugin, ActionPlugin @Override public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) { - return new MustacheScriptEngine(); + return new MustacheScriptEngine(settings); } @Override diff --git a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustacheScriptEngine.java b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustacheScriptEngine.java index ca06a853b1ed6..e7b1727791510 100644 --- a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustacheScriptEngine.java +++ b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustacheScriptEngine.java @@ -14,6 +14,13 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.text.SizeLimitingStringWriter; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.unit.MemorySizeValue; import org.elasticsearch.script.GeneralScriptException; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptContext; @@ -47,6 +54,19 @@ public final class MustacheScriptEngine implements ScriptEngine { public static final String NAME = "mustache"; + public static final Setting MUSTACHE_RESULT_SIZE_LIMIT = new Setting<>( + "mustache.max_output_size_bytes", + s -> "1mb", + s -> MemorySizeValue.parseBytesSizeValueOrHeapRatio(s, "mustache.max_output_size_bytes"), + Setting.Property.NodeScope + ); + + private final int sizeLimit; + + public MustacheScriptEngine(Settings settings) { + sizeLimit = (int) MUSTACHE_RESULT_SIZE_LIMIT.get(settings).getBytes(); + } + /** * Compile a template string to (in this case) a Mustache object than can * later be re-used for execution to fill in missing parameter values. @@ -118,10 +138,15 @@ private class MustacheExecutableScript extends TemplateScript { @Override public String execute() { - final StringWriter writer = new StringWriter(); + StringWriter writer = new SizeLimitingStringWriter(sizeLimit); try { template.execute(writer, params); } catch (Exception e) { + // size limit exception can appear at several places in the causal list depending on script & context + if (ExceptionsHelper.unwrap(e, SizeLimitingStringWriter.SizeLimitExceededException.class) != null) { + // don't log, client problem + throw new ElasticsearchParseException("Mustache script result size limit exceeded", e); + } if (shouldLogException(e)) { logger.error(() -> format("Error running %s", template), e); } diff --git a/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/CustomMustacheFactoryTests.java b/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/CustomMustacheFactoryTests.java index 014d6854121b6..eb9c1a6dc3031 100644 --- a/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/CustomMustacheFactoryTests.java +++ b/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/CustomMustacheFactoryTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.script.mustache; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptEngine; import org.elasticsearch.script.TemplateScript; @@ -65,7 +66,7 @@ public void testCreateEncoder() { } public void testJsonEscapeEncoder() { - final ScriptEngine engine = new MustacheScriptEngine(); + final ScriptEngine engine = new MustacheScriptEngine(Settings.EMPTY); final Map params = randomBoolean() ? Map.of(Script.CONTENT_TYPE_OPTION, JSON_MEDIA_TYPE) : Map.of(); TemplateScript.Factory compiled = engine.compile(null, "{\"field\": \"{{value}}\"}", TemplateScript.CONTEXT, params); @@ -75,7 +76,7 @@ public void testJsonEscapeEncoder() { } public void testDefaultEncoder() { - final ScriptEngine engine = new MustacheScriptEngine(); + final ScriptEngine engine = new MustacheScriptEngine(Settings.EMPTY); final Map params = Map.of(Script.CONTENT_TYPE_OPTION, PLAIN_TEXT_MEDIA_TYPE); TemplateScript.Factory compiled = engine.compile(null, "{\"field\": \"{{value}}\"}", TemplateScript.CONTEXT, params); @@ -85,7 +86,7 @@ public void testDefaultEncoder() { } public void testUrlEncoder() { - final ScriptEngine engine = new MustacheScriptEngine(); + final ScriptEngine engine = new MustacheScriptEngine(Settings.EMPTY); final Map params = Map.of(Script.CONTENT_TYPE_OPTION, X_WWW_FORM_URLENCODED_MEDIA_TYPE); TemplateScript.Factory compiled = engine.compile(null, "{\"field\": \"{{value}}\"}", TemplateScript.CONTEXT, params); diff --git a/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/MustacheScriptEngineTests.java b/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/MustacheScriptEngineTests.java index 089a154079a83..bc1cd30ad45bf 100644 --- a/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/MustacheScriptEngineTests.java +++ b/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/MustacheScriptEngineTests.java @@ -8,8 +8,13 @@ */ package org.elasticsearch.script.mustache; +import com.github.mustachejava.MustacheException; import com.github.mustachejava.MustacheFactory; +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.text.SizeLimitingStringWriter; import org.elasticsearch.script.GeneralScriptException; import org.elasticsearch.script.Script; import org.elasticsearch.script.TemplateScript; @@ -24,6 +29,9 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.test.LambdaMatchers.transformedMatch; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.startsWith; @@ -37,7 +45,7 @@ public class MustacheScriptEngineTests extends ESTestCase { @Before public void setup() { - qe = new MustacheScriptEngine(); + qe = new MustacheScriptEngine(Settings.builder().put(MustacheScriptEngine.MUSTACHE_RESULT_SIZE_LIMIT.getKey(), "1kb").build()); factory = CustomMustacheFactory.builder().build(); } @@ -402,6 +410,24 @@ public void testEscapeJson() throws IOException { } } + public void testResultSizeLimit() throws IOException { + String vals = "\"" + "{{val}}".repeat(200) + "\""; + String params = "\"val\":\"aaaaaaaaaa\""; + XContentParser parser = createParser(JsonXContent.jsonXContent, Strings.format("{\"source\":%s,\"params\":{%s}}", vals, params)); + Script script = Script.parse(parser); + var compiled = qe.compile(null, script.getIdOrCode(), TemplateScript.CONTEXT, Map.of()); + TemplateScript templateScript = compiled.newInstance(script.getParams()); + var ex = expectThrows(ElasticsearchParseException.class, templateScript::execute); + assertThat(ex.getCause(), instanceOf(MustacheException.class)); + assertThat( + ex.getCause().getCause(), + allOf( + instanceOf(SizeLimitingStringWriter.SizeLimitExceededException.class), + transformedMatch(Throwable::getMessage, endsWith("has exceeded the size limit [1024]")) + ) + ); + } + private String getChars() { String string = randomRealisticUnicodeOfCodepointLengthBetween(0, 10); for (int i = 0; i < string.length(); i++) { diff --git a/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/MustacheTests.java b/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/MustacheTests.java index 82c4637f600fe..335cfe91df87d 100644 --- a/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/MustacheTests.java +++ b/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/MustacheTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.script.mustache; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Strings; import org.elasticsearch.script.ScriptEngine; @@ -39,7 +40,7 @@ public class MustacheTests extends ESTestCase { - private ScriptEngine engine = new MustacheScriptEngine(); + private ScriptEngine engine = new MustacheScriptEngine(Settings.EMPTY); public void testBasics() { String template = """ diff --git a/qa/smoke-test-ingest-with-all-dependencies/src/yamlRestTest/java/org/elasticsearch/ingest/AbstractScriptTestCase.java b/qa/smoke-test-ingest-with-all-dependencies/src/yamlRestTest/java/org/elasticsearch/ingest/AbstractScriptTestCase.java index ff6e75bed047f..8d877bd48c1e3 100644 --- a/qa/smoke-test-ingest-with-all-dependencies/src/yamlRestTest/java/org/elasticsearch/ingest/AbstractScriptTestCase.java +++ b/qa/smoke-test-ingest-with-all-dependencies/src/yamlRestTest/java/org/elasticsearch/ingest/AbstractScriptTestCase.java @@ -31,7 +31,7 @@ public abstract class AbstractScriptTestCase extends ESTestCase { @Before public void init() throws Exception { - MustacheScriptEngine engine = new MustacheScriptEngine(); + MustacheScriptEngine engine = new MustacheScriptEngine(Settings.EMPTY); Map engines = Collections.singletonMap(engine.getType(), engine); scriptService = new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS, () -> 1L); } diff --git a/server/src/main/java/org/elasticsearch/common/text/SizeLimitingStringWriter.java b/server/src/main/java/org/elasticsearch/common/text/SizeLimitingStringWriter.java new file mode 100644 index 0000000000000..2df7e6537c609 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/common/text/SizeLimitingStringWriter.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.common.text; + +import org.elasticsearch.common.Strings; + +import java.io.StringWriter; + +/** + * A {@link StringWriter} that throws an exception if the string exceeds a specified size. + */ +public class SizeLimitingStringWriter extends StringWriter { + + public static class SizeLimitExceededException extends IllegalStateException { + public SizeLimitExceededException(String message) { + super(message); + } + } + + private final int sizeLimit; + + public SizeLimitingStringWriter(int sizeLimit) { + this.sizeLimit = sizeLimit; + } + + private void checkSizeLimit(int additionalChars) { + int bufLen = getBuffer().length(); + if (bufLen + additionalChars > sizeLimit) { + throw new SizeLimitExceededException( + Strings.format("String [%s...] has exceeded the size limit [%s]", getBuffer().substring(0, Math.min(bufLen, 20)), sizeLimit) + ); + } + } + + @Override + public void write(int c) { + checkSizeLimit(1); + super.write(c); + } + + // write(char[]) delegates to write(char[], int, int) + + @Override + public void write(char[] cbuf, int off, int len) { + checkSizeLimit(len); + super.write(cbuf, off, len); + } + + @Override + public void write(String str) { + checkSizeLimit(str.length()); + super.write(str); + } + + @Override + public void write(String str, int off, int len) { + checkSizeLimit(len); + super.write(str, off, len); + } + + // append(...) delegates to write(...) methods +} diff --git a/server/src/test/java/org/elasticsearch/common/text/SizeLimitingStringWriterTests.java b/server/src/test/java/org/elasticsearch/common/text/SizeLimitingStringWriterTests.java new file mode 100644 index 0000000000000..32a8de20df9aa --- /dev/null +++ b/server/src/test/java/org/elasticsearch/common/text/SizeLimitingStringWriterTests.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.common.text; + +import org.elasticsearch.test.ESTestCase; + +public class SizeLimitingStringWriterTests extends ESTestCase { + public void testSizeIsLimited() { + SizeLimitingStringWriter writer = new SizeLimitingStringWriter(10); + + writer.write("a".repeat(10)); + + // test all the methods + expectThrows(SizeLimitingStringWriter.SizeLimitExceededException.class, () -> writer.write('a')); + expectThrows(SizeLimitingStringWriter.SizeLimitExceededException.class, () -> writer.write("a")); + expectThrows(SizeLimitingStringWriter.SizeLimitExceededException.class, () -> writer.write(new char[1])); + expectThrows(SizeLimitingStringWriter.SizeLimitExceededException.class, () -> writer.write(new char[1], 0, 1)); + expectThrows(SizeLimitingStringWriter.SizeLimitExceededException.class, () -> writer.append('a')); + expectThrows(SizeLimitingStringWriter.SizeLimitExceededException.class, () -> writer.append("a")); + expectThrows(SizeLimitingStringWriter.SizeLimitExceededException.class, () -> writer.append("a", 0, 1)); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authc/support/mapper/TemplateRoleNameTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authc/support/mapper/TemplateRoleNameTests.java index 195e126662481..05044303561d8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authc/support/mapper/TemplateRoleNameTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authc/support/mapper/TemplateRoleNameTests.java @@ -89,7 +89,7 @@ public void testEqualsAndHashCode() throws Exception { public void testEvaluateRoles() throws Exception { final ScriptService scriptService = new ScriptService( Settings.EMPTY, - Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine()), + Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine(Settings.EMPTY)), ScriptModule.CORE_CONTEXTS, () -> 1L ); @@ -145,7 +145,7 @@ public void tryEquals(TemplateRoleName original) { public void testValidate() { final ScriptService scriptService = new ScriptService( Settings.EMPTY, - Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine()), + Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine(Settings.EMPTY)), ScriptModule.CORE_CONTEXTS, () -> 1L ); @@ -173,7 +173,7 @@ public void testValidate() { public void testValidateWillPassWithEmptyContext() { final ScriptService scriptService = new ScriptService( Settings.EMPTY, - Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine()), + Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine(Settings.EMPTY)), ScriptModule.CORE_CONTEXTS, () -> 1L ); @@ -204,7 +204,7 @@ public void testValidateWillPassWithEmptyContext() { public void testValidateWillFailForSyntaxError() { final ScriptService scriptService = new ScriptService( Settings.EMPTY, - Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine()), + Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine(Settings.EMPTY)), ScriptModule.CORE_CONTEXTS, () -> 1L ); @@ -268,7 +268,7 @@ public void testValidationWillFailWhenInlineScriptIsNotEnabled() { final Settings settings = Settings.builder().put("script.allowed_types", ScriptService.ALLOW_NONE).build(); final ScriptService scriptService = new ScriptService( settings, - Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine()), + Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine(Settings.EMPTY)), ScriptModule.CORE_CONTEXTS, () -> 1L ); @@ -285,7 +285,7 @@ public void testValidateWillFailWhenStoredScriptIsNotEnabled() { final Settings settings = Settings.builder().put("script.allowed_types", ScriptService.ALLOW_NONE).build(); final ScriptService scriptService = new ScriptService( settings, - Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine()), + Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine(Settings.EMPTY)), ScriptModule.CORE_CONTEXTS, () -> 1L ); @@ -314,7 +314,7 @@ public void testValidateWillFailWhenStoredScriptIsNotEnabled() { public void testValidateWillFailWhenStoredScriptIsNotFound() { final ScriptService scriptService = new ScriptService( Settings.EMPTY, - Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine()), + Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine(Settings.EMPTY)), ScriptModule.CORE_CONTEXTS, () -> 1L ); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/support/SecurityQueryTemplateEvaluatorTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/support/SecurityQueryTemplateEvaluatorTests.java index fed11c75715b5..58ae99df60bcd 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/support/SecurityQueryTemplateEvaluatorTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/support/SecurityQueryTemplateEvaluatorTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.security.authz.support; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptService; @@ -94,7 +95,7 @@ public void testDocLevelSecurityTemplateWithOpenIdConnectStyleMetadata() throws true ); - final MustacheScriptEngine mustache = new MustacheScriptEngine(); + final MustacheScriptEngine mustache = new MustacheScriptEngine(Settings.EMPTY); when(scriptService.compile(any(Script.class), eq(TemplateScript.CONTEXT))).thenAnswer(inv -> { assertThat(inv.getArguments(), arrayWithSize(2)); diff --git a/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/sp/WildcardServiceProviderResolverTests.java b/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/sp/WildcardServiceProviderResolverTests.java index 70e5325878c0a..832a6e8163aca 100644 --- a/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/sp/WildcardServiceProviderResolverTests.java +++ b/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/sp/WildcardServiceProviderResolverTests.java @@ -95,7 +95,7 @@ public void setUpResolver() { final Settings settings = Settings.EMPTY; final ScriptService scriptService = new ScriptService( settings, - Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine()), + Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine(Settings.EMPTY)), ScriptModule.CORE_CONTEXTS, () -> 1L ); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankServiceTests.java index 6ca9ae4296789..46e54ff3f8c3d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankServiceTests.java @@ -241,7 +241,7 @@ private LearningToRankService getTestLearningToRankService(TrainedModelProvider } private ScriptService getTestScriptService() { - ScriptEngine scriptEngine = new MustacheScriptEngine(); + ScriptEngine scriptEngine = new MustacheScriptEngine(Settings.EMPTY); return new ScriptService(Settings.EMPTY, Map.of(DEFAULT_TEMPLATE_LANG, scriptEngine), ScriptModule.CORE_CONTEXTS, () -> 1L); } } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/ActiveDirectoryRealmTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/ActiveDirectoryRealmTests.java index b0821864aacc7..e72bbd77697c2 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/ActiveDirectoryRealmTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/ActiveDirectoryRealmTests.java @@ -434,7 +434,7 @@ public void testRealmWithTemplatedRoleMapping() throws Exception { final ScriptService scriptService = new ScriptService( settings, - Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine()), + Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine(Settings.EMPTY)), ScriptModule.CORE_CONTEXTS, () -> 1L ); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/LdapRealmTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/LdapRealmTests.java index 7083d1301a3e6..83e2ff88f9dc3 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/LdapRealmTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/LdapRealmTests.java @@ -533,7 +533,7 @@ public void testLdapRealmWithTemplatedRoleMapping() throws Exception { final ScriptService scriptService = new ScriptService( defaultGlobalSettings, - Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine()), + Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine(Settings.EMPTY)), ScriptModule.CORE_CONTEXTS, () -> 1L ); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/mapper/ClusterStateRoleMapperTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/mapper/ClusterStateRoleMapperTests.java index 515b5ef741a00..063245e004476 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/mapper/ClusterStateRoleMapperTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/mapper/ClusterStateRoleMapperTests.java @@ -51,7 +51,7 @@ public class ClusterStateRoleMapperTests extends ESTestCase { public void setup() { scriptService = new ScriptService( Settings.EMPTY, - Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine()), + Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine(Settings.EMPTY)), ScriptModule.CORE_CONTEXTS, () -> 1L ); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/mapper/NativeRoleMappingStoreTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/mapper/NativeRoleMappingStoreTests.java index 2a084bacfaf76..38f01d4d18bc7 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/mapper/NativeRoleMappingStoreTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/mapper/NativeRoleMappingStoreTests.java @@ -87,7 +87,7 @@ public class NativeRoleMappingStoreTests extends ESTestCase { public void setup() { scriptService = new ScriptService( Settings.EMPTY, - Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine()), + Collections.singletonMap(MustacheScriptEngine.NAME, new MustacheScriptEngine(Settings.EMPTY)), ScriptModule.CORE_CONTEXTS, () -> 1L ); diff --git a/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/support/WatcherTemplateTests.java b/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/support/WatcherTemplateTests.java index 5ddff34a0ac45..3018afbe97338 100644 --- a/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/support/WatcherTemplateTests.java +++ b/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/support/WatcherTemplateTests.java @@ -37,7 +37,7 @@ public class WatcherTemplateTests extends ESTestCase { @Before public void init() throws Exception { - MustacheScriptEngine engine = new MustacheScriptEngine(); + MustacheScriptEngine engine = new MustacheScriptEngine(Settings.EMPTY); Map engines = Collections.singletonMap(engine.getType(), engine); Map> contexts = Collections.singletonMap( Watcher.SCRIPT_TEMPLATE_CONTEXT.name,