Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace MLModel(overwrite) with es_if_exists #249

Merged
merged 5 commits into from
Aug 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ currently using a minimum version of PyCharm 2019.2.4.
Tools\'-\>\'Docstring format\' to `numpy`
- Install development requirements. Open terminal in virtual
environment and run `pip install -r requirements-dev.txt`
- Setup Elasticsearch instance (assumes `localhost:9200`), and run
- Setup Elasticsearch instance with docker `ELASTICSEARCH_VERSION=elasticsearch:7.x-SNAPSHOT .ci/run-elasticsearch.sh` and check `http://localhost:9200`
V1NAY8 marked this conversation as resolved.
Show resolved Hide resolved
- Run
`python -m eland.tests.setup_tests` to setup test environment -*note
this modifies Elasticsearch indices*
- Install local `eland` module (required to execute notebook tests)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ dtype: int64
13057 20819.488281
13058 18315.431274
Length: 13059, dtype: float64
>>> print(s.info_es())
>>> print(s.es_info())
index_pattern: flights
Index:
index_field: _id
Expand Down
41 changes: 36 additions & 5 deletions eland/ml/imported_ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .ml_model import MLModel
from .transformers import get_model_transformer
from ..common import es_version
import warnings


if TYPE_CHECKING:
Expand Down Expand Up @@ -100,7 +101,13 @@ class ImportedMLModel(MLModel):
classification_weights: List[str]
Weights of the classification targets

overwrite: bool
es_if_exists: {'fail', 'replace'} default 'fail'
How to behave if model already exists

- fail: Raise a Value Error
- replace: Overwrite existing model

overwrite: **DEPRECATED** - bool
Delete and overwrite existing model (if exists)

es_compress_model_definition: bool
Expand All @@ -127,7 +134,7 @@ class ImportedMLModel(MLModel):
>>> # Serialise the model to Elasticsearch
>>> feature_names = ["f0", "f1", "f2", "f3", "f4"]
>>> model_id = "test_decision_tree_classifier"
>>> es_model = ImportedMLModel('localhost', model_id, classifier, feature_names, overwrite=True)
>>> es_model = ImportedMLModel('localhost', model_id, classifier, feature_names, es_if_exists='replace')

>>> # Get some test results from Elasticsearch model
>>> es_model.predict(test_data)
Expand Down Expand Up @@ -155,7 +162,8 @@ def __init__(
feature_names: List[str],
classification_labels: Optional[List[str]] = None,
classification_weights: Optional[List[float]] = None,
overwrite: bool = False,
es_if_exists: Optional[str] = None,
overwrite: Optional[bool] = None,
es_compress_model_definition: bool = True,
):
super().__init__(es_client, model_id)
Expand All @@ -171,7 +179,30 @@ def __init__(
self._model_type = transformer.model_type
serializer = transformer.transform()

if overwrite:
# Verify if both parameters are given
if overwrite is not None and es_if_exists is not None:
raise ValueError(
"Using 'overwrite' and 'es_if_exists' together is invalid, use only 'es_if_exists'"
)

if overwrite is not None:
warnings.warn(
"'overwrite' parameter is deprecated, use 'es_if_exists' instead",
DeprecationWarning,
stacklevel=2,
)
es_if_exists = "replace" if overwrite else "fail"
elif es_if_exists is None:
es_if_exists = "fail"

if es_if_exists not in ("fail", "replace"):
raise ValueError("'es_if_exists' must be either 'fail' or 'replace'")
elif es_if_exists == "fail":
if self.check_existing_model():
raise ValueError(
f"Trained machine learning model {model_id} already exists"
)
elif es_if_exists == "replace":
self.delete_model()

body: Dict[str, Any] = {
Expand Down Expand Up @@ -224,7 +255,7 @@ def predict(self, X: Union[List[float], List[List[float]]]) -> np.ndarray:
>>> # Serialise the model to Elasticsearch
>>> feature_names = ["f0", "f1", "f2", "f3", "f4", "f5"]
>>> model_id = "test_xgb_regressor"
>>> es_model = ImportedMLModel('localhost', model_id, regressor, feature_names, overwrite=True)
>>> es_model = ImportedMLModel('localhost', model_id, regressor, feature_names, es_if_exists='replace')

>>> # Get some test results from Elasticsearch model
>>> es_model.predict(test_data)
Expand Down
12 changes: 12 additions & 0 deletions eland/ml/ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,15 @@ def delete_model(self) -> None:
self._client.ml.delete_trained_model(model_id=self._model_id, ignore=(404,))
except elasticsearch.NotFoundError:
pass

def check_existing_model(self) -> bool:
"""
Check If model exists in Elastic
"""
try:
self._client.ml.get_trained_models(
model_id=self._model_id, include_model_definition=False
)
except elasticsearch.NotFoundError:
return False
return True
129 changes: 118 additions & 11 deletions eland/tests/ml/test_imported_ml_model_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_unpack_and_raise_errors_in_ingest_simulate(self, mocker):
model_id,
classifier,
feature_names,
overwrite=True,
es_if_exists="replace",
es_compress_model_definition=True,
)

Expand Down Expand Up @@ -147,7 +147,7 @@ def test_decision_tree_classifier(self, compress_model_definition):
model_id,
classifier,
feature_names,
overwrite=True,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
)

