From 96484faeed1d7cc42a5bcd3f906bd27d06564868 Mon Sep 17 00:00:00 2001 From: Mufaddal Rohawala Date: Wed, 13 Mar 2024 10:25:24 -0700 Subject: [PATCH 1/3] feature: Add support for Streaming Inference --- src/sagemaker/base_predictor.py | 82 ++++++++++ src/sagemaker/exceptions.py | 16 ++ src/sagemaker/iterators.py | 154 ++++++++++++++++++ .../lmi-model-falcon-7b/mymodel-7B.tar.gz | Bin 0 -> 382 bytes tests/integ/test_predict_stream.py | 88 ++++++++++ .../sagemaker/iterators/test_iterators.py | 58 +++++++ tests/unit/test_predictor.py | 75 +++++++++ 7 files changed, 473 insertions(+) create mode 100644 src/sagemaker/iterators.py create mode 100644 tests/data/lmi-model-falcon-7b/mymodel-7B.tar.gz create mode 100644 tests/integ/test_predict_stream.py create mode 100644 tests/unit/sagemaker/iterators/test_iterators.py diff --git a/src/sagemaker/base_predictor.py b/src/sagemaker/base_predictor.py index 76b83c25cd..f91674c142 100644 --- a/src/sagemaker/base_predictor.py +++ b/src/sagemaker/base_predictor.py @@ -52,6 +52,7 @@ JSONSerializer, NumpySerializer, ) +from sagemaker.iterators import LineIterator from sagemaker.session import production_variant, Session from sagemaker.utils import name_from_base, stringify_object, format_tags @@ -225,6 +226,7 @@ def _create_request_args( target_variant=None, inference_id=None, custom_attributes=None, + target_container_hostname=None, ): """Placeholder docstring""" @@ -286,9 +288,89 @@ def _create_request_args( if self._get_component_name(): args["InferenceComponentName"] = self.component_name + if target_container_hostname: + args["TargetContainerHostname"] = target_container_hostname + args["Body"] = data return args + def predict_stream( + self, + data, + initial_args=None, + target_variant=None, + inference_id=None, + custom_attributes=None, + component_name: Optional[str] = None, + target_container_hostname=None, + iterator=LineIterator, + ): + """Return the inference from the specified endpoint. + + Args: + data (object): Input data for which you want the model to provide + inference. If a serializer was specified when creating the + Predictor, the result of the serializer is sent as input + data. Otherwise the data must be sequence of bytes, and the + predict method then sends the bytes in the request body as is. + initial_args (dict[str,str]): Optional. Default arguments for boto3 + ``invoke_endpoint_with_response_stream`` call. Default is None (no default + arguments). + target_variant (str): The name of the production variant to run an inference + request on (Default: None). Note that the ProductionVariant identifies the + model you want to host and the resources you want to deploy for hosting it. + inference_id (str): If you provide a value, it is added to the captured data + when you enable data capture on the endpoint (Default: None). + custom_attributes (str): Provides additional information about a request for an + inference submitted to a model hosted at an Amazon SageMaker endpoint. + The information is an opaque value that is forwarded verbatim. You could use this + value, for example, to provide an ID that you can use to track a request or to + provide other metadata that a service endpoint was programmed to process. The value + must consist of no more than 1024 visible US-ASCII characters. + + The code in your model is responsible for setting or updating any custom attributes + in the response. If your code does not set this value in the response, an empty + value is returned. For example, if a custom attribute represents the trace ID, your + model can prepend the custom attribute with Trace ID: in your post-processing + function (Default: None). + component_name (str): Optional. Name of the Amazon SageMaker inference component + corresponding the predictor. + target_container_hostname (str): If the endpoint hosts multiple containers and is + configured to use direct invocation, this parameter specifies the host name of the + container to invoke. (Default: None). + iterator (:class:`~sagemaker.iterators.BaseIterator`): An iterator class which provides + an iterable interface to deserialize a stream response from Inference Endpoint. + An object of the iterator class provided will be returned by the predict_stream + method (Default::class:`~sagemaker.iterators.LineIterator`). Iterators defined in + :class:`~sagemaker.iterators` or custom iterators (needs to inherit + :class:`~sagemaker.iterators.BaseIterator`) can be specified as an input. + + Returns: + object (:class:`~sagemaker.iterators.BaseIterator`): An iterator object which would + allow iteration on EventStream response will be returned. The object would be + instantiated from `predict_stream` method's `iterator` parameter. + """ + # [TODO]: clean up component_name in _create_request_args + request_args = self._create_request_args( + data=data, + initial_args=initial_args, + target_variant=target_variant, + inference_id=inference_id, + custom_attributes=custom_attributes, + target_container_hostname=target_container_hostname, + ) + + inference_component_name = component_name or self._get_component_name() + if inference_component_name: + request_args["InferenceComponentName"] = inference_component_name + + response = ( + self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint_with_response_stream( + **request_args + ) + ) + return iterator(response["Body"]) + def update_endpoint( self, initial_instance_count=None, diff --git a/src/sagemaker/exceptions.py b/src/sagemaker/exceptions.py index b9d97cc241..7e1ed2c5fd 100644 --- a/src/sagemaker/exceptions.py +++ b/src/sagemaker/exceptions.py @@ -86,3 +86,19 @@ class AsyncInferenceModelError(AsyncInferenceError): def __init__(self, message): super().__init__(message=message) + + +class ModelStreamError(Exception): + def __init__(self, message="An error occurred", code=None): + self.message = message + self.code = code + if code is not None: + super().__init__(f"{message} (Code: {code})") + else: + super().__init__(message) + + +class InternalStreamFailure(Exception): + def __init__(self, message="An error occurred"): + self.message = message + super().__init__(self.message) diff --git a/src/sagemaker/iterators.py b/src/sagemaker/iterators.py new file mode 100644 index 0000000000..5680f14fc9 --- /dev/null +++ b/src/sagemaker/iterators.py @@ -0,0 +1,154 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Implements iterators for deserializing data returned from an inference streaming endpoint.""" +from __future__ import absolute_import + +from abc import ABC, abstractmethod +import io + +from sagemaker.exceptions import ModelStreamError, InternalStreamFailure + + +def handle_stream_errors(chunk): + """Handle API Response errors within `invoke_endpoint_with_response_stream` API if any. + + Args: + chunk (dict): A chunk of response received as part of `botocore.eventstream.EventStream` + response object. + + Raises: + ModelStreamError: If `ModelStreamError` error is detected in a chunk of + `botocore.eventstream.EventStream` response object. + InternalStreamFailure: If `InternalStreamFailure` error is detected in a chunk of + `botocore.eventstream.EventStream` response object. + """ + if "ModelStreamError" in chunk: + raise ModelStreamError( + chunk["ModelStreamError"]["Message"], code=chunk["ModelStreamError"]["ErrorCode"] + ) + if "InternalStreamFailure" in chunk: + raise InternalStreamFailure(chunk["InternalStreamFailure"]["Message"]) + + +class BaseIterator(ABC): + """Abstract base class for creation of new iterators. + + Provides a skeleton for customization requiring the overriding of iterator methods + __iter__ and __next__. + + Tenets of iterator class for Streaming Inference API Response + (https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ + sagemaker-runtime/client/invoke_endpoint_with_response_stream.html): + 1. Needs to accept an botocore.eventstream.EventStream response. + 2. Needs to implement logic in __next__ to: + 2.1. Concatenate and provide next chunk of response from botocore.eventstream.EventStream. + While doing so parse the response_chunk["PayloadPart"]["Bytes"]. + 2.2. Perform deserialization of response chunk based on expected response type. + 2.3. If PayloadPart not in EventStream response, handle Errors. + """ + + def __init__(self, stream): + """Initialises a Iterator object to help parse the byte event stream input. + + Args: + stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. + """ + self.stream = stream + + @abstractmethod + def __iter__(self): + """Abstract __iter__ method, returns an iterator object itself""" + return self + + @abstractmethod + def __next__(self): + """Abstract __next__ method, is responsible for returning the next element in the + iteration""" + pass + + +class LineIterator(BaseIterator): + """ + A helper class for parsing the byte stream input and provide iteration on lines with + '\n' separators. + """ + + def __init__(self, stream): + """Initialises a Iterator object to help parse the byte stream input and + provide iteration on lines with '\n' separators + + Args: + stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. + """ + super().__init__(stream) + self.byte_iterator = iter(self.stream) + self.buffer = io.BytesIO() + self.read_pos = 0 + + def __iter__(self): + """Returns an iterator object itself, which allows the object to be iterated. + + Returns: + iter : object + An iterator object representing the iterable. + """ + return self + + def __next__(self): + """ + The output of the event stream will be in the following format: + + ``` + b'{"outputs": [" a"]}\n' + b'{"outputs": [" challenging"]}\n' + b'{"outputs": [" problem"]}\n' + ... + ``` + + While usually each PayloadPart event from the event stream will contain a byte array + with a full json, this is not guaranteed and some of the json objects may be split across + PayloadPart events. For example: + ``` + {'PayloadPart': {'Bytes': b'{"outputs": '}} + {'PayloadPart': {'Bytes': b'[" problem"]}\n'}} + ``` + + This class accounts for this by concatenating bytes written via the 'write' function + and then exposing a method which will return lines (ending with a '\n' character) within + the buffer via the 'scan_lines' function. It maintains the position of the last read + position to ensure that previous bytes are not exposed again. + + Returns: + str: Read and return one line from the event stream. + """ + # Even with "while True" loop the function still behaves like a generator + # and sends the next new concatenated line + while True: + self.buffer.seek(self.read_pos) + line = self.buffer.readline() + if line and line[-1] == ord("\n"): + self.read_pos += len(line) + return line[:-1] + try: + chunk = next(self.byte_iterator) + except StopIteration: + if self.read_pos < self.buffer.getbuffer().nbytes: + continue + raise + if "PayloadPart" not in chunk: + # handle errors within API Response if any. + handle_stream_errors(chunk) + print("Unknown event type:" + chunk) + continue + self.buffer.seek(0, io.SEEK_END) + self.buffer.write(chunk["PayloadPart"]["Bytes"]) diff --git a/tests/data/lmi-model-falcon-7b/mymodel-7B.tar.gz b/tests/data/lmi-model-falcon-7b/mymodel-7B.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..6a66178b47a086851a140b8dba294307ac31ace1 GIT binary patch literal 382 zcmV-^0fGJ>iwFP!000001MSkyPV68Q2k@?aioU=oOg|VGY}~mp@eK$qr*knaG(e5| z^fF_nPHu8_qcJ!Be;Zmj^qg{-o+oc;+=!d2;=8a+G|h3${vMCdylzC*3NGrlV4OF+ zF3RTHDmt^oq(fO2!Ta=4+-K|msp-A{k;0>O`^!1_nL@G@zbMC{!EIgtv;R$1o%KK8 z6W&xz6eatj{2%(|{U^7#j^y3_?S-F{_3rX`ACxsRS-WVu8uZwEw-MdOx|qV!r&DC0 zM;r5lq^{;{=-Oe>*KE6^YoSfIV|-wyrM_+r+YY;qiPOgXm6%kZ$tO~M&L{H>t*hjs z4{Fvyk7F*y&^{1Jz80vTRPf`N@2cu_>i?){Ur1KlwXX9;IZk$CY+S4MOPZIY1|KG! z5(W7Xz02_wPZ1_P&s55Cn0b4eoAsWII&5% Date: Wed, 13 Mar 2024 15:31:28 -0700 Subject: [PATCH 2/3] fix: codestyle-docs-test --- src/sagemaker/base_predictor.py | 26 +++---- src/sagemaker/exceptions.py | 4 + src/sagemaker/iterators.py | 76 +++++++++++++------ tests/integ/test_predict_stream.py | 21 +++++ .../sagemaker/iterators/test_iterators.py | 72 +++++++++++++++++- 5 files changed, 162 insertions(+), 37 deletions(-) diff --git a/src/sagemaker/base_predictor.py b/src/sagemaker/base_predictor.py index f91674c142..1a7eea9cd7 100644 --- a/src/sagemaker/base_predictor.py +++ b/src/sagemaker/base_predictor.py @@ -52,7 +52,7 @@ JSONSerializer, NumpySerializer, ) -from sagemaker.iterators import LineIterator +from sagemaker.iterators import ByteIterator from sagemaker.session import production_variant, Session from sagemaker.utils import name_from_base, stringify_object, format_tags @@ -303,7 +303,7 @@ def predict_stream( custom_attributes=None, component_name: Optional[str] = None, target_container_hostname=None, - iterator=LineIterator, + iterator=ByteIterator, ): """Return the inference from the specified endpoint. @@ -315,14 +315,14 @@ def predict_stream( predict method then sends the bytes in the request body as is. initial_args (dict[str,str]): Optional. Default arguments for boto3 ``invoke_endpoint_with_response_stream`` call. Default is None (no default - arguments). - target_variant (str): The name of the production variant to run an inference + arguments). (Default: None) + target_variant (str): Optional. The name of the production variant to run an inference request on (Default: None). Note that the ProductionVariant identifies the model you want to host and the resources you want to deploy for hosting it. - inference_id (str): If you provide a value, it is added to the captured data + inference_id (str): Optional. If you provide a value, it is added to the captured data when you enable data capture on the endpoint (Default: None). - custom_attributes (str): Provides additional information about a request for an - inference submitted to a model hosted at an Amazon SageMaker endpoint. + custom_attributes (str): Optional. Provides additional information about a request for + an inference submitted to a model hosted at an Amazon SageMaker endpoint. The information is an opaque value that is forwarded verbatim. You could use this value, for example, to provide an ID that you can use to track a request or to provide other metadata that a service endpoint was programmed to process. The value @@ -334,14 +334,14 @@ def predict_stream( model can prepend the custom attribute with Trace ID: in your post-processing function (Default: None). component_name (str): Optional. Name of the Amazon SageMaker inference component - corresponding the predictor. - target_container_hostname (str): If the endpoint hosts multiple containers and is - configured to use direct invocation, this parameter specifies the host name of the - container to invoke. (Default: None). + corresponding the predictor. (Default: None) + target_container_hostname (str): Optional. If the endpoint hosts multiple containers + and is configured to use direct invocation, this parameter specifies the host name + of the container to invoke. (Default: None). iterator (:class:`~sagemaker.iterators.BaseIterator`): An iterator class which provides - an iterable interface to deserialize a stream response from Inference Endpoint. + an iterable interface to iterate Event stream response from Inference Endpoint. An object of the iterator class provided will be returned by the predict_stream - method (Default::class:`~sagemaker.iterators.LineIterator`). Iterators defined in + method (Default::class:`~sagemaker.iterators.ByteIterator`). Iterators defined in :class:`~sagemaker.iterators` or custom iterators (needs to inherit :class:`~sagemaker.iterators.BaseIterator`) can be specified as an input. diff --git a/src/sagemaker/exceptions.py b/src/sagemaker/exceptions.py index 7e1ed2c5fd..88ffa0a591 100644 --- a/src/sagemaker/exceptions.py +++ b/src/sagemaker/exceptions.py @@ -89,6 +89,8 @@ def __init__(self, message): class ModelStreamError(Exception): + """Raised when invoke_endpoint_with_response_stream Response returns ModelStreamError""" + def __init__(self, message="An error occurred", code=None): self.message = message self.code = code @@ -99,6 +101,8 @@ def __init__(self, message="An error occurred", code=None): class InternalStreamFailure(Exception): + """Raised when invoke_endpoint_with_response_stream Response returns InternalStreamFailure""" + def __init__(self, message="An error occurred"): self.message = message super().__init__(self.message) diff --git a/src/sagemaker/iterators.py b/src/sagemaker/iterators.py index 5680f14fc9..38a43121a1 100644 --- a/src/sagemaker/iterators.py +++ b/src/sagemaker/iterators.py @@ -41,7 +41,7 @@ def handle_stream_errors(chunk): class BaseIterator(ABC): - """Abstract base class for creation of new iterators. + """Abstract base class for Inference Streaming iterators. Provides a skeleton for customization requiring the overriding of iterator methods __iter__ and __next__. @@ -53,45 +53,75 @@ class BaseIterator(ABC): 2. Needs to implement logic in __next__ to: 2.1. Concatenate and provide next chunk of response from botocore.eventstream.EventStream. While doing so parse the response_chunk["PayloadPart"]["Bytes"]. - 2.2. Perform deserialization of response chunk based on expected response type. - 2.3. If PayloadPart not in EventStream response, handle Errors. + 2.2. If PayloadPart not in EventStream response, handle Errors + [Recommended to use `iterators.handle_stream_errors` method]. """ - def __init__(self, stream): + def __init__(self, event_stream): """Initialises a Iterator object to help parse the byte event stream input. Args: - stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. + event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. """ - self.stream = stream + self.event_stream = event_stream @abstractmethod def __iter__(self): - """Abstract __iter__ method, returns an iterator object itself""" + """Abstract method, returns an iterator object itself""" return self @abstractmethod def __next__(self): - """Abstract __next__ method, is responsible for returning the next element in the - iteration""" - pass + """Abstract method, is responsible for returning the next element in the iteration""" + + +class ByteIterator(BaseIterator): + """A helper class for parsing the byte Event Stream input to provide Byte iteration.""" + + def __init__(self, event_stream): + """Initialises a BytesIterator Iterator object + + Args: + event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. + """ + super().__init__(event_stream) + self.byte_iterator = iter(event_stream) + + def __iter__(self): + """Returns an iterator object itself, which allows the object to be iterated. + + Returns: + iter : object + An iterator object representing the iterable. + """ + return self + + def __next__(self): + """Returns the next chunk of Byte directly.""" + # Even with "while True" loop the function still behaves like a generator + # and sends the next new byte chunk. + while True: + chunk = next(self.byte_iterator) + if "PayloadPart" not in chunk: + # handle API response errors and force terminate. + handle_stream_errors(chunk) + # print and move on to next response byte + print("Unknown event type:" + chunk) + continue + return chunk["PayloadPart"]["Bytes"] class LineIterator(BaseIterator): - """ - A helper class for parsing the byte stream input and provide iteration on lines with - '\n' separators. - """ + """A helper class for parsing the byte Event Stream input to provide Line iteration.""" - def __init__(self, stream): - """Initialises a Iterator object to help parse the byte stream input and - provide iteration on lines with '\n' separators + def __init__(self, event_stream): + """Initialises a LineIterator Iterator object Args: - stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. + event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. """ - super().__init__(stream) - self.byte_iterator = iter(self.stream) + super().__init__(event_stream) + self.byte_iterator = iter(self.event_stream) self.buffer = io.BytesIO() self.read_pos = 0 @@ -105,7 +135,8 @@ def __iter__(self): return self def __next__(self): - """ + r"""Returns the next Line for an Line iterable. + The output of the event stream will be in the following format: ``` @@ -146,8 +177,9 @@ def __next__(self): continue raise if "PayloadPart" not in chunk: - # handle errors within API Response if any. + # handle API response errors and force terminate. handle_stream_errors(chunk) + # print and move on to next response byte print("Unknown event type:" + chunk) continue self.buffer.seek(0, io.SEEK_END) diff --git a/tests/integ/test_predict_stream.py b/tests/integ/test_predict_stream.py index f3a3e35ff6..bdd19187df 100644 --- a/tests/integ/test_predict_stream.py +++ b/tests/integ/test_predict_stream.py @@ -86,3 +86,24 @@ def test_predict_stream(sagemaker_session, endpoint_name): response += resp.get("outputs")[0] assert "AWS stands for Amazon Web Services." in response + + data = {"inputs": "what does AWS stand for?", "parameters": {"max_new_tokens": 400}} + initial_args = {"ContentType": "application/json"} + predictor = Predictor( + endpoint_name=endpoint_name, + sagemaker_session=sagemaker_session, + ) + + # Validate that no exception is raised when the target_variant is specified. + # uses the default `sagemaker.iterator.ByteIterator` + stream_iterator = predictor.predict_stream( + data=json.dumps(data), + initial_args=initial_args, + ) + + response = "" + for line in stream_iterator: + resp = json.loads(line) + response += resp.get("outputs")[0] + + assert "AWS stands for Amazon Web Services." in response diff --git a/tests/unit/sagemaker/iterators/test_iterators.py b/tests/unit/sagemaker/iterators/test_iterators.py index aa22656fdf..89e9d43a47 100644 --- a/tests/unit/sagemaker/iterators/test_iterators.py +++ b/tests/unit/sagemaker/iterators/test_iterators.py @@ -1,11 +1,25 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + import unittest from unittest.mock import MagicMock from sagemaker.exceptions import ModelStreamError, InternalStreamFailure -from sagemaker.iterators import LineIterator +from sagemaker.iterators import ByteIterator, LineIterator -class TestLineIterator(unittest.TestCase): +class TestByteIterator(unittest.TestCase): def test_iteration_with_payload_parts(self): # Mocking the stream object self.stream = MagicMock() @@ -14,6 +28,60 @@ def test_iteration_with_payload_parts(self): {"PayloadPart": {"Bytes": b'{"outputs": [" challenging"]}\n'}}, {"PayloadPart": {"Bytes": b'{"outputs": [" problem"]}\n'}}, ] + self.iterator = ByteIterator(self.stream) + + lines = list(self.iterator) + expected_lines = [ + b'{"outputs": [" a"]}\n', + b'{"outputs": [" challenging"]}\n', + b'{"outputs": [" problem"]}\n', + ] + self.assertEqual(lines, expected_lines) + + def test_iteration_with_model_stream_error(self): + # Mocking the stream object + self.stream = MagicMock() + self.stream.__iter__.return_value = [ + {"PayloadPart": {"Bytes": b'{"outputs": [" a"]}\n'}}, + {"PayloadPart": {"Bytes": b'{"outputs": [" challenging"]}\n'}}, + {"ModelStreamError": {"Message": "Error message", "ErrorCode": "500"}}, + {"PayloadPart": {"Bytes": b'{"outputs": [" problem"]}\n'}}, + ] + self.iterator = ByteIterator(self.stream) + + with self.assertRaises(ModelStreamError) as e: + list(self.iterator) + + self.assertEqual(str(e.exception.message), "Error message") + self.assertEqual(str(e.exception.code), "500") + + def test_iteration_with_internal_stream_failure(self): + # Mocking the stream object + self.stream = MagicMock() + self.stream.__iter__.return_value = [ + {"PayloadPart": {"Bytes": b'{"outputs": [" a"]}\n'}}, + {"PayloadPart": {"Bytes": b'{"outputs": [" challenging"]}\n'}}, + {"InternalStreamFailure": {"Message": "Error internal stream failure"}}, + {"PayloadPart": {"Bytes": b'{"outputs": [" problem"]}\n'}}, + ] + self.iterator = ByteIterator(self.stream) + + with self.assertRaises(InternalStreamFailure) as e: + list(self.iterator) + + self.assertEqual(str(e.exception.message), "Error internal stream failure") + + +class TestLineIterator(unittest.TestCase): + def test_iteration_with_payload_parts(self): + # Mocking the stream object + self.stream = MagicMock() + self.stream.__iter__.return_value = [ + {"PayloadPart": {"Bytes": b'{"outputs": [" a"]}\n'}}, + {"PayloadPart": {"Bytes": b'{"outputs": [" challenging"]}\n'}}, + {"PayloadPart": {"Bytes": b'{"outputs": '}}, + {"PayloadPart": {"Bytes": b'[" problem"]}\n'}}, + ] self.iterator = LineIterator(self.stream) lines = list(self.iterator) From 9b43eabf9fea695cda4b3277d210bf58616ce125 Mon Sep 17 00:00:00 2001 From: Mufaddal Rohawala Date: Wed, 13 Mar 2024 16:01:34 -0700 Subject: [PATCH 3/3] fix: codestyle-docs-test --- src/sagemaker/jumpstart/cache.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index fff421ab32..e9a34a21a8 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -65,7 +65,9 @@ def __init__( self, region: Optional[str] = None, max_s3_cache_items: int = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, - s3_cache_expiration_horizon: datetime.timedelta = JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON, + s3_cache_expiration_horizon: datetime.timedelta = ( + JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON + ), max_semantic_version_cache_items: int = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, semantic_version_cache_expiration_horizon: datetime.timedelta = ( JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON