Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NLP] Add prefix_string config option to the import model hub script #642

Merged
merged 2 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions eland/cli/eland_import_hub_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,19 @@ def get_arg_parser():
"--ca-certs", required=False, default=DEFAULT, help="Path to CA bundle"
)

parser.add_argument(
"--ingest-prefix",
required=False,
default=None,
help="String to prepend to model input at ingest",
)
parser.add_argument(
"--search-prefix",
required=False,
default=None,
help="String to prepend to model input at search",
)

return parser


Expand Down Expand Up @@ -244,6 +257,8 @@ def main():
task_type=args.task_type,
es_version=cluster_version,
quantize=args.quantize,
ingest_prefix=args.ingest_prefix,
search_prefix=args.search_prefix,
)
model_path, config, vocab_path = tm.save(tmp_dir)
except TaskTypeError as err:
Expand Down
19 changes: 19 additions & 0 deletions eland/ml/pytorch/nlp_ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,23 @@ def to_dict(self) -> t.Dict[str, t.Any]:
return self.__dict__


class PrefixStrings:
def __init__(
self, *, ingest_prefix: t.Optional[str], search_prefix: t.Optional[str]
):
self.ingest_prefix = ingest_prefix
self.search_prefix = search_prefix

def to_dict(self) -> t.Dict[str, t.Any]:
config = {}
if self.ingest_prefix is not None:
config["ingest"] = self.ingest_prefix
if self.search_prefix is not None:
config["search"] = self.search_prefix

return config


class NlpTrainedModelConfig:
def __init__(
self,
Expand All @@ -318,13 +335,15 @@ def __init__(
metadata: t.Optional[dict] = None,
model_type: t.Union["t.Literal['pytorch']", str] = "pytorch",
tags: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None,
prefix_strings: t.Optional[PrefixStrings],
):
self.tags = tags
self.description = description
self.inference_config = inference_config
self.input = input
self.metadata = metadata
self.model_type = model_type
self.prefix_strings = prefix_strings

def to_dict(self) -> t.Dict[str, t.Any]:
return {
Expand Down
28 changes: 28 additions & 0 deletions eland/ml/pytorch/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
NlpTrainedModelConfig,
NlpXLMRobertaTokenizationConfig,
PassThroughInferenceOptions,
PrefixStrings,
QuestionAnsweringInferenceOptions,
TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions,
Expand Down Expand Up @@ -596,6 +597,8 @@ def __init__(
es_version: Optional[Tuple[int, int, int]] = None,
quantize: bool = False,
access_token: Optional[str] = None,
ingest_prefix: Optional[str] = None,
search_prefix: Optional[str] = None,
):
"""
Loads a model from the Hugging Face repository or local file and creates
Expand All @@ -618,11 +621,22 @@ def __init__(

quantize: bool, default False
Quantize the model.

access_token: Optional[str]
For the HuggingFace Hub private model access

ingest_prefix: Optional[str]
Prefix string to prepend to input at ingest

search_prefix: Optional[str]
Prefix string to prepend to input at search
"""

self._model_id = model_id
self._access_token = access_token
self._task_type = task_type.replace("-", "_")
self._ingest_prefix = ingest_prefix
self._search_prefix = search_prefix

# load Hugging Face model and tokenizer
# use padding in the tokenizer to ensure max length sequences are used for tracing (at call time)
Expand Down Expand Up @@ -783,6 +797,19 @@ def _create_config(
"per_allocation_memory_bytes": per_allocation_memory_bytes,
}

prefix_strings = (
PrefixStrings(
ingest_prefix=self._ingest_prefix, search_prefix=self._search_prefix
)
if self._ingest_prefix or self._search_prefix
else None
)
prefix_strings_supported = es_version is None or es_version >= (8, 12, 0)
if not prefix_strings_supported and prefix_strings:
raise Exception(
f"The Elasticsearch cluster version {es_version} does not support prefix strings. Support was added in version 8.12.0"
)

return NlpTrainedModelConfig(
description=f"Model {self._model_id} for task type '{self._task_type}'",
model_type="pytorch",
Expand All @@ -791,6 +818,7 @@ def _create_config(
field_names=["text_field"],
),
metadata=metadata,
prefix_strings=prefix_strings,
)

def _get_per_deployment_memory(self) -> float:
Expand Down
31 changes: 28 additions & 3 deletions tests/ml/pytorch/test_pytorch_model_config_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@
MODEL_CONFIGURATIONS = []


@pytest.mark.skip(reason="https://github.com/elastic/eland/issues/633")
class TestModelConfguration:
@pytest.mark.skip(reason="https://github.com/elastic/eland/issues/633")
@pytest.mark.parametrize(
"model_id,task_type,config_type,tokenizer_type,max_sequence_len,embedding_size",
MODEL_CONFIGURATIONS,
)
def test_text_prediction(
def test_model_config(
self,
model_id,
task_type,
Expand All @@ -170,7 +170,6 @@ def test_text_prediction(
embedding_size,
):
with tempfile.TemporaryDirectory() as tmp_dir:
print("loading model " + model_id)
tm = TransformerModel(
model_id=model_id,
task_type=task_type,
Expand All @@ -183,6 +182,7 @@ def test_text_prediction(
assert isinstance(config.inference_config, config_type)
tokenization = config.inference_config.tokenization
assert isinstance(config.metadata, dict)
assert config.prefix_strings is None
assert (
"per_deployment_memory_bytes" in config.metadata
and config.metadata["per_deployment_memory_bytes"] > 0
Expand Down Expand Up @@ -210,3 +210,28 @@ def test_text_prediction(
assert len(config.inference_config.classification_labels) > 0

del tm

def test_model_config_with_prefix_string(self):
with tempfile.TemporaryDirectory() as tmp_dir:
tm = TransformerModel(
model_id="sentence-transformers/all-distilroberta-v1",
task_type="text_embedding",
es_version=(8, 12, 0),
quantize=False,
ingest_prefix="INGEST:",
search_prefix="SEARCH:",
)
_, config, _ = tm.save(tmp_dir)
assert config.prefix_strings.to_dict()["ingest"] == "INGEST:"
assert config.prefix_strings.to_dict()["search"] == "SEARCH:"

def test_model_config_with_prefix_string_not_supported(self):
with pytest.raises(Exception):
TransformerModel(
model_id="sentence-transformers/all-distilroberta-v1",
task_type="text_embedding",
es_version=(8, 11, 0),
quantize=False,
ingest_prefix="INGEST:",
search_prefix="SEARCH:",
)