diff --git a/src/sagemaker_huggingface_inference_toolkit/handler_service.py b/src/sagemaker_huggingface_inference_toolkit/handler_service.py index af4d396..05cfae1 100644 --- a/src/sagemaker_huggingface_inference_toolkit/handler_service.py +++ b/src/sagemaker_huggingface_inference_toolkit/handler_service.py @@ -18,6 +18,7 @@ import sys import time from abc import ABC +from inspect import signature from sagemaker_inference import environment, utils from transformers.pipelines import SUPPORTED_TASKS @@ -57,6 +58,11 @@ def __init__(self): self.context = None self.manifest = None self.environment = environment.Environment() + self.load_extra_arg = [] + self.preprocess_extra_arg = [] + self.predict_extra_arg = [] + self.postprocess_extra_arg = [] + self.transform_extra_arg = [] def initialize(self, context): """ @@ -74,7 +80,7 @@ def initialize(self, context): self.validate_and_initialize_user_module() self.device = self.get_device() - self.model = self.load(self.model_dir) + self.model = self.load(*([self.model_dir] + self.load_extra_arg)) self.initialized = True # # Load methods from file # if (not self._initialized) and ENABLE_MULTI_MODEL: @@ -92,10 +98,15 @@ def get_device(self): else: return -1 - def load(self, model_dir): + def load(self, model_dir, context=None): """ The Load handler is responsible for loading the Hugging Face transformer model. - It can be overridden to load the model from storage + It can be overridden to load the model from storage. + + Args: + model_dir (str): The directory where model files are stored. + context (obj): metadata on the incoming request data (default: None). + Returns: hf_pipeline (Pipeline): A Hugging Face Transformer pipeline. """ @@ -111,14 +122,16 @@ def load(self, model_dir): ) return hf_pipeline - def preprocess(self, input_data, content_type): + def preprocess(self, input_data, content_type, context=None): """ The preprocess handler is responsible for deserializing the input data into an object for prediction, can handle JSON. - The preprocess handler can be overridden for data or feature transformation, + The preprocess handler can be overridden for data or feature transformation. + Args: - input_data: the request payload serialized in the content_type format - content_type: the request content_type + input_data: the request payload serialized in the content_type format. + content_type: the request content_type. + context (obj): metadata on the incoming request data (default: None). Returns: decoded_input_data (dict): deserialized input_data into a Python dictonary. @@ -136,13 +149,16 @@ def preprocess(self, input_data, content_type): decoded_input_data = decoder_encoder.decode(input_data, content_type) return decoded_input_data - def predict(self, data, model): + def predict(self, data, model, context=None): """The predict handler is responsible for model predictions. Calls the `__call__` method of the provided `Pipeline` on decoded_input_data deserialized in input_fn. Runs prediction on GPU if is available. The predict handler can be overridden to implement the model inference. + Args: data (dict): deserialized decoded_input_data returned by the input_fn model : Model returned by the `load` method or if it is a custom module `model_fn`. + context (obj): metadata on the incoming request data (default: None). + Returns: obj (dict): prediction result. """ @@ -158,38 +174,42 @@ def predict(self, data, model): prediction = model(inputs) return prediction - def postprocess(self, prediction, accept): + def postprocess(self, prediction, accept, context=None): """ The postprocess handler is responsible for serializing the prediction result to the desired accept type, can handle JSON. - The postprocess handler can be overridden for inference response transformation + The postprocess handler can be overridden for inference response transformation. + Args: - prediction (dict): a prediction result from predict - accept (str): type which the output data needs to be serialized + prediction (dict): a prediction result from predict. + accept (str): type which the output data needs to be serialized. + context (obj): metadata on the incoming request data (default: None). Returns: output data serialized """ return decoder_encoder.encode(prediction, accept) - def transform_fn(self, model, input_data, content_type, accept): + def transform_fn(self, model, input_data, content_type, accept, context=None): """ Transform function ("transform_fn") can be used to write one function with pre/post-processing steps and predict step in it. - This fuction can't be mixed with "input_fn", "output_fn" or "predict_fn" + This fuction can't be mixed with "input_fn", "output_fn" or "predict_fn". + Args: model: Model returned by the model_fn above input_data: Data received for inference content_type: The content type of the inference data accept: The response accept type. + context (obj): metadata on the incoming request data (default: None). Returns: Response in the "accept" format type. """ # run pipeline start_time = time.time() - processed_data = self.preprocess(input_data, content_type) + processed_data = self.preprocess(*([input_data, content_type] + self.preprocess_extra_arg)) preprocess_time = time.time() - start_time - predictions = self.predict(processed_data, model) + predictions = self.predict(*([processed_data, model] + self.predict_extra_arg)) predict_time = time.time() - preprocess_time - start_time - response = self.postprocess(predictions, accept) + response = self.postprocess(*([predictions, accept] + self.postprocess_extra_arg)) postprocess_time = time.time() - predict_time - preprocess_time - start_time logger.info( @@ -231,7 +251,7 @@ def handle(self, data, context): input_data = input_data.decode("utf-8") predict_start = time.time() - response = self.transform_fn(self.model, input_data, content_type, accept) + response = self.transform_fn(*([self.model, input_data, content_type, accept] + self.transform_extra_arg)) predict_end = time.time() context.metrics.add_time("Transform Fn", round((predict_end - predict_start) * 1000, 2)) @@ -263,12 +283,38 @@ def validate_and_initialize_user_module(self): ) if load_fn is not None: + self.load_extra_arg = self.function_extra_arg(self.load, load_fn) self.load = load_fn if preprocess_fn is not None: + self.preprocess_extra_arg = self.function_extra_arg(self.preprocess, preprocess_fn) self.preprocess = preprocess_fn if predict_fn is not None: + self.predict_extra_arg = self.function_extra_arg(self.predict, predict_fn) self.predict = predict_fn if postprocess_fn is not None: + self.postprocess_extra_arg = self.function_extra_arg(self.postprocess, postprocess_fn) self.postprocess = postprocess_fn if transform_fn is not None: + self.transform_extra_arg = self.function_extra_arg(self.transform_fn, transform_fn) self.transform_fn = transform_fn + + def function_extra_arg(self, default_func, func): + """Helper to call the handler function which covers 2 cases: + 1. the handle function takes context + 2. the handle function does not take context + """ + num_default_func_input = len(signature(default_func).parameters) + num_func_input = len(signature(func).parameters) + if num_default_func_input == num_func_input: + # function takes context + extra_args = [self.context] + elif num_default_func_input == num_func_input + 1: + # function does not take context + extra_args = [] + else: + raise TypeError( + "{} definition takes {} or {} arguments but {} were given.".format( + func.__name__, num_default_func_input - 1, num_default_func_input, num_func_input + ) + ) + return extra_args diff --git a/tests/resources/model_input_predict_output_fn_with_context/code/inference.py b/tests/resources/model_input_predict_output_fn_with_context/code/inference.py new file mode 100644 index 0000000..7de1037 --- /dev/null +++ b/tests/resources/model_input_predict_output_fn_with_context/code/inference.py @@ -0,0 +1,14 @@ +def model_fn(model_dir, context=None): + return "model" + + +def input_fn(data, content_type, context=None): + return "data" + + +def predict_fn(data, model, context=None): + return "output" + + +def output_fn(prediction, accept, context=None): + return prediction diff --git a/tests/resources/model_input_predict_output_fn/code/inference.py b/tests/resources/model_input_predict_output_fn_without_context/code/inference.py similarity index 100% rename from tests/resources/model_input_predict_output_fn/code/inference.py rename to tests/resources/model_input_predict_output_fn_without_context/code/inference.py diff --git a/tests/resources/model_transform_fn_with_context/code/inference_tranform_fn.py b/tests/resources/model_transform_fn_with_context/code/inference_tranform_fn.py new file mode 100644 index 0000000..ef93e62 --- /dev/null +++ b/tests/resources/model_transform_fn_with_context/code/inference_tranform_fn.py @@ -0,0 +1,9 @@ +import os + + +def model_fn(model_dir, context=None): + return f"Loading {os.path.basename(__file__)}" + + +def transform_fn(a, b, c, d, context=None): + return f"output {b}" diff --git a/tests/resources/model_transform_fn/code/inference_tranform_fn.py b/tests/resources/model_transform_fn_without_context/code/inference_tranform_fn.py similarity index 100% rename from tests/resources/model_transform_fn/code/inference_tranform_fn.py rename to tests/resources/model_transform_fn_without_context/code/inference_tranform_fn.py diff --git a/tests/unit/test_handler_service_with_context.py b/tests/unit/test_handler_service_with_context.py new file mode 100644 index 0000000..a8b5b71 --- /dev/null +++ b/tests/unit/test_handler_service_with_context.py @@ -0,0 +1,168 @@ +# Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import tempfile + +import pytest +from sagemaker_inference import content_types +from transformers.testing_utils import require_torch, slow + +from mms.context import Context, RequestProcessor +from mms.metrics.metrics_store import MetricsStore +from mock import Mock +from sagemaker_huggingface_inference_toolkit import handler_service +from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub, get_pipeline + + +TASK = "text-classification" +MODEL = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english" +INPUT = {"inputs": "My name is Wolfgang and I live in Berlin"} +OUTPUT = [ + {"word": "Wolfgang", "score": 0.99, "entity": "I-PER", "index": 4, "start": 11, "end": 19}, + {"word": "Berlin", "score": 0.99, "entity": "I-LOC", "index": 9, "start": 34, "end": 40}, +] + + +@pytest.fixture() +def inference_handler(): + return handler_service.HuggingFaceHandlerService() + + +def test_get_device_cpu(inference_handler): + device = inference_handler.get_device() + assert device == -1 + + +@slow +def test_get_device_gpu(inference_handler): + device = inference_handler.get_device() + assert device > -1 + + +@require_torch +def test_test_initialize(inference_handler): + with tempfile.TemporaryDirectory() as tmpdirname: + storage_folder = _load_model_from_hub( + model_id=MODEL, + model_dir=tmpdirname, + ) + CONTEXT = Context(MODEL, storage_folder, {}, 1, -1, "1.1.4") + + inference_handler.initialize(CONTEXT) + assert inference_handler.initialized is True + + +@require_torch +def test_handle(inference_handler): + with tempfile.TemporaryDirectory() as tmpdirname: + storage_folder = _load_model_from_hub( + model_id=MODEL, + model_dir=tmpdirname, + ) + CONTEXT = Context(MODEL, storage_folder, {}, 1, -1, "1.1.4") + CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})] + CONTEXT.metrics = MetricsStore(1, MODEL) + + inference_handler.initialize(CONTEXT) + json_data = json.dumps(INPUT) + prediction = inference_handler.handle([{"body": json_data.encode()}], CONTEXT) + loaded_response = json.loads(prediction[0]) + assert "entity" in loaded_response[0] + assert "score" in loaded_response[0] + + +@require_torch +def test_load(inference_handler): + context = Mock() + with tempfile.TemporaryDirectory() as tmpdirname: + storage_folder = _load_model_from_hub( + model_id=MODEL, + model_dir=tmpdirname, + ) + # test with automatic infer + hf_pipeline_without_task = inference_handler.load(storage_folder, context) + assert hf_pipeline_without_task.task == "token-classification" + + # test with automatic infer + os.environ["HF_TASK"] = TASK + hf_pipeline_with_task = inference_handler.load(storage_folder, context) + assert hf_pipeline_with_task.task == TASK + + +def test_preprocess(inference_handler): + context = Mock() + json_data = json.dumps(INPUT) + decoded_input_data = inference_handler.preprocess(json_data, content_types.JSON, context) + assert "inputs" in decoded_input_data + + +def test_preprocess_bad_content_type(inference_handler): + context = Mock() + with pytest.raises(json.decoder.JSONDecodeError): + inference_handler.preprocess(b"", content_types.JSON, context) + + +@require_torch +def test_predict(inference_handler): + context = Mock() + with tempfile.TemporaryDirectory() as tmpdirname: + storage_folder = _load_model_from_hub( + model_id=MODEL, + model_dir=tmpdirname, + ) + inference_handler.model = get_pipeline(task=TASK, device=-1, model_dir=storage_folder) + prediction = inference_handler.predict(INPUT, inference_handler.model, context) + assert "label" in prediction[0] + assert "score" in prediction[0] + + +def test_postprocess(inference_handler): + context = Mock() + output = inference_handler.postprocess(OUTPUT, content_types.JSON, context) + assert isinstance(output, str) + + +def test_validate_and_initialize_user_module(inference_handler): + model_dir = os.path.join(os.getcwd(), "tests/resources/model_input_predict_output_fn_with_context") + CONTEXT = Context("", model_dir, {}, 1, -1, "1.1.4") + + inference_handler.initialize(CONTEXT) + CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})] + CONTEXT.metrics = MetricsStore(1, MODEL) + + prediction = inference_handler.handle([{"body": b""}], CONTEXT) + assert "output" in prediction[0] + + assert inference_handler.load({}, CONTEXT) == "model" + assert inference_handler.preprocess({}, "", CONTEXT) == "data" + assert inference_handler.predict({}, "model", CONTEXT) == "output" + assert inference_handler.postprocess("output", "", CONTEXT) == "output" + + +def test_validate_and_initialize_user_module_transform_fn(): + os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py" + inference_handler = handler_service.HuggingFaceHandlerService() + model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_with_context") + CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4") + + inference_handler.initialize(CONTEXT) + CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})] + CONTEXT.metrics = MetricsStore(1, MODEL) + assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0] + assert inference_handler.load({}, CONTEXT) == "Loading inference_tranform_fn.py" + assert ( + inference_handler.transform_fn("model", "dummy", "application/json", "application/json", CONTEXT) + == "output dummy" + ) diff --git a/tests/unit/test_handler_service.py b/tests/unit/test_handler_service_without_context.py similarity index 98% rename from tests/unit/test_handler_service.py rename to tests/unit/test_handler_service_without_context.py index bc30608..cda8360 100644 --- a/tests/unit/test_handler_service.py +++ b/tests/unit/test_handler_service_without_context.py @@ -129,7 +129,7 @@ def test_postprocess(inference_handler): def test_validate_and_initialize_user_module(inference_handler): - model_dir = os.path.join(os.getcwd(), "tests/resources/model_input_predict_output_fn") + model_dir = os.path.join(os.getcwd(), "tests/resources/model_input_predict_output_fn_without_context") CONTEXT = Context("", model_dir, {}, 1, -1, "1.1.4") inference_handler.initialize(CONTEXT) @@ -148,7 +148,7 @@ def test_validate_and_initialize_user_module(inference_handler): def test_validate_and_initialize_user_module_transform_fn(): os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py" inference_handler = handler_service.HuggingFaceHandlerService() - model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn") + model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_without_context") CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4") inference_handler.initialize(CONTEXT)