Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restrict PyTorch version not to be more advanced than that used in Elasticsearch #479

Merged
merged 7 commits into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ $ conda install -c conda-forge eland
### Compatibility

- Supports Python 3.7+ and Pandas 1.3
- Supports Elasticsearch clusters that are 7.11+, recommended 7.14 or later for all features to work.
Make sure your Eland major version matches the major version of your Elasticsearch cluster.
- Supports Elasticsearch clusters that are 7.11+, recommended 8.3 or later for all features to work.
If you are using the NLP with PyTorch feature make sure your Eland minor version matches the minor
version of your Elasticsearch cluster. For all other features it is sufficient for the major versions
to match.

### Prerequisites

Expand Down
22 changes: 16 additions & 6 deletions eland/ml/pytorch/_pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Mapping,
Expand All @@ -39,6 +40,8 @@
if TYPE_CHECKING:
from elasticsearch import Elasticsearch

from elasticsearch._sync.client.utils import _quote

DEFAULT_CHUNK_SIZE = 4 * 1024 * 1024 # 4MB
DEFAULT_TIMEOUT = "60s"

Expand Down Expand Up @@ -125,12 +128,19 @@ def infer(
docs: List[Mapping[str, str]],
timeout: str = DEFAULT_TIMEOUT,
) -> Any:
return self._client.options(
request_timeout=60
).ml.infer_trained_model_deployment(
model_id=self.model_id,
timeout=timeout,
docs=docs,
if docs is None:
raise ValueError("Empty value passed for parameter 'docs'")

__body: Dict[str, Any] = {}
__body["docs"] = docs

__path = f"/_ml/trained_models/{_quote(self.model_id)}/deployment/_infer"
__query: Dict[str, Any] = {}
__query["timeout"] = timeout
__headers = {"accept": "application/json", "content-type": "application/json"}

return self._client.options(request_timeout=60).perform_request(
"POST", __path, params=__query, headers=__headers, body=__body
)

def start(self, timeout: str = DEFAULT_TIMEOUT) -> None:
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def lint(session):
# Install numpy to use its mypy plugin
# https://numpy.org/devdocs/reference/typing.html#mypy-plugin
session.install("black", "flake8", "mypy", "isort", "numpy")
session.install("--pre", "elasticsearch>=8.0.0a1,<9")
session.install("--pre", "elasticsearch>=8.3,<9")
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
session.run("black", "--check", "--target-version=py37", *SOURCE_FILES)
session.run("isort", "--check", "--profile=black", *SOURCE_FILES)
Expand Down
11 changes: 7 additions & 4 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#
# Basic requirements
#
elasticsearch>=8,<9
elasticsearch>=8.3,<9
pandas>=1.2,<2
matplotlib<4
numpy<2
Expand All @@ -16,9 +16,12 @@ scikit-learn>=0.22.1,<2
lightgbm>=2,<4

# PyTorch doesn't support Python 3.10 yet (pytorch/pytorch#66424)
sentence-transformers>=2.1.0,<3; python_version<'3.10'
torch>=1.11.0,<2; python_version<'3.10'
transformers[torch]>=4.12.0,<5; python_version<'3.10'

# Elasticsearch uses v1.11.0 of PyTorch
torch>=1.11.0,<1.12.0; python_version<'3.10'
# Versions known to be compatible with torch 1.11
sentence-transformers>=2.1.0,<=2.2.2; python_version<'3.10'
transformers[torch]>=4.12.0,<=4.20.1; python_version<'3.10'

#
# Testing
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#
# Basic requirements
#
elasticsearch>=8,<9
elasticsearch>=8.3,<9
pandas>=1.2,<2
matplotlib<4
numpy<2
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@
"scikit-learn": ["scikit-learn>=0.22.1,<2"],
"lightgbm": ["lightgbm>=2,<4"],
"pytorch": [
"sentence-transformers>=2.1.0,<3",
"torch>=1.11.0,<2",
"transformers[torch]>=4.12.0,<5",
"torch>=1.11.0,<1.12.0",
"sentence-transformers>=2.1.0,<=2.2.2",
"transformers[torch]>=4.12.0,<=4.20.1",
],
}
extras["all"] = list({dep for deps in extras.values() for dep in deps})
Expand All @@ -82,7 +82,7 @@
keywords="elastic eland pandas python",
packages=find_packages(include=["eland", "eland.*"]),
install_requires=[
"elasticsearch>=8,<9",
"elasticsearch>=8.3,<9",
"pandas>=1.2,<2",
"matplotlib<4",
"numpy<2",
Expand Down