-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model container testing function to clipper admin (Vanilla python) #394
Changes from 7 commits
4495bba
2c1474c
bed8072
ffa7525
a3aa5eb
2b6cb33
8637156
a215b62
8cba54a
41c111d
f4e005d
80819ff
e3951fb
1214ba5
83ff62b
5ab12c5
2190404
3609f3a
a53355f
f9a19d1
b689450
7f66829
fcae76c
1678e2e
e606a2e
f261d68
c6b6fec
917aeca
9972853
8fa0806
b4f998d
d45addf
a250da2
4964ece
53cc658
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,9 @@ | |
import os | ||
import tarfile | ||
import six | ||
from cloudpickle import CloudPickler | ||
import pickle | ||
import numpy as np | ||
|
||
from .container_manager import CONTAINERLESS_MODEL_IMAGE | ||
from .exceptions import ClipperException, UnconnectedException | ||
|
@@ -1187,3 +1190,68 @@ def stop_all(self): | |
""" | ||
self.cm.stop_all() | ||
logger.info("Stopped all Clipper cluster and all model containers") | ||
|
||
def test_predict_function(self, query, func, input_type): | ||
"""Tests that the user's function has the correct signature and can be properly saved and loaded. | ||
|
||
The function should take a dict request object like the query frontend expects JSON, | ||
the predict function, and the input type for the model. | ||
|
||
Parameters | ||
---------- | ||
query: JSON or list of dicts | ||
Inputs to test the prediction function on. | ||
func: function | ||
Predict function to test. | ||
input_type: str | ||
The input_type to be associated with the registered app and deployed model. | ||
One of "integers", "floats", "doubles", "bytes", or "strings". | ||
""" | ||
query_data = list(x for x in list(query.values())) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's going on here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should be checking the JSON key as well, to ensure that their input is properly formatted There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Assuming that the query is in a JSON/dict structure, I'm getting the values of each key. I'm then checking to make sure the values are of the right input_type before then converting it into the respective numpy array of right dtype. What formatting of the key input needs to be checked? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the user wants to provide a single input, the key should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. Thanks for the clarification. |
||
|
||
if type(query_data[0][0]) == list: | ||
query_data = query_data[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What are you checking for here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm checking the nesting of the query data, whether it is a list or a list of lists. |
||
|
||
flattened_data = [item for sublist in query_data for item in sublist] | ||
numpy_data = None | ||
|
||
if input_type == "bytes": | ||
numpy_data = list(np.int8(x) for x in query_data) | ||
for x in flattened_data: | ||
if type(x) != bytes: | ||
return "Invalid input type" | ||
|
||
if input_type == "integers": | ||
numpy_data = list(np.int32(x) for x in query_data) | ||
for x in flattened_data: | ||
if type(x) != int: | ||
return "Invalid input type" | ||
|
||
if input_type == "floats" or input_type == "doubles": | ||
if input_type == "floats": | ||
numpy_data = list(np.float32(x) for x in query_data) | ||
else: | ||
numpy_data = list(np.float64(x) for x in query_data) | ||
for x in flattened_data: | ||
if type(x) != float: | ||
return "Invalid input type" | ||
|
||
if input_type == "string": | ||
numpy_data = list(np.str_(x) for x in query_data) | ||
for x in flattened_data: | ||
if type(x) != str: | ||
return "Invalid input type" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add checks for the other two accepted input types: (floats and bytes) |
||
|
||
s = six.StringIO() | ||
c = CloudPickler(s, 2) | ||
c.dump(func) | ||
serialized_func = s.getvalue() | ||
reloaded_func = pickle.loads(serialized_func) | ||
|
||
try: | ||
assert reloaded_func | ||
except AssertionError as e: | ||
logger.error("Function does not properly serialize and reload") | ||
return "Function does not properly serialize and reload" | ||
|
||
return reloaded_func(numpy_data) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,8 @@ | |
import os | ||
import posixpath | ||
import shutil | ||
import pickle | ||
import numpy as np | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Revert the changes to this file |
||
from ..version import __version__ | ||
|
||
from .deployer_utils import save_python_function | ||
|
@@ -87,7 +89,6 @@ def create_endpoint( | |
|
||
clipper_conn.link_model_to_app(name, name) | ||
|
||
|
||
def deploy_python_closure( | ||
clipper_conn, | ||
name, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
import os | ||
import json | ||
import time | ||
import numpy as np | ||
import requests | ||
import tempfile | ||
import shutil | ||
|
@@ -25,6 +26,7 @@ | |
from clipper_admin.deployers.python import create_endpoint as create_py_endpoint | ||
from clipper_admin.deployers.python import deploy_python_closure | ||
from clipper_admin import __version__ as clipper_version | ||
from clipper_admin import test_predict_function | ||
|
||
sys.path.insert(0, os.path.abspath('%s/util_direct_import/' % cur_dir)) | ||
from util_package import mock_module_in_package as mmip | ||
|
@@ -343,6 +345,20 @@ def predict_func(inputs): | |
}) | ||
self.assertEqual(len(containers), 1) | ||
|
||
def test_test_predict_function(self): | ||
def predict_func(xs): | ||
return [sum(x) for x in xs] | ||
|
||
deploy_python_closure(self.clipper_conn, name="sum-model", version=1, input_type="doubles", func=predict_func) | ||
self.clipper_conn.link_model_to_app(app_name="hello-world", model_name="sum-model") | ||
|
||
headers = {"Content-type": "application/json"} | ||
test_input = list(np.random.random(10)) | ||
pred = requests.post("http://localhost:1337/hello-world/predict", headers=headers, data=json.dumps({"input": test_input})).json() | ||
test_predict_result = test_predict_function(self.clipper_conn, query={"input": test_input}, func=predict_func, input_type="doubles") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should call the function like this Can you add a second test prediction that uses the batch_input = [list(np.random.random(10)) for _ in range(4)]
batch_pred = requests.post("http://localhost:1337/hello-world/predict", headers=headers, data=json.dumps({"input_batch": batch_input})).json()
test_batch_predict_result = self.clipper_conn.test_predict_function(query={"input_batch": test_input}, func=predict_func, input_type="doubles") |
||
|
||
self.assertEqual(pred, test_predict_result) | ||
|
||
|
||
class ClipperManagerTestCaseLong(unittest.TestCase): | ||
@classmethod | ||
|
@@ -374,8 +390,10 @@ def setUpClass(self): | |
self.latency_slo_micros) | ||
|
||
self.clipper_conn.register_application( | ||
self.app_name_4, self.input_type, self.default_output, | ||
self.latency_slo_micros) | ||
self.app_name_4, | ||
self.input_type, | ||
self.default_output, | ||
slo_micros=30000000) | ||
|
||
@classmethod | ||
def tearDownClass(self): | ||
|
@@ -481,6 +499,7 @@ def test_fixed_batch_size_model_processes_specified_query_batch_size_when_satura | |
model_version = 1 | ||
|
||
def predict_func(inputs): | ||
time.sleep(.5) | ||
batch_size = len(inputs) | ||
return [str(batch_size) for _ in inputs] | ||
|
||
|
@@ -534,6 +553,7 @@ def predict_func(inputs): | |
'test_stop_models', | ||
'test_python_closure_deploys_successfully', | ||
'test_register_py_endpoint', | ||
'test_test_predict_function' | ||
] | ||
|
||
LONG_TEST_ORDERING = [ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add an example in the method comment?