Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context to handler functions #103

Merged
merged 7 commits into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 54 additions & 18 deletions src/sagemaker_huggingface_inference_toolkit/handler_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sys
import time
from abc import ABC
from inspect import signature

from sagemaker_inference import environment, utils
from transformers.pipelines import SUPPORTED_TASKS
Expand Down Expand Up @@ -74,7 +75,7 @@ def initialize(self, context):
self.validate_and_initialize_user_module()

self.device = self.get_device()
self.model = self.load(self.model_dir)
self.model = self.run_handler_function(self.load, *(self.model_dir,))
self.initialized = True
# # Load methods from file
# if (not self._initialized) and ENABLE_MULTI_MODEL:
Expand All @@ -92,10 +93,15 @@ def get_device(self):
else:
return -1

def load(self, model_dir):
def load(self, model_dir, context=None):
"""
The Load handler is responsible for loading the Hugging Face transformer model.
It can be overridden to load the model from storage
It can be overridden to load the model from storage.

Args:
model_dir (str): The directory where model files are stored.
context (obj): metadata on the incoming request data (default: None).

Returns:
hf_pipeline (Pipeline): A Hugging Face Transformer pipeline.
"""
Expand All @@ -111,14 +117,16 @@ def load(self, model_dir):
)
return hf_pipeline

def preprocess(self, input_data, content_type):
def preprocess(self, input_data, content_type, context=None):
"""
The preprocess handler is responsible for deserializing the input data into
an object for prediction, can handle JSON.
The preprocess handler can be overridden for data or feature transformation,
The preprocess handler can be overridden for data or feature transformation.

Args:
input_data: the request payload serialized in the content_type format
content_type: the request content_type
input_data: the request payload serialized in the content_type format.
content_type: the request content_type.
context (obj): metadata on the incoming request data (default: None).

Returns:
decoded_input_data (dict): deserialized input_data into a Python dictonary.
Expand All @@ -136,13 +144,16 @@ def preprocess(self, input_data, content_type):
decoded_input_data = decoder_encoder.decode(input_data, content_type)
return decoded_input_data

def predict(self, data, model):
def predict(self, data, model, context=None):
"""The predict handler is responsible for model predictions. Calls the `__call__` method of the provided `Pipeline`
on decoded_input_data deserialized in input_fn. Runs prediction on GPU if is available.
The predict handler can be overridden to implement the model inference.

Args:
data (dict): deserialized decoded_input_data returned by the input_fn
model : Model returned by the `load` method or if it is a custom module `model_fn`.
context (obj): metadata on the incoming request data (default: None).

Returns:
obj (dict): prediction result.
"""
Expand All @@ -155,41 +166,47 @@ def predict(self, data, model):
if parameters is not None:
prediction = model(inputs, **parameters)
else:
print("These are the inputs")
print(inputs)
prediction = model(inputs)
return prediction

def postprocess(self, prediction, accept):
def postprocess(self, prediction, accept, context=None):
"""
The postprocess handler is responsible for serializing the prediction result to
the desired accept type, can handle JSON.
The postprocess handler can be overridden for inference response transformation
The postprocess handler can be overridden for inference response transformation.

Args:
prediction (dict): a prediction result from predict
accept (str): type which the output data needs to be serialized
prediction (dict): a prediction result from predict.
accept (str): type which the output data needs to be serialized.
context (obj): metadata on the incoming request data (default: None).
Returns: output data serialized
"""
return decoder_encoder.encode(prediction, accept)

def transform_fn(self, model, input_data, content_type, accept):
def transform_fn(self, model, input_data, content_type, accept, context=None):
"""
Transform function ("transform_fn") can be used to write one function with pre/post-processing steps and predict step in it.
This fuction can't be mixed with "input_fn", "output_fn" or "predict_fn"
This fuction can't be mixed with "input_fn", "output_fn" or "predict_fn".

Args:
model: Model returned by the model_fn above
input_data: Data received for inference
content_type: The content type of the inference data
accept: The response accept type.
context (obj): metadata on the incoming request data (default: None).

Returns: Response in the "accept" format type.