Expand Down Expand Up @@ -176,7 +176,7 @@ def test_decision_tree_regressor(self, compress_model_definition):
model_id,
regressor,
feature_names,
overwrite=True,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
)
# Get some test results
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_random_forest_classifier(self, compress_model_definition):
model_id,
classifier,
feature_names,
overwrite=True,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
)
# Get some test results
Expand Down Expand Up @@ -232,7 +232,7 @@ def test_random_forest_regressor(self, compress_model_definition):
model_id,
regressor,
feature_names,
overwrite=True,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
)
# Get some test results
Expand Down Expand Up @@ -270,7 +270,7 @@ def test_xgb_classifier(self, compress_model_definition, multi_class):
model_id,
classifier,
feature_names,
overwrite=True,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
)
# Get some test results
Expand Down Expand Up @@ -306,7 +306,7 @@ def test_xgb_classifier_objectives_and_booster(self, objective, booster):
model_id = "test_xgb_classifier"

es_model = ImportedMLModel(
ES_TEST_CLIENT, model_id, classifier, feature_names, overwrite=True
ES_TEST_CLIENT, model_id, classifier, feature_names, es_if_exists="replace"
)
# Get some test results
check_prediction_equality(
Expand Down Expand Up @@ -342,7 +342,7 @@ def test_xgb_regressor(self, compress_model_definition, objective, booster):
model_id,
regressor,
feature_names,
overwrite=True,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
)
# Get some test results
Expand All @@ -369,7 +369,7 @@ def test_predict_single_feature_vector(self):
model_id = "test_xgb_regressor"

es_model = ImportedMLModel(
ES_TEST_CLIENT, model_id, regressor, feature_names, overwrite=True
ES_TEST_CLIENT, model_id, regressor, feature_names, es_if_exists="replace"
)

# Single feature
Expand Down Expand Up @@ -410,7 +410,7 @@ def test_lgbm_regressor(self, compress_model_definition, objective, booster):
model_id,
regressor,
feature_names,
overwrite=True,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
)
# Get some test results
Expand Down Expand Up @@ -451,7 +451,7 @@ def test_lgbm_classifier_objectives_and_booster(
model_id,
classifier,
feature_names,
overwrite=True,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
)

Expand All @@ -461,3 +461,110 @@ def test_lgbm_classifier_objectives_and_booster(

# Clean up
es_model.delete_model()

# If both overwrite and es_if_exists is given.
@requires_sklearn
@pytest.mark.parametrize("compress_model_definition", [True, False])
@pytest.mark.parametrize("es_if_exists", ["fail", "replace"])
@pytest.mark.parametrize("overwrite", [True, False])
def test_imported_mlmodel_bothparams(
self, compress_model_definition, es_if_exists, overwrite
):
# Train model
training_data = datasets.make_regression(n_features=5)
regressor = RandomForestRegressor()
regressor.fit(training_data[0], training_data[1])

feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_random_forest_regressor"

match = "Using 'overwrite' and 'es_if_exists' together is invalid, use only 'es_if_exists'"
with pytest.raises(ValueError, match=match):
ImportedMLModel(
ES_TEST_CLIENT,
model_id,
regressor,
feature_names,
es_if_exists=es_if_exists,
overwrite=overwrite,
es_compress_model_definition=compress_model_definition,
)

# Deprecation warning for overwrite parameter
@requires_sklearn
@pytest.mark.parametrize("compress_model_definition", [True, False])
@pytest.mark.parametrize("overwrite", [True])
def test_imported_mlmodel_overwrite_true(
self, compress_model_definition, overwrite
):
# Train model
training_data = datasets.make_regression(n_features=5)
regressor = RandomForestRegressor()
regressor.fit(training_data[0], training_data[1])

feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_random_forest_regressor"

match = "'overwrite' parameter is deprecated, use 'es_if_exists' instead"
with pytest.warns(DeprecationWarning, match=match):
ImportedMLModel(
ES_TEST_CLIENT,
model_id,
regressor,
feature_names,
overwrite=overwrite,
es_compress_model_definition=compress_model_definition,
)

@requires_sklearn
@pytest.mark.parametrize("compress_model_definition", [True, False])
@pytest.mark.parametrize("overwrite", [False])
def test_imported_mlmodel_overwrite_false(
self, compress_model_definition, overwrite
):
# Train model
training_data = datasets.make_regression(n_features=5)
regressor = RandomForestRegressor()
regressor.fit(training_data[0], training_data[1])

feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_random_forest_regressor"

match_error = f"Trained machine learning model {model_id} already exists"
match_warning = (
"'overwrite' parameter is deprecated, use 'es_if_exists' instead"
)
with pytest.raises(ValueError, match=match_error):
with pytest.warns(DeprecationWarning, match=match_warning):
ImportedMLModel(
ES_TEST_CLIENT,
model_id,
regressor,
feature_names,
overwrite=overwrite,
es_compress_model_definition=compress_model_definition,
)

# Raise ValueError if Model exists when es_if_exists = 'fail'
@requires_sklearn
@pytest.mark.parametrize("compress_model_definition", [True, False])
def test_es_if_exists_fail(self, compress_model_definition):
# Train model
training_data = datasets.make_regression(n_features=5)
regressor = RandomForestRegressor()
regressor.fit(training_data[0], training_data[1])

feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_random_forest_regressor"

# If both overwrite and es_if_exists is given.
match = f"Trained machine learning model {model_id} already exists"
with pytest.raises(ValueError, match=match):
ImportedMLModel(
ES_TEST_CLIENT,
model_id,
regressor,
feature_names,
es_if_exists="fail",
es_compress_model_definition=compress_model_definition,
)
1 change: 1 addition & 0 deletions eland/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def wrapped(*args: Any, **kwargs: Any) -> RT:
warnings.warn(
f"{f.__name__} is deprecated, use {replace_with} instead",
DeprecationWarning,
stacklevel=2,
)
return f(*args, **kwargs)

Expand Down