Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Add embedding_size to text embedding config #95176

Merged
merged 3 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/95176.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 95176
summary: Add `embedding_size` to text embedding config
area: Machine Learning
type: enhancement
issues: []
4 changes: 4 additions & 0 deletions docs/reference/ml/ml-shared.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,10 @@ context. These embeddings can be used in a <<dense-vector,dense vector>> field
for powerful insights.
end::inference-config-text-embedding[]

tag::inference-config-text-embedding-size[]
The number of dimensions in the embedding vector produced by the model.
end::inference-config-text-embedding-size[]

tag::inference-config-text-similarity[]
Text similarity takes an input sequence and compares it with another input sequence. This is commonly referred to
as cross-encoding. This task is useful for ranking document text when comparing it to another provided text input.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,14 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-text-embedding
.Properties of text_embedding inference
[%collapsible%open]
======
`embedding_size`::::
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-text-embedding-size]

`results_field`::::
(Optional, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-results-field]

`tokenization`::::
(Optional, object)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,10 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-text-embedding
.Properties of text_embedding inference
[%collapsible%open]
=====
`embedding_size`::::
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-text-embedding-size]

`results_field`::::
(Optional, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-results-field]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
Expand All @@ -27,6 +28,8 @@ public class TextEmbeddingConfig implements NlpConfig {

public static final String NAME = "text_embedding";

public static ParseField EMBEDDING_SIZE = new ParseField("embedding_size");

public static TextEmbeddingConfig fromXContentStrict(XContentParser parser) {
return STRICT_PARSER.apply(parser, null);
}
Expand All @@ -42,7 +45,7 @@ private static ConstructingObjectParser<TextEmbeddingConfig, Void> createParser(
ConstructingObjectParser<TextEmbeddingConfig, Void> parser = new ConstructingObjectParser<>(
NAME,
ignoreUnknownFields,
a -> new TextEmbeddingConfig((VocabularyConfig) a[0], (Tokenization) a[1], (String) a[2])
a -> new TextEmbeddingConfig((VocabularyConfig) a[0], (Tokenization) a[1], (String) a[2], (Integer) a[3])
);
parser.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> {
if (ignoreUnknownFields == false) {
Expand All @@ -59,22 +62,33 @@ private static ConstructingObjectParser<TextEmbeddingConfig, Void> createParser(
TOKENIZATION
);
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), RESULTS_FIELD);
parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), EMBEDDING_SIZE);
return parser;
}

private final VocabularyConfig vocabularyConfig;
private final Tokenization tokenization;
private final String resultsField;
private final Integer embeddingSize;

public TextEmbeddingConfig(
@Nullable VocabularyConfig vocabularyConfig,
@Nullable Tokenization tokenization,
@Nullable String resultsField
@Nullable String resultsField,
@Nullable Integer embeddingSize
) {
this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
this.resultsField = resultsField;
if (embeddingSize != null && embeddingSize <= 0) {
throw ExceptionsHelper.badRequestException(
"[{}] must be a number greater than 0; configured size [{}]",
EMBEDDING_SIZE.getPreferredName(),
embeddingSize
);
}
this.embeddingSize = embeddingSize;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is serialized as strictly non-negative there should be a check in the public constructor that a negative number hasn't been supplied.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 I've added that check to the ctor, it makes no sense for the size to be <= 0

commit: 1fbbdde

if (this.tokenization.span != -1) {
throw ExceptionsHelper.badRequestException(
"[{}] does not support windowing long text sequences; configured span [{}]",
Expand All @@ -88,6 +102,11 @@ public TextEmbeddingConfig(StreamInput in) throws IOException {
vocabularyConfig = new VocabularyConfig(in);
tokenization = in.readNamedWriteable(Tokenization.class);
resultsField = in.readOptionalString();
if (in.getTransportVersion().onOrAfter(TransportVersion.V_8_8_0)) {
embeddingSize = in.readOptionalVInt();
} else {
embeddingSize = null;
}
}

@Override
Expand All @@ -98,6 +117,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (resultsField != null) {
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
}
if (embeddingSize != null) {
builder.field(EMBEDDING_SIZE.getPreferredName(), embeddingSize);
}
builder.endObject();
return builder;
}
Expand All @@ -112,6 +134,9 @@ public void writeTo(StreamOutput out) throws IOException {
vocabularyConfig.writeTo(out);
out.writeNamedWriteable(tokenization);
out.writeOptionalString(resultsField);
if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_8_0)) {
out.writeOptionalVInt(embeddingSize);
}
}

