diff --git a/eland/ml/pytorch/nlp_ml_model.py b/eland/ml/pytorch/nlp_ml_model.py index b9441e0f1..ff80b7298 100644 --- a/eland/ml/pytorch/nlp_ml_model.py +++ b/eland/ml/pytorch/nlp_ml_model.py @@ -19,8 +19,22 @@ class NlpTokenizationConfig: - def __init__(self, *, configuration_type: str): + def __init__( + self, + *, + configuration_type: str, + with_special_tokens: t.Optional[bool] = None, + max_sequence_length: t.Optional[int] = None, + truncate: t.Optional[ + t.Union["t.Literal['first', 'none', 'second']", str] + ] = None, + span: t.Optional[int] = None, + ): self.name = configuration_type + self.with_special_tokens = with_special_tokens + self.max_sequence_length = max_sequence_length + self.truncate = truncate + self.span = span def to_dict(self): return { @@ -42,12 +56,14 @@ def __init__( ] = None, span: t.Optional[int] = None, ): - super().__init__(configuration_type="roberta") + super().__init__( + configuration_type="roberta", + with_special_tokens=with_special_tokens, + max_sequence_length=max_sequence_length, + truncate=truncate, + span=span, + ) self.add_prefix_space = add_prefix_space - self.with_special_tokens = with_special_tokens - self.max_sequence_length = max_sequence_length - self.truncate = truncate - self.span = span class NlpBertTokenizationConfig(NlpTokenizationConfig): @@ -62,12 +78,14 @@ def __init__( ] = None, span: t.Optional[int] = None, ): - super().__init__(configuration_type="bert") + super().__init__( + configuration_type="bert", + with_special_tokens=with_special_tokens, + max_sequence_length=max_sequence_length, + truncate=truncate, + span=span, + ) self.do_lower_case = do_lower_case - self.with_special_tokens = with_special_tokens - self.max_sequence_length = max_sequence_length - self.truncate = truncate - self.span = span class NlpMPNetTokenizationConfig(NlpTokenizationConfig): @@ -82,12 +100,14 @@ def __init__( ] = None, span: t.Optional[int] = None, ): - super().__init__(configuration_type="mpnet") + super().__init__( + configuration_type="mpnet", + with_special_tokens=with_special_tokens, + max_sequence_length=max_sequence_length, + truncate=truncate, + span=span, + ) self.do_lower_case = do_lower_case - self.with_special_tokens = with_special_tokens - self.max_sequence_length = max_sequence_length - self.truncate = truncate - self.span = span class InferenceConfig: @@ -180,6 +200,24 @@ def __init__( self.results_field = results_field +class QuestionAnsweringInferenceOptions(InferenceConfig): + def __init__( + self, + *, + tokenization: NlpTokenizationConfig, + results_field: t.Optional[str] = None, + max_answer_length: t.Optional[int] = None, + question: t.Optional[str] = None, + num_top_classes: t.Optional[int] = None, + ): + super().__init__(configuration_type="question_answering") + self.tokenization = tokenization + self.results_field = results_field + self.max_answer_length = max_answer_length + self.question = question + self.num_top_classes = num_top_classes + + class TextEmbeddingInferenceOptions(InferenceConfig): def __init__( self, diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index c360e271d..86c5f6251 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -32,6 +32,7 @@ from transformers import ( AutoConfig, AutoModel, + AutoModelForQuestionAnswering, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, @@ -46,6 +47,7 @@ NlpTokenizationConfig, NlpTrainedModelConfig, PassThroughInferenceOptions, + QuestionAnsweringInferenceOptions, TextClassificationInferenceOptions, TextEmbeddingInferenceOptions, TrainedModelInput, @@ -59,6 +61,7 @@ "text_classification", "text_embedding", "zero_shot_classification", + "question_answering", } TASK_TYPE_TO_INFERENCE_CONFIG = { "fill_mask": FillMaskInferenceOptions, @@ -67,6 +70,7 @@ "text_embedding": TextEmbeddingInferenceOptions, "zero_shot_classification": ZeroShotClassificationInferenceOptions, "pass_through": PassThroughInferenceOptions, + "question_answering": QuestionAnsweringInferenceOptions, } SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES)) SUPPORTED_TOKENIZERS = ( @@ -92,6 +96,86 @@ ] +class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore + """ + A wrapper around a question answering model. + Our inference engine only takes the first tuple if the inference response + is a tuple. + + This wrapper transforms the output to be a stacked tensor if its a tuple. + + Otherwise it passes it through + """ + + def __init__(self, model: PreTrainedModel): + super().__init__() + self._hf_model = model + self.config = model.config + + @staticmethod + def from_pretrained(model_id: str) -> Optional[Any]: + model = AutoModelForQuestionAnswering.from_pretrained( + model_id, torchscript=True + ) + if isinstance( + model.config, + ( + transformers.MPNetConfig, + transformers.RobertaConfig, + transformers.BartConfig, + ), + ): + return _TwoParameterQuestionAnsweringWrapper(model) + else: + return _QuestionAnsweringWrapper(model) + + +class _QuestionAnsweringWrapper(_QuestionAnsweringWrapperModule): + def __init__(self, model: PreTrainedModel): + super().__init__(model=model) + + def forward( + self, + input_ids: Tensor, + attention_mask: Tensor, + token_type_ids: Tensor, + position_ids: Tensor, + ) -> Tensor: + """Wrap the input and output to conform to the native process interface.""" + + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + "position_ids": position_ids, + } + + # remove inputs for specific model types + if isinstance(self._hf_model.config, transformers.DistilBertConfig): + del inputs["token_type_ids"] + del inputs["position_ids"] + response = self._hf_model(**inputs) + if isinstance(response, tuple): + return torch.stack(list(response), dim=0) + return response + + +class _TwoParameterQuestionAnsweringWrapper(_QuestionAnsweringWrapperModule): + def __init__(self, model: PreTrainedModel): + super().__init__(model=model) + + def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: + """Wrap the input and output to conform to the native process interface.""" + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + response = self._hf_model(**inputs) + if isinstance(response, tuple): + return torch.stack(list(response), dim=0) + return response + + class _DistilBertWrapper(nn.Module): # type: ignore """ A simple wrapper around DistilBERT model which makes the model inputs @@ -404,6 +488,16 @@ def _prepare_inputs(self) -> transformers.BatchEncoding: ) +class _TraceableQuestionAnsweringModel(_TraceableModel): + def _prepare_inputs(self) -> transformers.BatchEncoding: + return self._tokenizer( + "What is the meaning of life?" + "The meaning of life, according to the hitchikers guide, is 42.", + padding="max_length", + return_tensors="pt", + ) + + class TransformerModel: def __init__(self, model_id: str, task_type: str, quantize: bool = False): self._model_id = model_id @@ -472,6 +566,11 @@ def _create_tokenization_config(self) -> NlpTokenizationConfig: def _create_config(self) -> NlpTrainedModelConfig: tokenization_config = self._create_tokenization_config() + # Set squad well known defaults + if self._task_type == "question_answering": + tokenization_config.max_sequence_length = 386 + tokenization_config.span = 128 + tokenization_config.truncate = "none" inference_config = ( TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type]( tokenization=tokenization_config, @@ -530,7 +629,9 @@ def _create_traceable_model(self) -> _TraceableModel: ) model = _DistilBertWrapper.try_wrapping(model) return _TraceableZeroShotClassificationModel(self._tokenizer, model) - + elif self._task_type == "question_answering": + model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id) + return _TraceableQuestionAnsweringModel(self._tokenizer, model) else: raise TypeError( f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"