Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Add support for Streaming Inference #4497

Merged
merged 3 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions src/sagemaker/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
JSONSerializer,
NumpySerializer,
)
from sagemaker.iterators import ByteIterator
from sagemaker.session import production_variant, Session
from sagemaker.utils import name_from_base, stringify_object, format_tags

Expand Down Expand Up @@ -225,6 +226,7 @@ def _create_request_args(
target_variant=None,
inference_id=None,
custom_attributes=None,
target_container_hostname=None,
):
"""Placeholder docstring"""

Expand Down Expand Up @@ -286,9 +288,89 @@ def _create_request_args(
if self._get_component_name():
args["InferenceComponentName"] = self.component_name

if target_container_hostname:
args["TargetContainerHostname"] = target_container_hostname

args["Body"] = data
return args

def predict_stream(
self,
data,
initial_args=None,
target_variant=None,
inference_id=None,
custom_attributes=None,
component_name: Optional[str] = None,
target_container_hostname=None,
iterator=ByteIterator,
):
"""Return the inference from the specified endpoint.

Args:
data (object): Input data for which you want the model to provide
inference. If a serializer was specified when creating the
Predictor, the result of the serializer is sent as input
data. Otherwise the data must be sequence of bytes, and the
predict method then sends the bytes in the request body as is.
initial_args (dict[str,str]): Optional. Default arguments for boto3
``invoke_endpoint_with_response_stream`` call. Default is None (no default
arguments). (Default: None)
target_variant (str): Optional. The name of the production variant to run an inference
request on (Default: None). Note that the ProductionVariant identifies the
model you want to host and the resources you want to deploy for hosting it.
inference_id (str): Optional. If you provide a value, it is added to the captured data
when you enable data capture on the endpoint (Default: None).
custom_attributes (str): Optional. Provides additional information about a request for
an inference submitted to a model hosted at an Amazon SageMaker endpoint.
The information is an opaque value that is forwarded verbatim. You could use this
value, for example, to provide an ID that you can use to track a request or to
provide other metadata that a service endpoint was programmed to process. The value
must consist of no more than 1024 visible US-ASCII characters.

The code in your model is responsible for setting or updating any custom attributes
in the response. If your code does not set this value in the response, an empty
value is returned. For example, if a custom attribute represents the trace ID, your
model can prepend the custom attribute with Trace ID: in your post-processing
function (Default: None).
component_name (str): Optional. Name of the Amazon SageMaker inference component
corresponding the predictor. (Default: None)
target_container_hostname (str): Optional. If the endpoint hosts multiple containers
and is configured to use direct invocation, this parameter specifies the host name
of the container to invoke. (Default: None).
iterator (:class:`~sagemaker.iterators.BaseIterator`): An iterator class which provides
an iterable interface to iterate Event stream response from Inference Endpoint.
An object of the iterator class provided will be returned by the predict_stream
method (Default::class:`~sagemaker.iterators.ByteIterator`). Iterators defined in
:class:`~sagemaker.iterators` or custom iterators (needs to inherit
:class:`~sagemaker.iterators.BaseIterator`) can be specified as an input.

Returns:
object (:class:`~sagemaker.iterators.BaseIterator`): An iterator object which would
allow iteration on EventStream response will be returned. The object would be
instantiated from `predict_stream` method's `iterator` parameter.
"""
# [TODO]: clean up component_name in _create_request_args
request_args = self._create_request_args(
data=data,
initial_args=initial_args,
target_variant=target_variant,
inference_id=inference_id,
custom_attributes=custom_attributes,
target_container_hostname=target_container_hostname,
)

inference_component_name = component_name or self._get_component_name()
if inference_component_name:
request_args["InferenceComponentName"] = inference_component_name

response = (
self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint_with_response_stream(
**request_args
)
)
return iterator(response["Body"])

def update_endpoint(
self,
initial_instance_count=None,
Expand Down
20 changes: 20 additions & 0 deletions src/sagemaker/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,23 @@

def __init__(self, message):
super().__init__(message=message)


class ModelStreamError(Exception):
"""Raised when invoke_endpoint_with_response_stream Response returns ModelStreamError"""

def __init__(self, message="An error occurred", code=None):
self.message = message
self.code = code
if code is not None:
super().__init__(f"{message} (Code: {code})")
else:
super().__init__(message)

Check warning on line 100 in src/sagemaker/exceptions.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/exceptions.py#L100

Added line #L100 was not covered by tests


class InternalStreamFailure(Exception):
"""Raised when invoke_endpoint_with_response_stream Response returns InternalStreamFailure"""

def __init__(self, message="An error occurred"):
self.message = message
super().__init__(self.message)
186 changes: 186 additions & 0 deletions src/sagemaker/iterators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Implements iterators for deserializing data returned from an inference streaming endpoint."""
from __future__ import absolute_import

from abc import ABC, abstractmethod
import io

from sagemaker.exceptions import ModelStreamError, InternalStreamFailure


def handle_stream_errors(chunk):
"""Handle API Response errors within `invoke_endpoint_with_response_stream` API if any.

Args:
chunk (dict): A chunk of response received as part of `botocore.eventstream.EventStream`
response object.

Raises:
ModelStreamError: If `ModelStreamError` error is detected in a chunk of
`botocore.eventstream.EventStream` response object.
InternalStreamFailure: If `InternalStreamFailure` error is detected in a chunk of
`botocore.eventstream.EventStream` response object.
"""
if "ModelStreamError" in chunk:
raise ModelStreamError(
chunk["ModelStreamError"]["Message"], code=chunk["ModelStreamError"]["ErrorCode"]
)
if "InternalStreamFailure" in chunk:
raise InternalStreamFailure(chunk["InternalStreamFailure"]["Message"])


class BaseIterator(ABC):
"""Abstract base class for Inference Streaming iterators.

Provides a skeleton for customization requiring the overriding of iterator methods
__iter__ and __next__.

Tenets of iterator class for Streaming Inference API Response
(https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/
sagemaker-runtime/client/invoke_endpoint_with_response_stream.html):
1. Needs to accept an botocore.eventstream.EventStream response.
2. Needs to implement logic in __next__ to:
2.1. Concatenate and provide next chunk of response from botocore.eventstream.EventStream.
While doing so parse the response_chunk["PayloadPart"]["Bytes"].
2.2. If PayloadPart not in EventStream response, handle Errors
[Recommended to use `iterators.handle_stream_errors` method].
"""

def __init__(self, event_stream):
"""Initialises a Iterator object to help parse the byte event stream input.

Args:
event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated.
"""
self.event_stream = event_stream

@abstractmethod
def __iter__(self):
"""Abstract method, returns an iterator object itself"""
return self

Check warning on line 71 in src/sagemaker/iterators.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/iterators.py#L71

Added line #L71 was not covered by tests

@abstractmethod
def __next__(self):
"""Abstract method, is responsible for returning the next element in the iteration"""


class ByteIterator(BaseIterator):
"""A helper class for parsing the byte Event Stream input to provide Byte iteration."""

def __init__(self, event_stream):
"""Initialises a BytesIterator Iterator object

Args:
event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated.
"""
super().__init__(event_stream)
self.byte_iterator = iter(event_stream)

def __iter__(self):
"""Returns an iterator object itself, which allows the object to be iterated.

Returns:
iter : object
An iterator object representing the iterable.
"""
return self

def __next__(self):
"""Returns the next chunk of Byte directly."""
# Even with "while True" loop the function still behaves like a generator
# and sends the next new byte chunk.
while True:
chunk = next(self.byte_iterator)
if "PayloadPart" not in chunk:
# handle API response errors and force terminate.
handle_stream_errors(chunk)
# print and move on to next response byte
print("Unknown event type:" + chunk)
continue

Check warning on line 110 in src/sagemaker/iterators.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/iterators.py#L109-L110

Added lines #L109 - L110 were not covered by tests
return chunk["PayloadPart"]["Bytes"]


class LineIterator(BaseIterator):
"""A helper class for parsing the byte Event Stream input to provide Line iteration."""

def __init__(self, event_stream):
"""Initialises a LineIterator Iterator object

Args:
event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated.
"""
super().__init__(event_stream)
self.byte_iterator = iter(self.event_stream)
self.buffer = io.BytesIO()
self.read_pos = 0

def __iter__(self):
"""Returns an iterator object itself, which allows the object to be iterated.

Returns:
iter : object
An iterator object representing the iterable.
"""
return self

def __next__(self):
r"""Returns the next Line for an Line iterable.

The output of the event stream will be in the following format:

```
b'{"outputs": [" a"]}\n'
b'{"outputs": [" challenging"]}\n'
b'{"outputs": [" problem"]}\n'
...
```

While usually each PayloadPart event from the event stream will contain a byte array
with a full json, this is not guaranteed and some of the json objects may be split across
PayloadPart events. For example:
```
{'PayloadPart': {'Bytes': b'{"outputs": '}}
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
```

This class accounts for this by concatenating bytes written via the 'write' function
and then exposing a method which will return lines (ending with a '\n' character) within
the buffer via the 'scan_lines' function. It maintains the position of the last read
position to ensure that previous bytes are not exposed again.

Returns:
str: Read and return one line from the event stream.
"""
# Even with "while True" loop the function still behaves like a generator
# and sends the next new concatenated line
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line)
return line[:-1]
try:
chunk = next(self.byte_iterator)
except StopIteration:
if self.read_pos < self.buffer.getbuffer().nbytes:
continue

Check warning on line 177 in src/sagemaker/iterators.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/iterators.py#L177

Added line #L177 was not covered by tests
raise
if "PayloadPart" not in chunk:
# handle API response errors and force terminate.
handle_stream_errors(chunk)
# print and move on to next response byte
print("Unknown event type:" + chunk)
continue

Check warning on line 184 in src/sagemaker/iterators.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/iterators.py#L183-L184

Added lines #L183 - L184 were not covered by tests
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
4 changes: 3 additions & 1 deletion src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def __init__(
self,
region: Optional[str] = None,
max_s3_cache_items: int = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
s3_cache_expiration_horizon: datetime.timedelta = JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON,
s3_cache_expiration_horizon: datetime.timedelta = (
JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON
),
max_semantic_version_cache_items: int = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
semantic_version_cache_expiration_horizon: datetime.timedelta = (
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON
Expand Down
Binary file added tests/data/lmi-model-falcon-7b/mymodel-7B.tar.gz
Binary file not shown.
Loading