Skip to content

Commit

Permalink
Set embedding_size config parameter for Text Embedding models (#532)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle authored Apr 25, 2023
1 parent 940f2a9 commit 50d301f
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 18 deletions.
2 changes: 2 additions & 0 deletions eland/ml/pytorch/nlp_ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,12 @@ def __init__(
*,
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
embedding_size: t.Optional[int] = None,
):
super().__init__(configuration_type="text_embedding")
self.tokenization = tokenization
self.results_field = results_field
self.embedding_size = embedding_size


class TextExpansionInferenceOptions(InferenceConfig):
Expand Down
4 changes: 4 additions & 0 deletions eland/ml/pytorch/traceable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def trace(self) -> TracedModelTypes:
self._model.eval()
return self._trace()

@abstractmethod
def sample_output(self) -> torch.Tensor:
...

@abstractmethod
def _trace(self) -> TracedModelTypes:
...
Expand Down
46 changes: 28 additions & 18 deletions eland/ml/pytorch/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,14 @@ def __init__(
self._tokenizer = tokenizer

def _trace(self) -> TracedModelTypes:
inputs = self._compatible_inputs()
return torch.jit.trace(self._model, inputs)

def sample_output(self) -> Tensor:
inputs = self._compatible_inputs()
return self._model(*inputs)

def _compatible_inputs(self) -> Tuple[Tensor, ...]:
inputs = self._prepare_inputs()

# Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface
Expand All @@ -458,21 +466,16 @@ def _trace(self) -> TracedModelTypes:
transformers.BartConfig,
),
):
return torch.jit.trace(
self._model,
(inputs["input_ids"], inputs["attention_mask"]),
)
del inputs["token_type_ids"]
return (inputs["input_ids"], inputs["attention_mask"])

position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)

return torch.jit.trace(
self._model,
(
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
position_ids,
),
inputs["position_ids"] = position_ids
return (
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["position_ids"],
)

@abstractmethod
Expand Down Expand Up @@ -640,16 +643,23 @@ def _create_config(self) -> NlpTrainedModelConfig:
tokenization_config.max_sequence_length = 386
tokenization_config.span = 128
tokenization_config.truncate = "none"
inference_config = (
TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](

if self._traceable_model.classification_labels():
inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
tokenization=tokenization_config,
classification_labels=self._traceable_model.classification_labels(),
)
if self._traceable_model.classification_labels()
else TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
elif self._task_type == "text_embedding":
sample_embedding, _ = self._traceable_model.sample_output()
embedding_size = sample_embedding.size(-1)
inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
tokenization=tokenization_config,
embedding_size=embedding_size,
)
else:
inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
tokenization=tokenization_config
)
)

return NlpTrainedModelConfig(
description=f"Model {self._model_id} for task type '{self._task_type}'",
Expand Down
7 changes: 7 additions & 0 deletions tests/ml/pytorch/test_transformer_pytorch_model_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ def _trace(self) -> TracedModelTypes:
),
)

def sample_output(self) -> torch.Tensor:
input_ids = torch.tensor(np.array(range(0, len(TEST_BERT_VOCAB))))
attention_mask = torch.tensor([1] * len(TEST_BERT_VOCAB))
token_type_ids = torch.tensor([0] * len(TEST_BERT_VOCAB))
position_ids = torch.arange(len(TEST_BERT_VOCAB), dtype=torch.long)
return self._model(input_ids, attention_mask, token_type_ids, position_ids)


class NerModule(nn.Module):
def forward(
Expand Down

0 comments on commit 50d301f

Please sign in to comment.