@Override
Expand Down Expand Up @@ -147,12 +172,13 @@ public boolean equals(Object o) {
TextEmbeddingConfig that = (TextEmbeddingConfig) o;
return Objects.equals(vocabularyConfig, that.vocabularyConfig)
&& Objects.equals(tokenization, that.tokenization)
&& Objects.equals(resultsField, that.resultsField);
&& Objects.equals(resultsField, that.resultsField)
&& Objects.equals(embeddingSize, that.embeddingSize);
}

@Override
public int hashCode() {
return Objects.hash(vocabularyConfig, tokenization, resultsField);
return Objects.hash(vocabularyConfig, tokenization, resultsField, embeddingSize);
}

@Override
Expand All @@ -169,4 +195,8 @@ public Tokenization getTokenization() {
public String getResultsField() {
return resultsField;
}

public Integer getEmbeddingSize() {
return embeddingSize;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ public InferenceConfig apply(InferenceConfig originalConfig) {
return new TextEmbeddingConfig(
embeddingConfig.getVocabularyConfig(),
tokenizationUpdate == null ? embeddingConfig.getTokenization() : tokenizationUpdate.apply(embeddingConfig.getTokenization()),
resultsField == null ? embeddingConfig.getResultsField() : resultsField
resultsField == null ? embeddingConfig.getResultsField() : resultsField,
embeddingConfig.getEmbeddingSize()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.XContentParser;
Expand All @@ -18,11 +19,21 @@
public class TextEmbeddingConfigTests extends InferenceConfigItemTestCase<TextEmbeddingConfig> {

public static TextEmbeddingConfig mutateForVersion(TextEmbeddingConfig instance, TransportVersion version) {
return new TextEmbeddingConfig(
instance.getVocabularyConfig(),
InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version),
instance.getResultsField()
);
if (version.before(TransportVersion.V_8_8_0)) {
return new TextEmbeddingConfig(
instance.getVocabularyConfig(),
InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version),
instance.getResultsField(),
null
);
} else {
return new TextEmbeddingConfig(
instance.getVocabularyConfig(),
InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version),
instance.getResultsField(),
instance.getEmbeddingSize()
);
}
}

@Override
Expand Down Expand Up @@ -60,6 +71,18 @@ protected TextEmbeddingConfig mutateInstanceForVersion(TextEmbeddingConfig insta
return mutateForVersion(instance, version);
}

public void testInvariants() {
ElasticsearchStatusException e = expectThrows(
ElasticsearchStatusException.class,
() -> new TextEmbeddingConfig(null, BertTokenizationTests.createRandom(), null, 0)
);
assertEquals("[embedding_size] must be a number greater than 0; configured size [0]", e.getMessage());

var invalidTokenization = new BertTokenization(true, true, 512, Tokenization.Truncate.NONE, 128);
e = expectThrows(ElasticsearchStatusException.class, () -> new TextEmbeddingConfig(null, invalidTokenization, null, 200));
assertEquals("[text_embedding] does not support windowing long text sequences; configured span [128]", e.getMessage());
}

public static TextEmbeddingConfig createRandom() {
return new TextEmbeddingConfig(
randomBoolean() ? null : VocabularyConfigTests.createRandom(),
Expand All @@ -70,7 +93,8 @@ public static TextEmbeddingConfig createRandom() {
MPNetTokenizationTests.createRandom(),
RobertaTokenizationTests.createRandom()
),
randomBoolean() ? null : randomAlphaOfLength(7)
randomBoolean() ? null : randomAlphaOfLength(7),
randomBoolean() ? null : randomIntBetween(1, 1000)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,24 @@ public void testApply() {
assertThat(originalConfig, sameInstance(new TextEmbeddingConfigUpdate.Builder().build().apply(originalConfig)));

assertThat(
new TextEmbeddingConfig(originalConfig.getVocabularyConfig(), originalConfig.getTokenization(), "ml-results"),
new TextEmbeddingConfig(
originalConfig.getVocabularyConfig(),
originalConfig.getTokenization(),
"ml-results",
originalConfig.getEmbeddingSize()
),
equalTo(new TextEmbeddingConfigUpdate.Builder().setResultsField("ml-results").build().apply(originalConfig))
);

Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
Tokenization tokenization = cloneWithNewTruncation(originalConfig.getTokenization(), truncate);
assertThat(
new TextEmbeddingConfig(originalConfig.getVocabularyConfig(), tokenization, originalConfig.getResultsField()),
new TextEmbeddingConfig(
originalConfig.getVocabularyConfig(),
tokenization,
originalConfig.getResultsField(),
originalConfig.getEmbeddingSize()
),
equalTo(
new TextEmbeddingConfigUpdate.Builder().setTokenizationUpdate(
createTokenizationUpdate(originalConfig.getTokenization(), truncate, null)
Expand Down