Skip to content

Commit

Permalink
[ML] add support for question_answering NLP tasks (#457)
Browse files Browse the repository at this point in the history
Adds support for `question_answering` NLP models within the pytorch model uploader.

Related: elastic/elasticsearch#85958
  • Loading branch information
benwtrent authored May 4, 2022
1 parent afe08f8 commit 70fadc9
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 17 deletions.
70 changes: 54 additions & 16 deletions eland/ml/pytorch/nlp_ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,22 @@


class NlpTokenizationConfig:
def __init__(self, *, configuration_type: str):
def __init__(
self,
*,
configuration_type: str,
with_special_tokens: t.Optional[bool] = None,
max_sequence_length: t.Optional[int] = None,
truncate: t.Optional[
t.Union["t.Literal['first', 'none', 'second']", str]
] = None,
span: t.Optional[int] = None,
):
self.name = configuration_type
self.with_special_tokens = with_special_tokens
self.max_sequence_length = max_sequence_length
self.truncate = truncate
self.span = span

def to_dict(self):
return {
Expand All @@ -42,12 +56,14 @@ def __init__(
] = None,
span: t.Optional[int] = None,
):
super().__init__(configuration_type="roberta")
super().__init__(
configuration_type="roberta",
with_special_tokens=with_special_tokens,
max_sequence_length=max_sequence_length,
truncate=truncate,
span=span,
)
self.add_prefix_space = add_prefix_space
self.with_special_tokens = with_special_tokens
self.max_sequence_length = max_sequence_length
self.truncate = truncate
self.span = span


class NlpBertTokenizationConfig(NlpTokenizationConfig):
Expand All @@ -62,12 +78,14 @@ def __init__(
] = None,
span: t.Optional[int] = None,
):
super().__init__(configuration_type="bert")
super().__init__(
configuration_type="bert",
with_special_tokens=with_special_tokens,
max_sequence_length=max_sequence_length,
truncate=truncate,
span=span,
)
self.do_lower_case = do_lower_case
self.with_special_tokens = with_special_tokens
self.max_sequence_length = max_sequence_length
self.truncate = truncate
self.span = span


class NlpMPNetTokenizationConfig(NlpTokenizationConfig):
Expand All @@ -82,12 +100,14 @@ def __init__(
] = None,
span: t.Optional[int] = None,
):
super().__init__(configuration_type="mpnet")
super().__init__(
configuration_type="mpnet",
with_special_tokens=with_special_tokens,
max_sequence_length=max_sequence_length,
truncate=truncate,
span=span,
)
self.do_lower_case = do_lower_case
self.with_special_tokens = with_special_tokens
self.max_sequence_length = max_sequence_length
self.truncate = truncate
self.span = span


class InferenceConfig:
Expand Down Expand Up @@ -180,6 +200,24 @@ def __init__(
self.results_field = results_field


class QuestionAnsweringInferenceOptions(InferenceConfig):
def __init__(
self,
*,
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
max_answer_length: t.Optional[int] = None,
question: t.Optional[str] = None,
num_top_classes: t.Optional[int] = None,
):
super().__init__(configuration_type="question_answering")
self.tokenization = tokenization
self.results_field = results_field
self.max_answer_length = max_answer_length
self.question = question
self.num_top_classes = num_top_classes


class TextEmbeddingInferenceOptions(InferenceConfig):
def __init__(
self,
Expand Down
103 changes: 102 additions & 1 deletion eland/ml/pytorch/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from transformers import (
AutoConfig,
AutoModel,
AutoModelForQuestionAnswering,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
Expand All @@ -46,6 +47,7 @@
NlpTokenizationConfig,
NlpTrainedModelConfig,
PassThroughInferenceOptions,
QuestionAnsweringInferenceOptions,
TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions,
TrainedModelInput,
Expand All @@ -59,6 +61,7 @@
"text_classification",
"text_embedding",
"zero_shot_classification",
"question_answering",
}
TASK_TYPE_TO_INFERENCE_CONFIG = {
"fill_mask": FillMaskInferenceOptions,
Expand All @@ -67,6 +70,7 @@
"text_embedding": TextEmbeddingInferenceOptions,
"zero_shot_classification": ZeroShotClassificationInferenceOptions,
"pass_through": PassThroughInferenceOptions,
"question_answering": QuestionAnsweringInferenceOptions,
}
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
SUPPORTED_TOKENIZERS = (
Expand All @@ -92,6 +96,86 @@
]


class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore
"""
A wrapper around a question answering model.
Our inference engine only takes the first tuple if the inference response
is a tuple.
This wrapper transforms the output to be a stacked tensor if its a tuple.
Otherwise it passes it through
"""

def __init__(self, model: PreTrainedModel):
super().__init__()
self._hf_model = model
self.config = model.config

@staticmethod
def from_pretrained(model_id: str) -> Optional[Any]:
model = AutoModelForQuestionAnswering.from_pretrained(
model_id, torchscript=True
)
if isinstance(
model.config,
(
transformers.MPNetConfig,
transformers.RobertaConfig,
transformers.BartConfig,
),
):
return _TwoParameterQuestionAnsweringWrapper(model)
else:
return _QuestionAnsweringWrapper(model)


class _QuestionAnsweringWrapper(_QuestionAnsweringWrapperModule):
def __init__(self, model: PreTrainedModel):
super().__init__(model=model)

def forward(
self,
input_ids: Tensor,
attention_mask: Tensor,
token_type_ids: Tensor,
position_ids: Tensor,
) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""

inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"position_ids": position_ids,
}

# remove inputs for specific model types
if isinstance(self._hf_model.config, transformers.DistilBertConfig):
del inputs["token_type_ids"]
del inputs["position_ids"]
response = self._hf_model(**inputs)
if isinstance(response, tuple):
return torch.stack(list(response), dim=0)
return response


class _TwoParameterQuestionAnsweringWrapper(_QuestionAnsweringWrapperModule):
def __init__(self, model: PreTrainedModel):
super().__init__(model=model)

def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
response = self._hf_model(**inputs)
if isinstance(response, tuple):
return torch.stack(list(response), dim=0)
return response


class _DistilBertWrapper(nn.Module): # type: ignore
"""
A simple wrapper around DistilBERT model which makes the model inputs
Expand Down Expand Up @@ -404,6 +488,16 @@ def _prepare_inputs(self) -> transformers.BatchEncoding:
)


class _TraceableQuestionAnsweringModel(_TraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"What is the meaning of life?"
"The meaning of life, according to the hitchikers guide, is 42.",
padding="max_length",
return_tensors="pt",
)


class TransformerModel:
def __init__(self, model_id: str, task_type: str, quantize: bool = False):
self._model_id = model_id
Expand Down Expand Up @@ -472,6 +566,11 @@ def _create_tokenization_config(self) -> NlpTokenizationConfig:
def _create_config(self) -> NlpTrainedModelConfig:
tokenization_config = self._create_tokenization_config()

# Set squad well known defaults
if self._task_type == "question_answering":
tokenization_config.max_sequence_length = 386
tokenization_config.span = 128
tokenization_config.truncate = "none"
inference_config = (
TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
tokenization=tokenization_config,
Expand Down Expand Up @@ -530,7 +629,9 @@ def _create_traceable_model(self) -> _TraceableModel:
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableZeroShotClassificationModel(self._tokenizer, model)

elif self._task_type == "question_answering":
model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id)
return _TraceableQuestionAnsweringModel(self._tokenizer, model)
else:
raise TypeError(
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
Expand Down

0 comments on commit 70fadc9

Please sign in to comment.