diff --git a/samcli/lib/utils/stream_writer.py b/samcli/lib/utils/stream_writer.py index 78bae0d8ff..dc2c2b9b5c 100644 --- a/samcli/lib/utils/stream_writer.py +++ b/samcli/lib/utils/stream_writer.py @@ -1,11 +1,12 @@ """ This class acts like a wrapper around output streams to provide any flexibility with output we need """ -from typing import TextIO +from io import BytesIO, TextIOWrapper +from typing import Optional, TextIO, Union class StreamWriter: - def __init__(self, stream: TextIO, auto_flush: bool = False): + def __init__(self, stream: TextIO, stream_bytes: Optional[Union[TextIO, BytesIO]] = None, auto_flush: bool = False): """ Instatiates new StreamWriter to the specified stream @@ -13,16 +14,40 @@ def __init__(self, stream: TextIO, auto_flush: bool = False): ---------- stream io.RawIOBase Stream to wrap + stream_bytes io.TextIO | io.BytesIO + Stream to wrap if bytes are being written auto_flush bool Whether to autoflush the stream upon writing """ self._stream = stream + self._stream_bytes = stream if isinstance(stream, TextIOWrapper) else stream_bytes self._auto_flush = auto_flush @property def stream(self) -> TextIO: return self._stream + def write_bytes(self, output: bytes): + """ + Writes specified text to the underlying stream + Parameters + ---------- + output bytes-like object + Bytes to write into buffer + """ + # all these ifs are to satisfy the linting/type checking + if not self._stream_bytes: + return + if isinstance(self._stream_bytes, TextIOWrapper): + self._stream_bytes.buffer.write(output) + if self._auto_flush: + self._stream_bytes.flush() + + elif isinstance(self._stream_bytes, BytesIO): + self._stream_bytes.write(output) + if self._auto_flush: + self._stream_bytes.flush() + def write_str(self, output: str): """ Writes specified text to the underlying stream @@ -39,3 +64,5 @@ def write_str(self, output: str): def flush(self): self._stream.flush() + if self._stream_bytes: + self._stream_bytes.flush() diff --git a/samcli/local/apigw/authorizers/lambda_authorizer.py b/samcli/local/apigw/authorizers/lambda_authorizer.py index ed3483eee5..8b7b92c6ea 100644 --- a/samcli/local/apigw/authorizers/lambda_authorizer.py +++ b/samcli/local/apigw/authorizers/lambda_authorizer.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from json import JSONDecodeError, loads -from typing import Any, Dict, List, Optional, Tuple, Type, cast +from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast from urllib.parse import parse_qsl from samcli.commands.local.lib.validators.identity_source_validator import IdentitySourceValidator @@ -321,13 +321,13 @@ def _parse_identity_sources(self, identity_sources: List[str]) -> None: break - def is_valid_response(self, response: str, method_arn: str) -> bool: + def is_valid_response(self, response: Union[str, bytes], method_arn: str) -> bool: """ Validates whether a Lambda authorizer request is authenticated or not. Parameters ---------- - response: str + response: Union[str, bytes] JSON string containing the output from a Lambda authorizer method_arn: str The method ARN of the route that invoked the Lambda authorizer @@ -418,13 +418,13 @@ def _validate_simple_response(self, response: dict) -> bool: return cast(bool, is_authorized) - def get_context(self, response: str) -> Dict[str, Any]: + def get_context(self, response: Union[str, bytes]) -> Dict[str, Any]: """ Returns the context (if set) from the authorizer response and appends the principalId to it. Parameters ---------- - response: str + response: Union[str, bytes] Output from Lambda authorizer Returns diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index df82b68ae4..1e0f871fcd 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -6,7 +6,7 @@ from datetime import datetime from io import StringIO from time import time -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from flask import Flask, Request, request from werkzeug.datastructures import Headers @@ -594,7 +594,7 @@ def _valid_identity_sources(self, request: Request, route: Route) -> bool: return True - def _invoke_lambda_function(self, lambda_function_name: str, event: dict) -> str: + def _invoke_lambda_function(self, lambda_function_name: str, event: dict) -> Union[str, bytes]: """ Helper method to invoke a function and setup stdout+stderr @@ -607,8 +607,8 @@ def _invoke_lambda_function(self, lambda_function_name: str, event: dict) -> str Returns ------- - str - A string containing the output from the Lambda function + Union[str, bytes] + A string or bytes containing the output from the Lambda function """ with StringIO() as stdout: event_str = json.dumps(event, sort_keys=True) diff --git a/samcli/local/docker/container.py b/samcli/local/docker/container.py index b1777279b0..239c92fcea 100644 --- a/samcli/local/docker/container.py +++ b/samcli/local/docker/container.py @@ -369,7 +369,7 @@ def start(self, input_data=None): raise ex @retry(exc=requests.exceptions.RequestException, exc_raise=ContainerResponseException) - def wait_for_http_response(self, name, event, stdout) -> Union[str, bytes]: + def wait_for_http_response(self, name, event, stdout) -> Tuple[Union[str, bytes], bool]: # TODO(sriram-mv): `aws-lambda-rie` is in a mode where the function_name is always "function" # NOTE(sriram-mv): There is a connection timeout set on the http call to `aws-lambda-rie`, however there is not # a read time out for the response received from the server. @@ -380,10 +380,13 @@ def wait_for_http_response(self, name, event, stdout) -> Union[str, bytes]: timeout=(self.RAPID_CONNECTION_TIMEOUT, None), ) try: - return json.dumps(json.loads(resp.content), ensure_ascii=False) + # if response is an image then json.loads/dumps will throw a UnicodeDecodeError so return raw content + if "image" in resp.headers["Content-Type"]: + return resp.content, True + return json.dumps(json.loads(resp.content), ensure_ascii=False), False except json.JSONDecodeError: LOG.debug("Failed to deserialize response from RIE, returning the raw response as is") - return resp.content + return resp.content, False def wait_for_result(self, full_path, event, stdout, stderr, start_timer=None): # NOTE(sriram-mv): Let logging happen in its own thread, so that a http request can be sent. @@ -406,13 +409,15 @@ def wait_for_result(self, full_path, event, stdout, stderr, start_timer=None): # start the timer for function timeout right before executing the function, as waiting for the socket # can take some time timer = start_timer() if start_timer else None - response = self.wait_for_http_response(full_path, event, stdout) + response, is_image = self.wait_for_http_response(full_path, event, stdout) if timer: timer.cancel() self._logs_thread_event.wait(timeout=1) if isinstance(response, str): stdout.write_str(response) + elif isinstance(response, bytes) and is_image: + stdout.write_bytes(response) elif isinstance(response, bytes): stdout.write_str(response.decode("utf-8")) stdout.flush() diff --git a/samcli/local/lambda_service/local_lambda_invoke_service.py b/samcli/local/lambda_service/local_lambda_invoke_service.py index 1e46a10507..a847802b3c 100644 --- a/samcli/local/lambda_service/local_lambda_invoke_service.py +++ b/samcli/local/lambda_service/local_lambda_invoke_service.py @@ -165,8 +165,9 @@ def _invoke_request_handler(self, function_name): request_data = request_data.decode("utf-8") - stdout_stream = io.StringIO() - stdout_stream_writer = StreamWriter(stdout_stream, auto_flush=True) + stdout_stream_string = io.StringIO() + stdout_stream_bytes = io.BytesIO() + stdout_stream_writer = StreamWriter(stdout_stream_string, stdout_stream_bytes, auto_flush=True) try: self.lambda_runner.invoke(function_name, request_data, stdout=stdout_stream_writer, stderr=self.stderr) @@ -178,7 +179,9 @@ def _invoke_request_handler(self, function_name): "Inline code is not supported for sam local commands. Please write your code in a separate file." ) - lambda_response, is_lambda_user_error_response = LambdaOutputParser.get_lambda_output(stdout_stream) + lambda_response, is_lambda_user_error_response = LambdaOutputParser.get_lambda_output( + stdout_stream_string, stdout_stream_bytes + ) if is_lambda_user_error_response: return self.service_response( diff --git a/samcli/local/services/base_local_service.py b/samcli/local/services/base_local_service.py index 5de5beb7dd..573c24b445 100644 --- a/samcli/local/services/base_local_service.py +++ b/samcli/local/services/base_local_service.py @@ -2,7 +2,7 @@ import io import json import logging -from typing import Tuple +from typing import Optional, Tuple, Union from flask import Response @@ -85,7 +85,9 @@ def service_response(body, headers, status_code): class LambdaOutputParser: @staticmethod - def get_lambda_output(stdout_stream: io.StringIO) -> Tuple[str, bool]: + def get_lambda_output( + stdout_stream_str: io.StringIO, stdout_stream_bytes: Optional[io.BytesIO] = None + ) -> Tuple[Union[str, bytes], bool]: """ This method will extract read the given stream and return the response from Lambda function separated out from any log statements it might have outputted. Logs end up in the stdout stream if the Lambda function @@ -93,9 +95,12 @@ def get_lambda_output(stdout_stream: io.StringIO) -> Tuple[str, bool]: Parameters ---------- - stdout_stream : io.BaseIO + stdout_stream_str : io.BaseIO Stream to fetch data from + stdout_stream_bytes : Optional[io.BytesIO], optional + Stream to fetch raw bytes data from + Returns ------- str @@ -103,7 +108,9 @@ def get_lambda_output(stdout_stream: io.StringIO) -> Tuple[str, bool]: bool If the response is an error/exception from the container """ - lambda_response = stdout_stream.getvalue() + lambda_response: Union[str, bytes] = stdout_stream_str.getvalue() + if stdout_stream_bytes and not lambda_response: + lambda_response = stdout_stream_bytes.getvalue() # When the Lambda Function returns an Error/Exception, the output is added to the stdout of the container. From # our perspective, the container returned some value, which is not always true. Since the output is the only diff --git a/tests/integration/local/invoke/test_integrations_cli.py b/tests/integration/local/invoke/test_integrations_cli.py index c1a114b5d4..3c33a25cea 100644 --- a/tests/integration/local/invoke/test_integrations_cli.py +++ b/tests/integration/local/invoke/test_integrations_cli.py @@ -1197,6 +1197,41 @@ def test_invoke_inline_code_function(self): self.assertEqual(process.returncode, 1) +class TestInvokeFunctionWithImageBytesAsReturn(InvokeIntegBase): + template = Path("template-return-image.yaml") + + @pytest.mark.flaky(reruns=3) + def test_invoke_returncode_is_zero(self): + command_list = InvokeIntegBase.get_command_list( + "GetImageFunction", template_path=self.template_path, event_path=self.event_path + ) + + process = Popen(command_list, stdout=PIPE) + try: + process.communicate(timeout=TIMEOUT) + except TimeoutExpired: + process.kill() + raise + + self.assertEqual(process.returncode, 0) + + @pytest.mark.flaky(reruns=3) + def test_invoke_image_is_returned(self): + command_list = InvokeIntegBase.get_command_list( + "GetImageFunction", template_path=self.template_path, event_path=self.event_path + ) + + process = Popen(command_list, stdout=PIPE) + try: + stdout, _ = process.communicate(timeout=TIMEOUT) + except TimeoutExpired: + process.kill() + raise + + # The first byte of a png image file is \x89 so we can check that to verify that it returned an image + self.assertEqual(stdout[0:1], b"\x89") + + class TestInvokeFunctionWithError(InvokeIntegBase): template = Path("template.yml") diff --git a/tests/integration/testdata/invoke/image-for-lambda.png b/tests/integration/testdata/invoke/image-for-lambda.png new file mode 100644 index 0000000000..56a3af614b Binary files /dev/null and b/tests/integration/testdata/invoke/image-for-lambda.png differ diff --git a/tests/integration/testdata/invoke/main.py b/tests/integration/testdata/invoke/main.py index e33635ccb4..0f2753a6ff 100644 --- a/tests/integration/testdata/invoke/main.py +++ b/tests/integration/testdata/invoke/main.py @@ -60,3 +60,10 @@ def execute_git(event, context): def no_response(event, context): print("lambda called") + + +def image_handler(event, context): + with open("image-for-lambda.png", "rb") as f: + image_bytes = f.read() + + return image_bytes \ No newline at end of file diff --git a/tests/integration/testdata/invoke/template-return-image.yaml b/tests/integration/testdata/invoke/template-return-image.yaml new file mode 100644 index 0000000000..f59f8e5453 --- /dev/null +++ b/tests/integration/testdata/invoke/template-return-image.yaml @@ -0,0 +1,12 @@ +AWSTemplateFormatVersion : '2010-09-09' +Transform: AWS::Serverless-2016-10-31 +Description: A hello world application. + +Resources: + GetImageFunction: + Type: AWS::Serverless::Function + Properties: + Handler: main.image_handler + Runtime: python3.11 + CodeUri: . + Timeout: 600 \ No newline at end of file diff --git a/tests/unit/lib/utils/test_stream_writer.py b/tests/unit/lib/utils/test_stream_writer.py index 0459a44c0e..c586e48a42 100644 --- a/tests/unit/lib/utils/test_stream_writer.py +++ b/tests/unit/lib/utils/test_stream_writer.py @@ -1,8 +1,8 @@ """ Tests for StreamWriter """ -import io +from io import BytesIO, TextIOWrapper from unittest import TestCase from samcli.lib.utils.stream_writer import StreamWriter @@ -20,6 +20,35 @@ def test_must_write_to_stream(self): stream_mock.write.assert_called_once_with(buffer.decode("utf-8")) + def test_must_write_to_stream_bytes(self): + img_bytes = b"\xff\xab\x11" + stream_mock = Mock() + byte_stream_mock = Mock(spec=BytesIO) + + writer = StreamWriter(stream_mock, byte_stream_mock) + writer.write_bytes(img_bytes) + + byte_stream_mock.write.assert_called_once_with(img_bytes) + + def test_must_write_to_stream_bytes_for_stdout(self): + img_bytes = b"\xff\xab\x11" + stream_mock = Mock() + byte_stream_mock = Mock(spec=TextIOWrapper) + + writer = StreamWriter(stream_mock, byte_stream_mock) + writer.write_bytes(img_bytes) + + byte_stream_mock.buffer.write.assert_called_once_with(img_bytes) + + def test_must_not_write_to_stream_bytes_if_not_defined(self): + img_bytes = b"\xff\xab\x11" + stream_mock = Mock() + + writer = StreamWriter(stream_mock) + writer.write_bytes(img_bytes) + + stream_mock.write.assert_not_called() + def test_must_flush_underlying_stream(self): stream_mock = Mock() writer = StreamWriter(stream_mock) @@ -44,7 +73,7 @@ def test_when_auto_flush_on_flush_after_each_write(self): lines = ["first", "second", "third"] - writer = StreamWriter(stream_mock, True) + writer = StreamWriter(stream_mock, auto_flush=True) for line in lines: writer.write_str(line) diff --git a/tests/unit/local/docker/test_container.py b/tests/unit/local/docker/test_container.py index cdcf22c5d9..ddf02f91e3 100644 --- a/tests/unit/local/docker/test_container.py +++ b/tests/unit/local/docker/test_container.py @@ -1,6 +1,7 @@ """ Unit test for Container class """ +import base64 import json from unittest import TestCase from unittest.mock import MagicMock, Mock, call, patch, ANY @@ -584,22 +585,77 @@ def setUp(self): self.socket_mock = Mock() self.socket_mock.connect_ex.return_value = 0 + @patch("socket.socket") + @patch("samcli.local.docker.container.requests") + def test_wait_for_result_no_error_image_response(self, mock_requests, patched_socket): + self.container.is_created.return_value = True + + rie_response = b"\xff\xab" + resp_headers = { + "Date": "Tue, 02 Jan 2024 21:23:31 GMT", + "Content-Type": "image/jpeg", + "Transfer-Encoding": "chunked", + } + + real_container_mock = Mock() + self.mock_docker_client.containers.get.return_value = real_container_mock + + output_itr = Mock() + real_container_mock.attach.return_value = output_itr + self.container._write_container_output = Mock() + self.container._create_threading_event = Mock() + self.container._create_threading_event.return_value = Mock() + + stdout_mock = Mock() + stdout_mock.write_bytes = Mock() + stderr_mock = Mock() + response = Mock() + response.content = rie_response + response.headers = resp_headers + mock_requests.post.return_value = response + + patched_socket.return_value = self.socket_mock + + start_timer = Mock() + timer = Mock() + start_timer.return_value = timer + + self.container.wait_for_result( + event=self.event, full_path=self.name, stdout=stdout_mock, stderr=stderr_mock, start_timer=start_timer + ) + + # since we passed in a start_timer function, ensure it's called and + # the timer is cancelled once execution is done + start_timer.assert_called() + timer.cancel.assert_called() + + # make sure we wait for the same host+port that we make the post request to + host = self.container._container_host + port = self.container.rapid_port_host + self.socket_mock.connect_ex.assert_called_with((host, port)) + mock_requests.post.assert_called_with( + self.container.URL.format(host=host, port=port, function_name="function"), + data=b"{}", + timeout=(self.container.RAPID_CONNECTION_TIMEOUT, None), + ) + stdout_mock.write_bytes.assert_called_with(rie_response) + @parameterized.expand( [ - ( - True, - b'{"hello":"world"}', - ), + (True, b'{"hello":"world"}', {"Date": "Tue, 02 Jan 2024 21:23:31 GMT", "Content-Type": "text"}), ( False, b"non-json-deserializable", + {"Date": "Tue, 02 Jan 2024 21:23:31 GMT", "Content-Type": "text/plain"}, ), - (False, b""), + (False, b"", {"Date": "Tue, 02 Jan 2024 21:23:31 GMT", "Content-Type": "text/plain"}), ] ) @patch("socket.socket") @patch("samcli.local.docker.container.requests") - def test_wait_for_result_no_error(self, response_deserializable, rie_response, mock_requests, patched_socket): + def test_wait_for_result_no_error( + self, response_deserializable, rie_response, resp_headers, mock_requests, patched_socket + ): self.container.is_created.return_value = True real_container_mock = Mock() @@ -616,6 +672,7 @@ def test_wait_for_result_no_error(self, response_deserializable, rie_response, m stderr_mock = Mock() response = Mock() response.content = rie_response + response.headers = resp_headers mock_requests.post.return_value = response patched_socket.return_value = self.socket_mock diff --git a/tests/unit/local/lambda_service/test_local_lambda_invoke_service.py b/tests/unit/local/lambda_service/test_local_lambda_invoke_service.py index 684e72f00c..e338762db9 100644 --- a/tests/unit/local/lambda_service/test_local_lambda_invoke_service.py +++ b/tests/unit/local/lambda_service/test_local_lambda_invoke_service.py @@ -135,7 +135,7 @@ def test_request_handler_returns_process_stdout_when_making_response( result = service._invoke_request_handler(function_name="HelloWorld") self.assertEqual(result, "request response") - lambda_output_parser_mock.get_lambda_output.assert_called_with(ANY) + lambda_output_parser_mock.get_lambda_output.assert_called_with(ANY, ANY) @patch("samcli.local.lambda_service.local_lambda_invoke_service.LambdaErrorResponses") def test_construct_error_handling(self, lambda_error_response_mock):