Skip to content

Commit

Permalink
Add PyTorch modules to noxfile
Browse files Browse the repository at this point in the history
We added the `pytorch` module which is type checked but was not in the
noxfile as such. This change also addresses type errors that arose after
adding type checking.
  • Loading branch information
joshdevins authored Nov 29, 2021
1 parent 7209f61 commit 5bc1a82
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 22 deletions.
6 changes: 3 additions & 3 deletions eland/ml/pytorch/_pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Set, Tuple, Union

from tqdm.auto import tqdm
from tqdm.auto import tqdm # type: ignore

from eland.common import ensure_es_client

Expand Down Expand Up @@ -101,7 +101,7 @@ def import_model(

def infer(
self, body: Dict[str, Any], timeout: str = DEFAULT_TIMEOUT
) -> Dict[str, Any]:
) -> Union[bool, Any]:
return self._client.transport.perform_request(
"POST",
f"/_ml/trained_models/{self.model_id}/deployment/_infer",
Expand All @@ -124,7 +124,7 @@ def stop(self) -> None:
)

def delete(self) -> None:
self._client.ml.delete_trained_model(self.model_id, ignore=(404,))
self._client.ml.delete_trained_model(model_id=self.model_id, ignore=(404,))

@classmethod
def list(
Expand Down
38 changes: 19 additions & 19 deletions eland/ml/pytorch/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
import json
import os.path
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import transformers
from sentence_transformers import SentenceTransformer
import torch # type: ignore
import transformers # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from torch import Tensor, nn
from transformers import (
AutoConfig,
Expand Down Expand Up @@ -66,7 +66,7 @@
]


class _DistilBertWrapper(nn.Module):
class _DistilBertWrapper(nn.Module): # type: ignore
"""
A simple wrapper around DistilBERT model which makes the model inputs
conform to Elasticsearch's native inference processor interface.
Expand Down Expand Up @@ -96,7 +96,7 @@ def forward(
return self._model(input_ids=input_ids, attention_mask=attention_mask)


class _SentenceTransformerWrapper(nn.Module):
class _SentenceTransformerWrapper(nn.Module): # type: ignore
"""
A wrapper around sentence-transformer models to provide pooling,
normalization and other graph layers that are not defined in the base
Expand All @@ -122,7 +122,7 @@ def from_pretrained(
else:
return None

def _remove_pooling_layer(self):
def _remove_pooling_layer(self) -> None:
"""
Removes any last pooling layer which is not used to create embeddings.
Leaving this layer in will cause it to return a NoneType which in turn
Expand All @@ -135,7 +135,7 @@ def _remove_pooling_layer(self):
if hasattr(self._hf_model, "pooler"):
self._hf_model.pooler = None

def _replace_transformer_layer(self):
def _replace_transformer_layer(self) -> None:
"""
Replaces the HuggingFace Transformer layer in the SentenceTransformer
modules so we can set it with one that has pooling layer removed and
Expand Down Expand Up @@ -167,7 +167,7 @@ def forward(
return self._st_model(inputs)[self._output_key]


class _DPREncoderWrapper(nn.Module):
class _DPREncoderWrapper(nn.Module): # type: ignore
"""
AutoModel loading does not work for DPRContextEncoders, this only exists as
a workaround. This may never be fixed so this is likely permanent.
Expand Down Expand Up @@ -240,10 +240,7 @@ def __init__(
self._tokenizer = tokenizer
self._model = model

def classification_labels(self) -> Optional[List[str]]:
return None

def quantize(self):
def quantize(self) -> None:
torch.quantization.quantize_dynamic(
self._model, {torch.nn.Linear}, dtype=torch.qint8
)
Expand Down Expand Up @@ -275,11 +272,14 @@ def trace(self) -> TracedModelTypes:
def _prepare_inputs(self) -> transformers.BatchEncoding:
...

def classification_labels(self) -> Optional[List[str]]:
return None


class _TraceableClassificationModel(_TraceableModel, ABC):
def classification_labels(self) -> Optional[List[str]]:
id_label_items = self._model.config.id2label.items()
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])]
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] # type: ignore

# Make classes like I-PER into I_PER which fits Java enumerations
return [label.replace("-", "_") for label in labels]
Expand Down Expand Up @@ -361,15 +361,15 @@ def __init__(self, model_id: str, task_type: str, quantize: bool = False):
self._vocab = self._load_vocab()
self._config = self._create_config()

def _load_vocab(self):
def _load_vocab(self) -> Dict[str, List[str]]:
vocab_items = self._tokenizer.get_vocab().items()
vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])]
vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])] # type: ignore
return {
"vocabulary": vocabulary,
}

def _create_config(self):
inference_config = {
def _create_config(self) -> Dict[str, Any]:
inference_config: Dict[str, Dict[str, Any]] = {
self._task_type: {
"tokenization": {
"bert": {
Expand Down Expand Up @@ -448,7 +448,7 @@ def _create_traceable_model(self) -> _TraceableModel:
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
)

def elasticsearch_model_id(self):
def elasticsearch_model_id(self) -> str:
# Elasticsearch model IDs need to be a specific format: no special chars, all lowercase, max 64 chars
return self._model_id.replace("/", "__").lower()[:64]

Expand Down
3 changes: 3 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
"eland/ml/_optional.py",
"eland/ml/_model_serializer.py",
"eland/ml/ml_model.py",
"eland/ml/pytorch/__init__.py",
"eland/ml/pytorch/_pytorch_model.py",
"eland/ml/pytorch/transformers.py",
"eland/ml/transformers/__init__.py",
"eland/ml/transformers/base.py",
"eland/ml/transformers/lightgbm.py",
Expand Down

0 comments on commit 5bc1a82

Please sign in to comment.