Skip to content

Commit

Permalink
Allow importing private HuggingFace models (#608)
Browse files Browse the repository at this point in the history
  • Loading branch information
pquentin authored Sep 25, 2023
1 parent 5ec7606 commit 566bb9e
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 22 deletions.
15 changes: 12 additions & 3 deletions docs/guide/machine-learning.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ underscores `__`.

The following authentication options are available when using the import script:

* username and password authentication (specified with the `-u` and `-p` options):
* Elasticsearch username and password authentication (specified with the `-u` and `-p` options):
+
--
[source,bash]
Expand All @@ -170,7 +170,7 @@ eland_import_hub_model -u <username> -p <password> --cloud-id <cloud-id> ...
These `-u` and `-p` options also work when you use `--url`.
--

* username and password authentication (embedded in the URL):
* Elasticsearch username and password authentication (embedded in the URL):
+
--
[source,bash]
Expand All @@ -179,11 +179,20 @@ eland_import_hub_model --url https://<user>:<password>@<hostname>:<port> ...
--------------------------------------------------
--

* API key authentication:
* Elasticsearch API key authentication:
+
--
[source,bash]
--------------------------------------------------
eland_import_hub_model --es-api-key <api-key> --url https://<hostname>:<port> ...
--------------------------------------------------
--

* HuggingFace Hub access token (for private models):
+
--
[source,bash]
--------------------------------------------------
eland_import_hub_model --hub-access-token <access-token> ...
--------------------------------------------------
--
7 changes: 7 additions & 0 deletions eland/cli/eland_import_hub_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ def get_arg_parser():
help="The model ID in the Hugging Face model hub, "
"e.g. dbmdz/bert-large-cased-finetuned-conll03-english",
)
parser.add_argument(
"--hub-access-token",
required=False,
default=os.environ.get("HUB_ACCESS_TOKEN"),
help="The Hugging Face access token, needed to access private models",
)
parser.add_argument(
"--es-model-id",
required=False,
Expand Down Expand Up @@ -234,6 +240,7 @@ def main():
try:
tm = TransformerModel(
model_id=args.hub_model_id,
access_token=args.hub_access_token,
task_type=args.task_type,
es_version=cluster_version,
quantize=args.quantize,
Expand Down
46 changes: 27 additions & 19 deletions eland/ml/pytorch/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ def __init__(self, model: PreTrainedModel):
self.config = model.config

@staticmethod
def from_pretrained(model_id: str) -> Optional[Any]:
def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]:
model = AutoModelForQuestionAnswering.from_pretrained(
model_id, torchscript=True
model_id, token=token, torchscript=True
)
if isinstance(
model.config,
Expand Down Expand Up @@ -292,9 +292,12 @@ def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY)

@staticmethod
def from_pretrained(
model_id: str, output_key: str = DEFAULT_OUTPUT_KEY
model_id: str,
*,
token: Optional[str] = None,
output_key: str = DEFAULT_OUTPUT_KEY,
) -> Optional[Any]:
model = AutoModel.from_pretrained(model_id, torchscript=True)
model = AutoModel.from_pretrained(model_id, token=token, torchscript=True)
if isinstance(
model.config,
(
Expand Down Expand Up @@ -393,8 +396,8 @@ def __init__(
self.config = model.config

@staticmethod
def from_pretrained(model_id: str) -> Optional[Any]:
config = AutoConfig.from_pretrained(model_id)
def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]:
config = AutoConfig.from_pretrained(model_id, token=token)

def is_compatible() -> bool:
is_dpr_model = config.model_type == "dpr"
Expand Down Expand Up @@ -579,9 +582,10 @@ def _prepare_inputs(self) -> transformers.BatchEncoding:
class TransformerModel:
def __init__(
self,
*,
model_id: str,
access_token: Optional[str],
task_type: str,
*,
es_version: Optional[Tuple[int, int, int]] = None,
quantize: bool = False,
):
Expand Down Expand Up @@ -609,14 +613,14 @@ def __init__(
"""

self._model_id = model_id
self._access_token = access_token
self._task_type = task_type.replace("-", "_")

# load Hugging Face model and tokenizer
# use padding in the tokenizer to ensure max length sequences are used for tracing (at call time)
# - see: https://huggingface.co/transformers/serialization.html#dummy-inputs-and-standard-lengths
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
self._model_id,
use_fast=False,
self._model_id, token=self._access_token, use_fast=False
)

# check for a supported tokenizer
Expand Down Expand Up @@ -755,7 +759,7 @@ def _create_config(
def _create_traceable_model(self) -> TraceableModel:
if self._task_type == "auto":
model = transformers.AutoModel.from_pretrained(
self._model_id, torchscript=True
self._model_id, token=self._access_token, torchscript=True
)
maybe_task_type = task_type_from_model_config(model.config)
if maybe_task_type is None:
Expand All @@ -767,54 +771,58 @@ def _create_traceable_model(self) -> TraceableModel:

if self._task_type == "fill_mask":
model = transformers.AutoModelForMaskedLM.from_pretrained(
self._model_id, torchscript=True
self._model_id, token=self._access_token, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableFillMaskModel(self._tokenizer, model)

elif self._task_type == "ner":
model = transformers.AutoModelForTokenClassification.from_pretrained(
self._model_id, torchscript=True
self._model_id, token=self._access_token, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableNerModel(self._tokenizer, model)

elif self._task_type == "text_classification":
model = transformers.AutoModelForSequenceClassification.from_pretrained(
self._model_id, torchscript=True
self._model_id, token=self._access_token, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableTextClassificationModel(self._tokenizer, model)

elif self._task_type == "text_embedding":
model = _DPREncoderWrapper.from_pretrained(self._model_id)
model = _DPREncoderWrapper.from_pretrained(
self._model_id, token=self._access_token
)
if not model:
model = _SentenceTransformerWrapperModule.from_pretrained(
self._model_id
self._model_id, token=self._access_token
)
return _TraceableTextEmbeddingModel(self._tokenizer, model)

elif self._task_type == "zero_shot_classification":
model = transformers.AutoModelForSequenceClassification.from_pretrained(
self._model_id, torchscript=True
self._model_id, token=self._access_token, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableZeroShotClassificationModel(self._tokenizer, model)

elif self._task_type == "question_answering":
model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id)
model = _QuestionAnsweringWrapperModule.from_pretrained(
self._model_id, token=self._access_token
)
return _TraceableQuestionAnsweringModel(self._tokenizer, model)

elif self._task_type == "text_similarity":
model = transformers.AutoModelForSequenceClassification.from_pretrained(
self._model_id, torchscript=True
self._model_id, token=self._access_token, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableTextSimilarityModel(self._tokenizer, model)

elif self._task_type == "pass_through":
model = transformers.AutoModel.from_pretrained(
self._model_id, torchscript=True
self._model_id, token=self._access_token, torchscript=True
)
return _TraceablePassThroughModel(self._tokenizer, model)

Expand Down

0 comments on commit 566bb9e

Please sign in to comment.