diff --git a/eland/cli/eland_import_hub_model.py b/eland/cli/eland_import_hub_model.py index 4ca85447..9496e91e 100755 --- a/eland/cli/eland_import_hub_model.py +++ b/eland/cli/eland_import_hub_model.py @@ -229,7 +229,6 @@ def check_cluster_version(es_client, logger): sem_ver = parse_es_version(es_info["version"]["number"]) major_version = sem_ver[0] - minor_version = sem_ver[1] # NLP models added in 8 if major_version < 8: @@ -238,13 +237,13 @@ def check_cluster_version(es_client, logger): ) exit(1) - # PyTorch was upgraded to version 2.1.2 in 8.13 + # PyTorch was upgraded to version 2.3.1 in 8.15.2 # and is incompatible with earlier versions - if major_version == 8 and minor_version < 13: + if sem_ver < (8, 15, 2): import torch logger.error( - f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.13. Please upgrade Elasticsearch to at least version 8.13" + f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.15.2. Please upgrade Elasticsearch to at least version 8.15.2" ) exit(1) diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index ab89e55b..271a2431 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -36,6 +36,7 @@ AutoConfig, AutoModel, AutoModelForQuestionAnswering, + BertTokenizer, PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, @@ -757,6 +758,9 @@ def _find_max_sequence_length(self) -> int: if max_len is not None and max_len < REASONABLE_MAX_LENGTH: return int(max_len) + if isinstance(self._tokenizer, BertTokenizer): + return 512 + raise UnknownModelInputSizeError("Cannot determine model max input length") def _create_config( diff --git a/noxfile.py b/noxfile.py index a60950ec..e8a57191 100644 --- a/noxfile.py +++ b/noxfile.py @@ -121,7 +121,7 @@ def test(session, pandas_version: str): "--nbval", ) - # PyTorch 2.1.2 doesn't support Python 3.12 + # PyTorch 2.3.1 doesn't support Python 3.12 if session.python == "3.12": pytest_args += ("--ignore=eland/ml/pytorch",) session.run( diff --git a/setup.py b/setup.py index 1767ea32..1befe7d0 100644 --- a/setup.py +++ b/setup.py @@ -60,10 +60,12 @@ "lightgbm": ["lightgbm>=2,<4"], "pytorch": [ "requests<3", - "torch==2.1.2", + "torch==2.3.1", "tqdm", - "sentence-transformers>=2.1.0,<=2.3.1", - "transformers[torch]>=4.31.0,<4.36.0", + "sentence-transformers>=2.1.0,<=2.7.0", + # sentencepiece is a required dependency for the slow tokenizers + # https://huggingface.co/transformers/v4.4.2/migration.html#sentencepiece-is-removed-from-the-required-dependencies + "transformers[sentencepiece]>=4.31.0,<4.44.0", ], } extras["all"] = list({dep for deps in extras.values() for dep in deps}) diff --git a/tests/ml/pytorch/test_pytorch_model_config_pytest.py b/tests/ml/pytorch/test_pytorch_model_config_pytest.py index 50ea4aa9..c12be3a8 100644 --- a/tests/ml/pytorch/test_pytorch_model_config_pytest.py +++ b/tests/ml/pytorch/test_pytorch_model_config_pytest.py @@ -58,8 +58,8 @@ pytestmark = [ pytest.mark.skipif( - ES_VERSION < (8, 13, 0), - reason="Eland uses Pytorch 2.1.2, versions of Elasticsearch prior to 8.13.0 are incompatible with PyTorch 2.1.2", + ES_VERSION < (8, 15, 1), + reason="Eland uses Pytorch 2.3.1, versions of Elasticsearch prior to 8.15.1 are incompatible with PyTorch 2.3.1", ), pytest.mark.skipif( not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run" @@ -149,21 +149,12 @@ 1024, None, ), - ( - "cardiffnlp/twitter-roberta-base-sentiment", - "text_classification", - TextClassificationInferenceOptions, - NlpRobertaTokenizationConfig, - 512, - None, - ), ] else: MODEL_CONFIGURATIONS = [] class TestModelConfguration: - @pytest.mark.skip(reason="https://github.com/elastic/eland/issues/633") @pytest.mark.parametrize( "model_id,task_type,config_type,tokenizer_type,max_sequence_len,embedding_size", MODEL_CONFIGURATIONS, diff --git a/tests/ml/pytorch/test_pytorch_model_upload_pytest.py b/tests/ml/pytorch/test_pytorch_model_upload_pytest.py index 7eac6a8d..c84a77e0 100644 --- a/tests/ml/pytorch/test_pytorch_model_upload_pytest.py +++ b/tests/ml/pytorch/test_pytorch_model_upload_pytest.py @@ -39,8 +39,8 @@ pytestmark = [ pytest.mark.skipif( - ES_VERSION < (8, 13, 0), - reason="Eland uses Pytorch 2.1.2, versions of Elasticsearch prior to 8.13.0 are incompatible with PyTorch 2.1.2", + ES_VERSION < (8, 15, 2), + reason="Eland uses Pytorch 2.3.1, versions of Elasticsearch prior to 8.15.2 are incompatible with PyTorch 2.3.1", ), pytest.mark.skipif( not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"