"""
# run pipeline
start_time = time.time()
processed_data = self.preprocess(input_data, content_type)
processed_data = self.run_handler_function(self.preprocess, *(input_data, content_type))
preprocess_time = time.time() - start_time
predictions = self.predict(processed_data, model)
predictions = self.run_handler_function(self.predict, *(processed_data, model))
predict_time = time.time() - preprocess_time - start_time
response = self.postprocess(predictions, accept)
response = self.run_handler_function(self.postprocess, *(predictions, accept))
postprocess_time = time.time() - predict_time - preprocess_time - start_time

logger.info(
Expand Down Expand Up @@ -231,7 +248,7 @@ def handle(self, data, context):
input_data = input_data.decode("utf-8")

predict_start = time.time()
response = self.transform_fn(self.model, input_data, content_type, accept)
response = self.run_handler_function(self.transform_fn, *(self.model, input_data, content_type, accept))
predict_end = time.time()

context.metrics.add_time("Transform Fn", round((predict_end - predict_start) * 1000, 2))
Expand Down Expand Up @@ -272,3 +289,22 @@ def validate_and_initialize_user_module(self):
self.postprocess = postprocess_fn
if transform_fn is not None:
self.transform_fn = transform_fn

def run_handler_function(self, func, *argv):
"""Helper to call the handler function which covers 2 cases:
1. the handle function takes context
2. the handle function does not take context
"""
num_func_input = len(signature(func).parameters)
if num_func_input == len(argv):
# function does not take context
result = func(*argv)
elif num_func_input == len(argv) + 1:
# function takes context
argv_context = argv + (self.context,)
result = func(*argv_context)
else:
raise TypeError(
"{} takes {} arguments but {} were given.".format(func.__name__, num_func_input, len(argv))
)
return result
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that needed? What overhead does it add to an request/function call?.

Why don't we just past self.context in the handle method in line

response = self.transform_fn(self.model, input_data, content_type, accept)

Copy link
Contributor Author

@sachanub sachanub Sep 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @philschmid . Thank you for your review. Maybe I am not understanding your point correctly, but I don't think we can directly pass self.context in the handle method. If we do that, don't we risk breaking existing customers who are not using context as an input argument? With the above mentioned run_handler_function, we should be able to support both customers who do not want to use context as well as customers who want to use context. Please correct me if I misunderstood. Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would it break it? you mean if they have a custom inference.py that defines a transform_fn method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we might affect those who define a custom transform_fn right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In handler_service.py, we are trying to import user defined functions from inference.py

            load_fn = getattr(user_module, "model_fn", None)
            preprocess_fn = getattr(user_module, "input_fn", None)
            predict_fn = getattr(user_module, "predict_fn", None)
            postprocess_fn = getattr(user_module, "output_fn", None)
            transform_fn = getattr(user_module, "transform_fn", None)

If existence, it will be used to overwrite self.load, self.preprocess, self.predict, self.postprocess and self.transform_fn

            if load_fn is not None:
                self.load = load_fn
            if preprocess_fn is not None:
                self.preprocess = preprocess_fn
            if predict_fn is not None:
                self.predict = predict_fn
            if postprocess_fn is not None:
                self.postprocess = postprocess_fn
            if transform_fn is not None:
                self.transform_fn = transform_fn

In this PR, we introduce additional parameter context for default handlers. However, for existing user scripts, they don't have this parameter when they implement customized handler function. We have to be careful when we call these functions. That's why @sachanub is adding a helper function to determine when to call functions with the new parameter.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that something we could deprecate and remove with a v3 of the toolkit? Feels like very error prone or a big overhead to parse and add args like this for every incoming request.

What do you think of adding a check to see if there is a inference.py provided and if not we are using self-transform directly? Most customer deploy models using the "zero-code" deployment feature, where you provide a MODEL_ID and TASK and don't need and inference.py script.

Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
def model_fn(model_dir, context=None):
return "model"


def input_fn(data, content_type, context=None):
return "data"


def predict_fn(data, model, context=None):
return "output"


def output_fn(prediction, accept, context=None):
return prediction
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import os


def model_fn(model_dir, context=None):
return f"Loading {os.path.basename(__file__)}"


def transform_fn(a, b, c, d, context=None):
return f"output {b}"
168 changes: 168 additions & 0 deletions tests/unit/test_handler_service_with_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile

import pytest
from sagemaker_inference import content_types
from transformers.testing_utils import require_torch, slow

from mms.context import Context, RequestProcessor
from mms.metrics.metrics_store import MetricsStore
from mock import Mock
from sagemaker_huggingface_inference_toolkit import handler_service
from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub, get_pipeline


TASK = "text-classification"
MODEL = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
INPUT = {"inputs": "My name is Wolfgang and I live in Berlin"}
OUTPUT = [
{"word": "Wolfgang", "score": 0.99, "entity": "I-PER", "index": 4, "start": 11, "end": 19},
{"word": "Berlin", "score": 0.99, "entity": "I-LOC", "index": 9, "start": 34, "end": 40},
]


@pytest.fixture()
def inference_handler():
return handler_service.HuggingFaceHandlerService()


def test_get_device_cpu(inference_handler):
device = inference_handler.get_device()
assert device == -1


@slow
def test_get_device_gpu(inference_handler):
device = inference_handler.get_device()
assert device > -1


@require_torch
def test_test_initialize(inference_handler):
with tempfile.TemporaryDirectory() as tmpdirname:
storage_folder = _load_model_from_hub(
model_id=MODEL,
model_dir=tmpdirname,
)
CONTEXT = Context(MODEL, storage_folder, {}, 1, -1, "1.1.4")

inference_handler.initialize(CONTEXT)
assert inference_handler.initialized is True


@require_torch
def test_handle(inference_handler):
with tempfile.TemporaryDirectory() as tmpdirname:
storage_folder = _load_model_from_hub(
model_id=MODEL,
model_dir=tmpdirname,
)
CONTEXT = Context(MODEL, storage_folder, {}, 1, -1, "1.1.4")
CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
CONTEXT.metrics = MetricsStore(1, MODEL)

inference_handler.initialize(CONTEXT)
json_data = json.dumps(INPUT)
prediction = inference_handler.handle([{"body": json_data.encode()}], CONTEXT)
loaded_response = json.loads(prediction[0])
assert "entity" in loaded_response[0]
assert "score" in loaded_response[0]


@require_torch
def test_load(inference_handler):
context = Mock()
with tempfile.TemporaryDirectory() as tmpdirname:
storage_folder = _load_model_from_hub(
model_id=MODEL,
model_dir=tmpdirname,
)
# test with automatic infer
hf_pipeline_without_task = inference_handler.load(storage_folder, context)
assert hf_pipeline_without_task.task == "token-classification"

# test with automatic infer
os.environ["HF_TASK"] = TASK
hf_pipeline_with_task = inference_handler.load(storage_folder, context)
assert hf_pipeline_with_task.task == TASK


def test_preprocess(inference_handler):
context = Mock()
json_data = json.dumps(INPUT)
decoded_input_data = inference_handler.preprocess(json_data, content_types.JSON, context)
assert "inputs" in decoded_input_data


def test_preprocess_bad_content_type(inference_handler):
context = Mock()
with pytest.raises(json.decoder.JSONDecodeError):
inference_handler.preprocess(b"", content_types.JSON, context)


@require_torch
def test_predict(inference_handler):
context = Mock()
with tempfile.TemporaryDirectory() as tmpdirname:
storage_folder = _load_model_from_hub(
model_id=MODEL,
model_dir=tmpdirname,
)
inference_handler.model = get_pipeline(task=TASK, device=-1, model_dir=storage_folder)
prediction = inference_handler.predict(INPUT, inference_handler.model, context)
assert "label" in prediction[0]
assert "score" in prediction[0]


def test_postprocess(inference_handler):
context = Mock()
output = inference_handler.postprocess(OUTPUT, content_types.JSON, context)
assert isinstance(output, str)


def test_validate_and_initialize_user_module(inference_handler):
model_dir = os.path.join(os.getcwd(), "tests/resources/model_input_predict_output_fn_with_context")
CONTEXT = Context("", model_dir, {}, 1, -1, "1.1.4")

inference_handler.initialize(CONTEXT)
CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
CONTEXT.metrics = MetricsStore(1, MODEL)

prediction = inference_handler.handle([{"body": b""}], CONTEXT)
assert "output" in prediction[0]

assert inference_handler.load({}, CONTEXT) == "model"
assert inference_handler.preprocess({}, "", CONTEXT) == "data"
assert inference_handler.predict({}, "model", CONTEXT) == "output"
assert inference_handler.postprocess("output", "", CONTEXT) == "output"


def test_validate_and_initialize_user_module_transform_fn():
os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py"
inference_handler = handler_service.HuggingFaceHandlerService()
model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_with_context")
CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4")

inference_handler.initialize(CONTEXT)
CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
CONTEXT.metrics = MetricsStore(1, MODEL)
assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0]
assert inference_handler.load({}, CONTEXT) == "Loading inference_tranform_fn.py"
assert (
inference_handler.transform_fn("model", "dummy", "application/json", "application/json", CONTEXT)
== "output dummy"
)
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_postprocess(inference_handler):


def test_validate_and_initialize_user_module(inference_handler):
model_dir = os.path.join(os.getcwd(), "tests/resources/model_input_predict_output_fn")
model_dir = os.path.join(os.getcwd(), "tests/resources/model_input_predict_output_fn_without_context")
CONTEXT = Context("", model_dir, {}, 1, -1, "1.1.4")

inference_handler.initialize(CONTEXT)
Expand All @@ -148,7 +148,7 @@ def test_validate_and_initialize_user_module(inference_handler):
def test_validate_and_initialize_user_module_transform_fn():
os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py"
inference_handler = handler_service.HuggingFaceHandlerService()
model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn")
model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_without_context")
CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4")

inference_handler.initialize(CONTEXT)
Expand Down
Loading