Skip to content

Commit

Permalink
[NLP] Add support for the pass_through task #526
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle authored Apr 6, 2023
1 parent 8e0d897 commit 940f2a9
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions eland/ml/pytorch/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
SUPPORTED_TASK_TYPES = {
"fill_mask",
"ner",
"pass_through",
"text_classification",
"text_embedding",
"text_expansion",
Expand Down Expand Up @@ -510,6 +511,15 @@ def _prepare_inputs(self) -> transformers.BatchEncoding:
)


class _TraceablePassThroughModel(_TransformerTraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"This is an example sentence.",
padding="max_length",
return_tensors="pt",
)


class _TraceableTextClassificationModel(_TraceableClassificationModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
Expand Down Expand Up @@ -709,6 +719,11 @@ def _create_traceable_model(self) -> TraceableModel:
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableTextSimilarityModel(self._tokenizer, model)
elif self._task_type == "pass_through":
model = transformers.AutoModel.from_pretrained(
self._model_id, torchscript=True
)
return _TraceablePassThroughModel(self._tokenizer, model)
else:
raise TypeError(
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
Expand Down

0 comments on commit 940f2a9

Please sign in to comment.