diff --git a/src/sagemaker/base_predictor.py b/src/sagemaker/base_predictor.py index 76b83c25cd..1a7eea9cd7 100644 --- a/src/sagemaker/base_predictor.py +++ b/src/sagemaker/base_predictor.py @@ -52,6 +52,7 @@ JSONSerializer, NumpySerializer, ) +from sagemaker.iterators import ByteIterator from sagemaker.session import production_variant, Session from sagemaker.utils import name_from_base, stringify_object, format_tags @@ -225,6 +226,7 @@ def _create_request_args( target_variant=None, inference_id=None, custom_attributes=None, + target_container_hostname=None, ): """Placeholder docstring""" @@ -286,9 +288,89 @@ def _create_request_args( if self._get_component_name(): args["InferenceComponentName"] = self.component_name + if target_container_hostname: + args["TargetContainerHostname"] = target_container_hostname + args["Body"] = data return args + def predict_stream( + self, + data, + initial_args=None, + target_variant=None, + inference_id=None, + custom_attributes=None, + component_name: Optional[str] = None, + target_container_hostname=None, + iterator=ByteIterator, + ): + """Return the inference from the specified endpoint. + + Args: + data (object): Input data for which you want the model to provide + inference. If a serializer was specified when creating the + Predictor, the result of the serializer is sent as input + data. Otherwise the data must be sequence of bytes, and the + predict method then sends the bytes in the request body as is. + initial_args (dict[str,str]): Optional. Default arguments for boto3 + ``invoke_endpoint_with_response_stream`` call. Default is None (no default + arguments). (Default: None) + target_variant (str): Optional. The name of the production variant to run an inference + request on (Default: None). Note that the ProductionVariant identifies the + model you want to host and the resources you want to deploy for hosting it. + inference_id (str): Optional. If you provide a value, it is added to the captured data + when you enable data capture on the endpoint (Default: None). + custom_attributes (str): Optional. Provides additional information about a request for + an inference submitted to a model hosted at an Amazon SageMaker endpoint. + The information is an opaque value that is forwarded verbatim. You could use this + value, for example, to provide an ID that you can use to track a request or to + provide other metadata that a service endpoint was programmed to process. The value + must consist of no more than 1024 visible US-ASCII characters. + + The code in your model is responsible for setting or updating any custom attributes + in the response. If your code does not set this value in the response, an empty + value is returned. For example, if a custom attribute represents the trace ID, your + model can prepend the custom attribute with Trace ID: in your post-processing + function (Default: None). + component_name (str): Optional. Name of the Amazon SageMaker inference component + corresponding the predictor. (Default: None) + target_container_hostname (str): Optional. If the endpoint hosts multiple containers + and is configured to use direct invocation, this parameter specifies the host name + of the container to invoke. (Default: None). + iterator (:class:`~sagemaker.iterators.BaseIterator`): An iterator class which provides + an iterable interface to iterate Event stream response from Inference Endpoint. + An object of the iterator class provided will be returned by the predict_stream + method (Default::class:`~sagemaker.iterators.ByteIterator`). Iterators defined in + :class:`~sagemaker.iterators` or custom iterators (needs to inherit + :class:`~sagemaker.iterators.BaseIterator`) can be specified as an input. + + Returns: + object (:class:`~sagemaker.iterators.BaseIterator`): An iterator object which would + allow iteration on EventStream response will be returned. The object would be + instantiated from `predict_stream` method's `iterator` parameter. + """ + # [TODO]: clean up component_name in _create_request_args + request_args = self._create_request_args( + data=data, + initial_args=initial_args, + target_variant=target_variant, + inference_id=inference_id, + custom_attributes=custom_attributes, + target_container_hostname=target_container_hostname, + ) + + inference_component_name = component_name or self._get_component_name() + if inference_component_name: + request_args["InferenceComponentName"] = inference_component_name + + response = ( + self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint_with_response_stream( + **request_args + ) + ) + return iterator(response["Body"]) + def update_endpoint( self, initial_instance_count=None, diff --git a/src/sagemaker/exceptions.py b/src/sagemaker/exceptions.py index b9d97cc241..88ffa0a591 100644 --- a/src/sagemaker/exceptions.py +++ b/src/sagemaker/exceptions.py @@ -86,3 +86,23 @@ class AsyncInferenceModelError(AsyncInferenceError): def __init__(self, message): super().__init__(message=message) + + +class ModelStreamError(Exception): + """Raised when invoke_endpoint_with_response_stream Response returns ModelStreamError""" + + def __init__(self, message="An error occurred", code=None): + self.message = message + self.code = code + if code is not None: + super().__init__(f"{message} (Code: {code})") + else: + super().__init__(message) + + +class InternalStreamFailure(Exception): + """Raised when invoke_endpoint_with_response_stream Response returns InternalStreamFailure""" + + def __init__(self, message="An error occurred"): + self.message = message + super().__init__(self.message) diff --git a/src/sagemaker/iterators.py b/src/sagemaker/iterators.py new file mode 100644 index 0000000000..38a43121a1 --- /dev/null +++ b/src/sagemaker/iterators.py @@ -0,0 +1,186 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Implements iterators for deserializing data returned from an inference streaming endpoint.""" +from __future__ import absolute_import + +from abc import ABC, abstractmethod +import io + +from sagemaker.exceptions import ModelStreamError, InternalStreamFailure + + +def handle_stream_errors(chunk): + """Handle API Response errors within `invoke_endpoint_with_response_stream` API if any. + + Args: + chunk (dict): A chunk of response received as part of `botocore.eventstream.EventStream` + response object. + + Raises: + ModelStreamError: If `ModelStreamError` error is detected in a chunk of + `botocore.eventstream.EventStream` response object. + InternalStreamFailure: If `InternalStreamFailure` error is detected in a chunk of + `botocore.eventstream.EventStream` response object. + """ + if "ModelStreamError" in chunk: + raise ModelStreamError( + chunk["ModelStreamError"]["Message"], code=chunk["ModelStreamError"]["ErrorCode"] + ) + if "InternalStreamFailure" in chunk: + raise InternalStreamFailure(chunk["InternalStreamFailure"]["Message"]) + + +class BaseIterator(ABC): + """Abstract base class for Inference Streaming iterators. + + Provides a skeleton for customization requiring the overriding of iterator methods + __iter__ and __next__. + + Tenets of iterator class for Streaming Inference API Response + (https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ + sagemaker-runtime/client/invoke_endpoint_with_response_stream.html): + 1. Needs to accept an botocore.eventstream.EventStream response. + 2. Needs to implement logic in __next__ to: + 2.1. Concatenate and provide next chunk of response from botocore.eventstream.EventStream. + While doing so parse the response_chunk["PayloadPart"]["Bytes"]. + 2.2. If PayloadPart not in EventStream response, handle Errors + [Recommended to use `iterators.handle_stream_errors` method]. + """ + + def __init__(self, event_stream): + """Initialises a Iterator object to help parse the byte event stream input. + + Args: + event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. + """ + self.event_stream = event_stream + + @abstractmethod + def __iter__(self): + """Abstract method, returns an iterator object itself""" + return self + + @abstractmethod + def __next__(self): + """Abstract method, is responsible for returning the next element in the iteration""" + + +class ByteIterator(BaseIterator): + """A helper class for parsing the byte Event Stream input to provide Byte iteration.""" + + def __init__(self, event_stream): + """Initialises a BytesIterator Iterator object + + Args: + event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. + """ + super().__init__(event_stream) + self.byte_iterator = iter(event_stream) + + def __iter__(self): + """Returns an iterator object itself, which allows the object to be iterated. + + Returns: + iter : object + An iterator object representing the iterable. + """ + return self + + def __next__(self): + """Returns the next chunk of Byte directly.""" + # Even with "while True" loop the function still behaves like a generator + # and sends the next new byte chunk. + while True: + chunk = next(self.byte_iterator) + if "PayloadPart" not in chunk: + # handle API response errors and force terminate. + handle_stream_errors(chunk) + # print and move on to next response byte + print("Unknown event type:" + chunk) + continue + return chunk["PayloadPart"]["Bytes"] + + +class LineIterator(BaseIterator): + """A helper class for parsing the byte Event Stream input to provide Line iteration.""" + + def __init__(self, event_stream): + """Initialises a LineIterator Iterator object + + Args: + event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. + """ + super().__init__(event_stream) + self.byte_iterator = iter(self.event_stream) + self.buffer = io.BytesIO() + self.read_pos = 0 + + def __iter__(self): + """Returns an iterator object itself, which allows the object to be iterated. + + Returns: + iter : object + An iterator object representing the iterable. + """ + return self + + def __next__(self): + r"""Returns the next Line for an Line iterable. + + The output of the event stream will be in the following format: + + ``` + b'{"outputs": [" a"]}\n' + b'{"outputs": [" challenging"]}\n' + b'{"outputs": [" problem"]}\n' + ... + ``` + + While usually each PayloadPart event from the event stream will contain a byte array + with a full json, this is not guaranteed and some of the json objects may be split across + PayloadPart events. For example: + ``` + {'PayloadPart': {'Bytes': b'{"outputs": '}} + {'PayloadPart': {'Bytes': b'[" problem"]}\n'}} + ``` + + This class accounts for this by concatenating bytes written via the 'write' function + and then exposing a method which will return lines (ending with a '\n' character) within + the buffer via the 'scan_lines' function. It maintains the position of the last read + position to ensure that previous bytes are not exposed again. + + Returns: + str: Read and return one line from the event stream. + """ + # Even with "while True" loop the function still behaves like a generator + # and sends the next new concatenated line + while True: + self.buffer.seek(self.read_pos) + line = self.buffer.readline() + if line and line[-1] == ord("\n"): + self.read_pos += len(line) + return line[:-1] + try: + chunk = next(self.byte_iterator) + except StopIteration: + if self.read_pos < self.buffer.getbuffer().nbytes: + continue + raise + if "PayloadPart" not in chunk: + # handle API response errors and force terminate. + handle_stream_errors(chunk) + # print and move on to next response byte + print("Unknown event type:" + chunk) + continue + self.buffer.seek(0, io.SEEK_END) + self.buffer.write(chunk["PayloadPart"]["Bytes"]) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index fff421ab32..e9a34a21a8 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -65,7 +65,9 @@ def __init__( self, region: Optional[str] = None, max_s3_cache_items: int = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, - s3_cache_expiration_horizon: datetime.timedelta = JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON, + s3_cache_expiration_horizon: datetime.timedelta = ( + JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON + ), max_semantic_version_cache_items: int = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, semantic_version_cache_expiration_horizon: datetime.timedelta = ( JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON diff --git a/tests/data/lmi-model-falcon-7b/mymodel-7B.tar.gz b/tests/data/lmi-model-falcon-7b/mymodel-7B.tar.gz new file mode 100644 index 0000000000..6a66178b47 Binary files /dev/null and b/tests/data/lmi-model-falcon-7b/mymodel-7B.tar.gz differ diff --git a/tests/integ/test_predict_stream.py b/tests/integ/test_predict_stream.py new file mode 100644 index 0000000000..bdd19187df --- /dev/null +++ b/tests/integ/test_predict_stream.py @@ -0,0 +1,109 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import os +import pytest + +import tests.integ +import tests.integ.timeout + +from sagemaker import image_uris +from sagemaker.iterators import LineIterator +from sagemaker.model import Model +from sagemaker.predictor import Predictor +from sagemaker.utils import unique_name_from_base + +from tests.integ import DATA_DIR + + +ROLE = "SageMakerRole" +INSTANCE_COUNT = 1 +INSTANCE_TYPE = "ml.g5.2xlarge" +LMI_FALCON_7B_DATA_PATH = os.path.join(DATA_DIR, "lmi-model-falcon-7b") + + +@pytest.yield_fixture(scope="module") +def endpoint_name(sagemaker_session): + lmi_endpoint_name = unique_name_from_base("lmi-model-falcon-7b") + model_data = sagemaker_session.upload_data( + path=os.path.join(LMI_FALCON_7B_DATA_PATH, "mymodel-7B.tar.gz"), + key_prefix="large-model-lmi/code", + ) + + image_uri = image_uris.retrieve( + framework="djl-deepspeed", region=sagemaker_session.boto_region_name, version="0.23.0" + ) + + with tests.integ.timeout.timeout_and_delete_endpoint_by_name( + endpoint_name=lmi_endpoint_name, sagemaker_session=sagemaker_session, hours=2 + ): + lmi_model = Model( + sagemaker_session=sagemaker_session, + model_data=model_data, + image_uri=image_uri, + name=lmi_endpoint_name, # model name + role=ROLE, + ) + lmi_model.deploy( + INSTANCE_COUNT, + INSTANCE_TYPE, + endpoint_name=lmi_endpoint_name, + container_startup_health_check_timeout=900, + ) + yield lmi_endpoint_name + + +def test_predict_stream(sagemaker_session, endpoint_name): + data = {"inputs": "what does AWS stand for?", "parameters": {"max_new_tokens": 400}} + initial_args = {"ContentType": "application/json"} + predictor = Predictor( + endpoint_name=endpoint_name, + sagemaker_session=sagemaker_session, + ) + + # Validate that no exception is raised when the target_variant is specified. + stream_iterator = predictor.predict_stream( + data=json.dumps(data), + initial_args=initial_args, + iterator=LineIterator, + ) + + response = "" + for line in stream_iterator: + resp = json.loads(line) + response += resp.get("outputs")[0] + + assert "AWS stands for Amazon Web Services." in response + + data = {"inputs": "what does AWS stand for?", "parameters": {"max_new_tokens": 400}} + initial_args = {"ContentType": "application/json"} + predictor = Predictor( + endpoint_name=endpoint_name, + sagemaker_session=sagemaker_session, + ) + + # Validate that no exception is raised when the target_variant is specified. + # uses the default `sagemaker.iterator.ByteIterator` + stream_iterator = predictor.predict_stream( + data=json.dumps(data), + initial_args=initial_args, + ) + + response = "" + for line in stream_iterator: + resp = json.loads(line) + response += resp.get("outputs")[0] + + assert "AWS stands for Amazon Web Services." in response diff --git a/tests/unit/sagemaker/iterators/test_iterators.py b/tests/unit/sagemaker/iterators/test_iterators.py new file mode 100644 index 0000000000..89e9d43a47 --- /dev/null +++ b/tests/unit/sagemaker/iterators/test_iterators.py @@ -0,0 +1,126 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import unittest +from unittest.mock import MagicMock + +from sagemaker.exceptions import ModelStreamError, InternalStreamFailure +from sagemaker.iterators import ByteIterator, LineIterator + + +class TestByteIterator(unittest.TestCase): + def test_iteration_with_payload_parts(self): + # Mocking the stream object + self.stream = MagicMock() + self.stream.__iter__.return_value = [ + {"PayloadPart": {"Bytes": b'{"outputs": [" a"]}\n'}}, + {"PayloadPart": {"Bytes": b'{"outputs": [" challenging"]}\n'}}, + {"PayloadPart": {"Bytes": b'{"outputs": [" problem"]}\n'}}, + ] + self.iterator = ByteIterator(self.stream) + + lines = list(self.iterator) + expected_lines = [ + b'{"outputs": [" a"]}\n', + b'{"outputs": [" challenging"]}\n', + b'{"outputs": [" problem"]}\n', + ] + self.assertEqual(lines, expected_lines) + + def test_iteration_with_model_stream_error(self): + # Mocking the stream object + self.stream = MagicMock() + self.stream.__iter__.return_value = [ + {"PayloadPart": {"Bytes": b'{"outputs": [" a"]}\n'}}, + {"PayloadPart": {"Bytes": b'{"outputs": [" challenging"]}\n'}}, + {"ModelStreamError": {"Message": "Error message", "ErrorCode": "500"}}, + {"PayloadPart": {"Bytes": b'{"outputs": [" problem"]}\n'}}, + ] + self.iterator = ByteIterator(self.stream) + + with self.assertRaises(ModelStreamError) as e: + list(self.iterator) + + self.assertEqual(str(e.exception.message), "Error message") + self.assertEqual(str(e.exception.code), "500") + + def test_iteration_with_internal_stream_failure(self): + # Mocking the stream object + self.stream = MagicMock() + self.stream.__iter__.return_value = [ + {"PayloadPart": {"Bytes": b'{"outputs": [" a"]}\n'}}, + {"PayloadPart": {"Bytes": b'{"outputs": [" challenging"]}\n'}}, + {"InternalStreamFailure": {"Message": "Error internal stream failure"}}, + {"PayloadPart": {"Bytes": b'{"outputs": [" problem"]}\n'}}, + ] + self.iterator = ByteIterator(self.stream) + + with self.assertRaises(InternalStreamFailure) as e: + list(self.iterator) + + self.assertEqual(str(e.exception.message), "Error internal stream failure") + + +class TestLineIterator(unittest.TestCase): + def test_iteration_with_payload_parts(self): + # Mocking the stream object + self.stream = MagicMock() + self.stream.__iter__.return_value = [ + {"PayloadPart": {"Bytes": b'{"outputs": [" a"]}\n'}}, + {"PayloadPart": {"Bytes": b'{"outputs": [" challenging"]}\n'}}, + {"PayloadPart": {"Bytes": b'{"outputs": '}}, + {"PayloadPart": {"Bytes": b'[" problem"]}\n'}}, + ] + self.iterator = LineIterator(self.stream) + + lines = list(self.iterator) + expected_lines = [ + b'{"outputs": [" a"]}', + b'{"outputs": [" challenging"]}', + b'{"outputs": [" problem"]}', + ] + self.assertEqual(lines, expected_lines) + + def test_iteration_with_model_stream_error(self): + # Mocking the stream object + self.stream = MagicMock() + self.stream.__iter__.return_value = [ + {"PayloadPart": {"Bytes": b'{"outputs": [" a"]}\n'}}, + {"PayloadPart": {"Bytes": b'{"outputs": [" challenging"]}\n'}}, + {"ModelStreamError": {"Message": "Error message", "ErrorCode": "500"}}, + {"PayloadPart": {"Bytes": b'{"outputs": [" problem"]}\n'}}, + ] + self.iterator = LineIterator(self.stream) + + with self.assertRaises(ModelStreamError) as e: + list(self.iterator) + + self.assertEqual(str(e.exception.message), "Error message") + self.assertEqual(str(e.exception.code), "500") + + def test_iteration_with_internal_stream_failure(self): + # Mocking the stream object + self.stream = MagicMock() + self.stream.__iter__.return_value = [ + {"PayloadPart": {"Bytes": b'{"outputs": [" a"]}\n'}}, + {"PayloadPart": {"Bytes": b'{"outputs": [" challenging"]}\n'}}, + {"InternalStreamFailure": {"Message": "Error internal stream failure"}}, + {"PayloadPart": {"Bytes": b'{"outputs": [" problem"]}\n'}}, + ] + self.iterator = LineIterator(self.stream) + + with self.assertRaises(InternalStreamFailure) as e: + list(self.iterator) + + self.assertEqual(str(e.exception.message), "Error internal stream failure") diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index 1ee9babdf7..1e4f6d0f0a 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -34,6 +34,7 @@ CSV_RETURN_VALUE = "1,2,3\r\n" PRODUCTION_VARIANT_1 = "PRODUCTION_VARIANT_1" INFERENCE_ID = "inference-id" +STREAM_ITERABLE_BODY = ["This", "is", "stream", "response"] ENDPOINT_DESC = {"EndpointArn": "foo", "EndpointConfigName": ENDPOINT} @@ -56,6 +57,11 @@ def empty_sagemaker_session(): ims.sagemaker_runtime_client.invoke_endpoint = Mock( name="invoke_endpoint", return_value={"Body": response_body} ) + + stream_response_body = STREAM_ITERABLE_BODY + ims.sagemaker_runtime_client.invoke_endpoint_with_response_stream = Mock( + name="invoke_endpoint_with_response_stream", return_value={"Body": stream_response_body} + ) return ims @@ -260,6 +266,75 @@ def test_predict_call_with_multiple_accept_types(): assert kwargs == expected_request_args +def test_predict_stream_call_pass_through(): + sagemaker_session = empty_sagemaker_session() + predictor = Predictor(ENDPOINT, sagemaker_session) + + data = "dummy" + result = predictor.predict_stream(data, iterator=list) + + assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint_with_response_stream.called + assert sagemaker_session.sagemaker_client.describe_endpoint.not_called + assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called + + expected_request_args = { + "Accept": DEFAULT_ACCEPT, + "Body": data, + "ContentType": DEFAULT_CONTENT_TYPE, + "EndpointName": ENDPOINT, + } + + ( + call_args, + kwargs, + ) = sagemaker_session.sagemaker_runtime_client.invoke_endpoint_with_response_stream.call_args + assert kwargs == expected_request_args + + assert result == STREAM_ITERABLE_BODY + + +def test_predict_stream_call_all_args(): + sagemaker_session = empty_sagemaker_session() + predictor = Predictor(ENDPOINT, sagemaker_session) + + data = "dummy" + initial_args = {"ContentType": "application/json"} + result = predictor.predict_stream( + data, + initial_args=initial_args, + target_variant=PRODUCTION_VARIANT_1, + inference_id=INFERENCE_ID, + custom_attributes="custom-attribute", + component_name="test_component_name", + target_container_hostname="test_target_container_hostname", + iterator=list, + ) + + assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint_with_response_stream.called + assert sagemaker_session.sagemaker_client.describe_endpoint.not_called + assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called + + expected_request_args = { + "Accept": DEFAULT_ACCEPT, + "Body": data, + "ContentType": "application/json", + "EndpointName": ENDPOINT, + "TargetVariant": PRODUCTION_VARIANT_1, + "InferenceId": INFERENCE_ID, + "CustomAttributes": "custom-attribute", + "InferenceComponentName": "test_component_name", + "TargetContainerHostname": "test_target_container_hostname", + } + + ( + call_args, + kwargs, + ) = sagemaker_session.sagemaker_runtime_client.invoke_endpoint_with_response_stream.call_args + assert kwargs == expected_request_args + + assert result == STREAM_ITERABLE_BODY + + @patch("sagemaker.base_predictor.name_from_base") def test_update_endpoint_no_args(name_from_base): new_endpoint_config_name = "new-endpoint-config"