-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add context to handler functions #103
Merged
Merged
Changes from 5 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
fb84927
Add context to handler functions
d4372c9
Add context to validate_and_initialize_user_module unit tests
1e54f98
Fix formatting issues
ad06a7f
Update comments in handler_service.py
778408f
Include unit tests for handler service with and without context
sachanub 129fa2f
Add logic to check for custom handler functions
sachanub 4c5911f
Modify logic to identify extra arguments in initialize stage
sachanub File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
14 changes: 14 additions & 0 deletions
14
tests/resources/model_input_predict_output_fn_with_context/code/inference.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
def model_fn(model_dir, context=None): | ||
return "model" | ||
|
||
|
||
def input_fn(data, content_type, context=None): | ||
return "data" | ||
|
||
|
||
def predict_fn(data, model, context=None): | ||
return "output" | ||
|
||
|
||
def output_fn(prediction, accept, context=None): | ||
return prediction |
File renamed without changes.
9 changes: 9 additions & 0 deletions
9
tests/resources/model_transform_fn_with_context/code/inference_tranform_fn.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import os | ||
|
||
|
||
def model_fn(model_dir, context=None): | ||
return f"Loading {os.path.basename(__file__)}" | ||
|
||
|
||
def transform_fn(a, b, c, d, context=None): | ||
return f"output {b}" |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import json | ||
import os | ||
import tempfile | ||
|
||
import pytest | ||
from sagemaker_inference import content_types | ||
from transformers.testing_utils import require_torch, slow | ||
|
||
from mms.context import Context, RequestProcessor | ||
from mms.metrics.metrics_store import MetricsStore | ||
from mock import Mock | ||
from sagemaker_huggingface_inference_toolkit import handler_service | ||
from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub, get_pipeline | ||
|
||
|
||
TASK = "text-classification" | ||
MODEL = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english" | ||
INPUT = {"inputs": "My name is Wolfgang and I live in Berlin"} | ||
OUTPUT = [ | ||
{"word": "Wolfgang", "score": 0.99, "entity": "I-PER", "index": 4, "start": 11, "end": 19}, | ||
{"word": "Berlin", "score": 0.99, "entity": "I-LOC", "index": 9, "start": 34, "end": 40}, | ||
] | ||
|
||
|
||
@pytest.fixture() | ||
def inference_handler(): | ||
return handler_service.HuggingFaceHandlerService() | ||
|
||
|
||
def test_get_device_cpu(inference_handler): | ||
device = inference_handler.get_device() | ||
assert device == -1 | ||
|
||
|
||
@slow | ||
def test_get_device_gpu(inference_handler): | ||
device = inference_handler.get_device() | ||
assert device > -1 | ||
|
||
|
||
@require_torch | ||
def test_test_initialize(inference_handler): | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
storage_folder = _load_model_from_hub( | ||
model_id=MODEL, | ||
model_dir=tmpdirname, | ||
) | ||
CONTEXT = Context(MODEL, storage_folder, {}, 1, -1, "1.1.4") | ||
|
||
inference_handler.initialize(CONTEXT) | ||
assert inference_handler.initialized is True | ||
|
||
|
||
@require_torch | ||
def test_handle(inference_handler): | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
storage_folder = _load_model_from_hub( | ||
model_id=MODEL, | ||
model_dir=tmpdirname, | ||
) | ||
CONTEXT = Context(MODEL, storage_folder, {}, 1, -1, "1.1.4") | ||
CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})] | ||
CONTEXT.metrics = MetricsStore(1, MODEL) | ||
|
||
inference_handler.initialize(CONTEXT) | ||
json_data = json.dumps(INPUT) | ||
prediction = inference_handler.handle([{"body": json_data.encode()}], CONTEXT) | ||
loaded_response = json.loads(prediction[0]) | ||
assert "entity" in loaded_response[0] | ||
assert "score" in loaded_response[0] | ||
|
||
|
||
@require_torch | ||
def test_load(inference_handler): | ||
context = Mock() | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
storage_folder = _load_model_from_hub( | ||
model_id=MODEL, | ||
model_dir=tmpdirname, | ||
) | ||
# test with automatic infer | ||
hf_pipeline_without_task = inference_handler.load(storage_folder, context) | ||
assert hf_pipeline_without_task.task == "token-classification" | ||
|
||
# test with automatic infer | ||
os.environ["HF_TASK"] = TASK | ||
hf_pipeline_with_task = inference_handler.load(storage_folder, context) | ||
assert hf_pipeline_with_task.task == TASK | ||
|
||
|
||
def test_preprocess(inference_handler): | ||
context = Mock() | ||
json_data = json.dumps(INPUT) | ||
decoded_input_data = inference_handler.preprocess(json_data, content_types.JSON, context) | ||
assert "inputs" in decoded_input_data | ||
|
||
|
||
def test_preprocess_bad_content_type(inference_handler): | ||
context = Mock() | ||
with pytest.raises(json.decoder.JSONDecodeError): | ||
inference_handler.preprocess(b"", content_types.JSON, context) | ||
|
||
|
||
@require_torch | ||
def test_predict(inference_handler): | ||
context = Mock() | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
storage_folder = _load_model_from_hub( | ||
model_id=MODEL, | ||
model_dir=tmpdirname, | ||
) | ||
inference_handler.model = get_pipeline(task=TASK, device=-1, model_dir=storage_folder) | ||
prediction = inference_handler.predict(INPUT, inference_handler.model, context) | ||
assert "label" in prediction[0] | ||
assert "score" in prediction[0] | ||
|
||
|
||
def test_postprocess(inference_handler): | ||
context = Mock() | ||
output = inference_handler.postprocess(OUTPUT, content_types.JSON, context) | ||
assert isinstance(output, str) | ||
|
||
|
||
def test_validate_and_initialize_user_module(inference_handler): | ||
model_dir = os.path.join(os.getcwd(), "tests/resources/model_input_predict_output_fn_with_context") | ||
CONTEXT = Context("", model_dir, {}, 1, -1, "1.1.4") | ||
|
||
inference_handler.initialize(CONTEXT) | ||
CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})] | ||
CONTEXT.metrics = MetricsStore(1, MODEL) | ||
|
||
prediction = inference_handler.handle([{"body": b""}], CONTEXT) | ||
assert "output" in prediction[0] | ||
|
||
assert inference_handler.load({}, CONTEXT) == "model" | ||
assert inference_handler.preprocess({}, "", CONTEXT) == "data" | ||
assert inference_handler.predict({}, "model", CONTEXT) == "output" | ||
assert inference_handler.postprocess("output", "", CONTEXT) == "output" | ||
|
||
|
||
def test_validate_and_initialize_user_module_transform_fn(): | ||
os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py" | ||
inference_handler = handler_service.HuggingFaceHandlerService() | ||
model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_with_context") | ||
CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4") | ||
|
||
inference_handler.initialize(CONTEXT) | ||
CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})] | ||
CONTEXT.metrics = MetricsStore(1, MODEL) | ||
assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0] | ||
assert inference_handler.load({}, CONTEXT) == "Loading inference_tranform_fn.py" | ||
assert ( | ||
inference_handler.transform_fn("model", "dummy", "application/json", "application/json", CONTEXT) | ||
== "output dummy" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that needed? What overhead does it add to an request/function call?.
Why don't we just past
self.context
in thehandle
method in linesagemaker-huggingface-inference-toolkit/src/sagemaker_huggingface_inference_toolkit/handler_service.py
Line 234 in 44e3dec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @philschmid . Thank you for your review. Maybe I am not understanding your point correctly, but I don't think we can directly pass
self.context
in thehandle
method. If we do that, don't we risk breaking existing customers who are not usingcontext
as an input argument? With the above mentionedrun_handler_function
, we should be able to support both customers who do not want to usecontext
as well as customers who want to usecontext
. Please correct me if I misunderstood. Thanks!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would it break it? you mean if they have a custom
inference.py
that defines atransform_fn
method?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we might affect those who define a custom
transform_fn
right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In
handler_service.py
, we are trying to import user defined functions frominference.py
If existence, it will be used to overwrite
self.load
,self.preprocess
,self.predict
,self.postprocess
andself.transform_fn
In this PR, we introduce additional parameter
context
for default handlers. However, for existing user scripts, they don't have this parameter when they implement customized handler function. We have to be careful when we call these functions. That's why @sachanub is adding a helper function to determine when to call functions with the new parameter.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that something we could deprecate and remove with a v3 of the toolkit? Feels like very error prone or a big overhead to parse and add args like this for every incoming request.
What do you think of adding a check to see if there is a
inference.py
provided and if not we are usingself-transform
directly? Most customer deploy models using the "zero-code" deployment feature, where you provide aMODEL_ID
andTASK
and don't need andinference.py
script.