diff --git a/eland/ml/pytorch/_pytorch_model.py b/eland/ml/pytorch/_pytorch_model.py index 66948b4f..0797e730 100644 --- a/eland/ml/pytorch/_pytorch_model.py +++ b/eland/ml/pytorch/_pytorch_model.py @@ -21,7 +21,7 @@ import os from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Set, Tuple, Union -from tqdm.auto import tqdm +from tqdm.auto import tqdm # type: ignore from eland.common import ensure_es_client @@ -101,7 +101,7 @@ def import_model( def infer( self, body: Dict[str, Any], timeout: str = DEFAULT_TIMEOUT - ) -> Dict[str, Any]: + ) -> Union[bool, Any]: return self._client.transport.perform_request( "POST", f"/_ml/trained_models/{self.model_id}/deployment/_infer", @@ -124,7 +124,7 @@ def stop(self) -> None: ) def delete(self) -> None: - self._client.ml.delete_trained_model(self.model_id, ignore=(404,)) + self._client.ml.delete_trained_model(model_id=self.model_id, ignore=(404,)) @classmethod def list( diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index dc5f37c3..41069675 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -23,11 +23,11 @@ import json import os.path from abc import ABC, abstractmethod -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union -import torch -import transformers -from sentence_transformers import SentenceTransformer +import torch # type: ignore +import transformers # type: ignore +from sentence_transformers import SentenceTransformer # type: ignore from torch import Tensor, nn from transformers import ( AutoConfig, @@ -66,7 +66,7 @@ ] -class _DistilBertWrapper(nn.Module): +class _DistilBertWrapper(nn.Module): # type: ignore """ A simple wrapper around DistilBERT model which makes the model inputs conform to Elasticsearch's native inference processor interface. @@ -96,7 +96,7 @@ def forward( return self._model(input_ids=input_ids, attention_mask=attention_mask) -class _SentenceTransformerWrapper(nn.Module): +class _SentenceTransformerWrapper(nn.Module): # type: ignore """ A wrapper around sentence-transformer models to provide pooling, normalization and other graph layers that are not defined in the base @@ -122,7 +122,7 @@ def from_pretrained( else: return None - def _remove_pooling_layer(self): + def _remove_pooling_layer(self) -> None: """ Removes any last pooling layer which is not used to create embeddings. Leaving this layer in will cause it to return a NoneType which in turn @@ -135,7 +135,7 @@ def _remove_pooling_layer(self): if hasattr(self._hf_model, "pooler"): self._hf_model.pooler = None - def _replace_transformer_layer(self): + def _replace_transformer_layer(self) -> None: """ Replaces the HuggingFace Transformer layer in the SentenceTransformer modules so we can set it with one that has pooling layer removed and @@ -167,7 +167,7 @@ def forward( return self._st_model(inputs)[self._output_key] -class _DPREncoderWrapper(nn.Module): +class _DPREncoderWrapper(nn.Module): # type: ignore """ AutoModel loading does not work for DPRContextEncoders, this only exists as a workaround. This may never be fixed so this is likely permanent. @@ -240,10 +240,7 @@ def __init__( self._tokenizer = tokenizer self._model = model - def classification_labels(self) -> Optional[List[str]]: - return None - - def quantize(self): + def quantize(self) -> None: torch.quantization.quantize_dynamic( self._model, {torch.nn.Linear}, dtype=torch.qint8 ) @@ -275,11 +272,14 @@ def trace(self) -> TracedModelTypes: def _prepare_inputs(self) -> transformers.BatchEncoding: ... + def classification_labels(self) -> Optional[List[str]]: + return None + class _TraceableClassificationModel(_TraceableModel, ABC): def classification_labels(self) -> Optional[List[str]]: id_label_items = self._model.config.id2label.items() - labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] + labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] # type: ignore # Make classes like I-PER into I_PER which fits Java enumerations return [label.replace("-", "_") for label in labels] @@ -361,15 +361,15 @@ def __init__(self, model_id: str, task_type: str, quantize: bool = False): self._vocab = self._load_vocab() self._config = self._create_config() - def _load_vocab(self): + def _load_vocab(self) -> Dict[str, List[str]]: vocab_items = self._tokenizer.get_vocab().items() - vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])] + vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])] # type: ignore return { "vocabulary": vocabulary, } - def _create_config(self): - inference_config = { + def _create_config(self) -> Dict[str, Any]: + inference_config: Dict[str, Dict[str, Any]] = { self._task_type: { "tokenization": { "bert": { @@ -448,7 +448,7 @@ def _create_traceable_model(self) -> _TraceableModel: f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}" ) - def elasticsearch_model_id(self): + def elasticsearch_model_id(self) -> str: # Elasticsearch model IDs need to be a specific format: no special chars, all lowercase, max 64 chars return self._model_id.replace("/", "__").lower()[:64] diff --git a/noxfile.py b/noxfile.py index e9043143..0a4f45cc 100644 --- a/noxfile.py +++ b/noxfile.py @@ -44,6 +44,9 @@ "eland/ml/_optional.py", "eland/ml/_model_serializer.py", "eland/ml/ml_model.py", + "eland/ml/pytorch/__init__.py", + "eland/ml/pytorch/_pytorch_model.py", + "eland/ml/pytorch/transformers.py", "eland/ml/transformers/__init__.py", "eland/ml/transformers/base.py", "eland/ml/transformers/lightgbm.py",