Skip to content

Commit

Permalink
Revert "Fix: sam invoke local throw error 'utf-8' codec can't decode …
Browse files Browse the repository at this point in the history
…byte 0xff (aws#6509)" (aws#6543)

This reverts commit c49c218.
  • Loading branch information
sidhujus authored Jan 10, 2024
1 parent de61f7b commit 335e529
Show file tree
Hide file tree
Showing 13 changed files with 31 additions and 213 deletions.
31 changes: 2 additions & 29 deletions samcli/lib/utils/stream_writer.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,28 @@
"""
This class acts like a wrapper around output streams to provide any flexibility with output we need
"""
from io import BytesIO, TextIOWrapper
from typing import Optional, TextIO, Union
from typing import TextIO


class StreamWriter:
def __init__(self, stream: TextIO, stream_bytes: Optional[Union[TextIO, BytesIO]] = None, auto_flush: bool = False):
def __init__(self, stream: TextIO, auto_flush: bool = False):
"""
Instatiates new StreamWriter to the specified stream
Parameters
----------
stream io.RawIOBase
Stream to wrap
stream_bytes io.TextIO | io.BytesIO
Stream to wrap if bytes are being written
auto_flush bool
Whether to autoflush the stream upon writing
"""
self._stream = stream
self._stream_bytes = stream if isinstance(stream, TextIOWrapper) else stream_bytes
self._auto_flush = auto_flush

@property
def stream(self) -> TextIO:
return self._stream

def write_bytes(self, output: bytes):
"""
Writes specified text to the underlying stream
Parameters
----------
output bytes-like object
Bytes to write into buffer
"""
# all these ifs are to satisfy the linting/type checking
if not self._stream_bytes:
return
if isinstance(self._stream_bytes, TextIOWrapper):
self._stream_bytes.buffer.write(output)
if self._auto_flush:
self._stream_bytes.flush()

elif isinstance(self._stream_bytes, BytesIO):
self._stream_bytes.write(output)
if self._auto_flush:
self._stream_bytes.flush()

def write_str(self, output: str):
"""
Writes specified text to the underlying stream
Expand All @@ -64,5 +39,3 @@ def write_str(self, output: str):

def flush(self):
self._stream.flush()
if self._stream_bytes:
self._stream_bytes.flush()
10 changes: 5 additions & 5 deletions samcli/local/apigw/authorizers/lambda_authorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from json import JSONDecodeError, loads
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Type, cast
from urllib.parse import parse_qsl

from samcli.commands.local.lib.validators.identity_source_validator import IdentitySourceValidator
Expand Down Expand Up @@ -321,13 +321,13 @@ def _parse_identity_sources(self, identity_sources: List[str]) -> None:

break

def is_valid_response(self, response: Union[str, bytes], method_arn: str) -> bool:
def is_valid_response(self, response: str, method_arn: str) -> bool:
"""
Validates whether a Lambda authorizer request is authenticated or not.
Parameters
----------
response: Union[str, bytes]
response: str
JSON string containing the output from a Lambda authorizer
method_arn: str
The method ARN of the route that invoked the Lambda authorizer
Expand Down Expand Up @@ -418,13 +418,13 @@ def _validate_simple_response(self, response: dict) -> bool:

return cast(bool, is_authorized)

def get_context(self, response: Union[str, bytes]) -> Dict[str, Any]:
def get_context(self, response: str) -> Dict[str, Any]:
"""
Returns the context (if set) from the authorizer response and appends the principalId to it.
Parameters
----------
response: Union[str, bytes]
response: str
Output from Lambda authorizer
Returns
Expand Down
8 changes: 4 additions & 4 deletions samcli/local/apigw/local_apigw_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from datetime import datetime
from io import StringIO
from time import time
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple

from flask import Flask, Request, request
from werkzeug.datastructures import Headers
Expand Down Expand Up @@ -594,7 +594,7 @@ def _valid_identity_sources(self, request: Request, route: Route) -> bool:

return True

def _invoke_lambda_function(self, lambda_function_name: str, event: dict) -> Union[str, bytes]:
def _invoke_lambda_function(self, lambda_function_name: str, event: dict) -> str:
"""
Helper method to invoke a function and setup stdout+stderr
Expand All @@ -607,8 +607,8 @@ def _invoke_lambda_function(self, lambda_function_name: str, event: dict) -> Uni
Returns
-------
Union[str, bytes]
A string or bytes containing the output from the Lambda function
str
A string containing the output from the Lambda function
"""
with StringIO() as stdout:
event_str = json.dumps(event, sort_keys=True)
Expand Down
13 changes: 4 additions & 9 deletions samcli/local/docker/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def start(self, input_data=None):
raise ex

@retry(exc=requests.exceptions.RequestException, exc_raise=ContainerResponseException)
def wait_for_http_response(self, name, event, stdout) -> Tuple[Union[str, bytes], bool]:
def wait_for_http_response(self, name, event, stdout) -> Union[str, bytes]:
# TODO(sriram-mv): `aws-lambda-rie` is in a mode where the function_name is always "function"
# NOTE(sriram-mv): There is a connection timeout set on the http call to `aws-lambda-rie`, however there is not
# a read time out for the response received from the server.
Expand All @@ -374,13 +374,10 @@ def wait_for_http_response(self, name, event, stdout) -> Tuple[Union[str, bytes]
timeout=(self.RAPID_CONNECTION_TIMEOUT, None),
)
try:
# if response is an image then json.loads/dumps will throw a UnicodeDecodeError so return raw content
if "image" in resp.headers["Content-Type"]:
return resp.content, True
return json.dumps(json.loads(resp.content), ensure_ascii=False), False
return json.dumps(json.loads(resp.content), ensure_ascii=False)
except json.JSONDecodeError:
LOG.debug("Failed to deserialize response from RIE, returning the raw response as is")
return resp.content, False
return resp.content

def wait_for_result(self, full_path, event, stdout, stderr, start_timer=None):
# NOTE(sriram-mv): Let logging happen in its own thread, so that a http request can be sent.
Expand All @@ -403,15 +400,13 @@ def wait_for_result(self, full_path, event, stdout, stderr, start_timer=None):
# start the timer for function timeout right before executing the function, as waiting for the socket
# can take some time
timer = start_timer() if start_timer else None
response, is_image = self.wait_for_http_response(full_path, event, stdout)
response = self.wait_for_http_response(full_path, event, stdout)
if timer:
timer.cancel()

self._logs_thread_event.wait(timeout=1)
if isinstance(response, str):
stdout.write_str(response)
elif isinstance(response, bytes) and is_image:
stdout.write_bytes(response)
elif isinstance(response, bytes):
stdout.write_str(response.decode("utf-8"))
stdout.flush()
Expand Down
9 changes: 3 additions & 6 deletions samcli/local/lambda_service/local_lambda_invoke_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,8 @@ def _invoke_request_handler(self, function_name):

request_data = request_data.decode("utf-8")

stdout_stream_string = io.StringIO()
stdout_stream_bytes = io.BytesIO()
stdout_stream_writer = StreamWriter(stdout_stream_string, stdout_stream_bytes, auto_flush=True)
stdout_stream = io.StringIO()
stdout_stream_writer = StreamWriter(stdout_stream, auto_flush=True)

try:
self.lambda_runner.invoke(function_name, request_data, stdout=stdout_stream_writer, stderr=self.stderr)
Expand All @@ -179,9 +178,7 @@ def _invoke_request_handler(self, function_name):
"Inline code is not supported for sam local commands. Please write your code in a separate file."
)

lambda_response, is_lambda_user_error_response = LambdaOutputParser.get_lambda_output(
stdout_stream_string, stdout_stream_bytes
)
lambda_response, is_lambda_user_error_response = LambdaOutputParser.get_lambda_output(stdout_stream)

if is_lambda_user_error_response:
return self.service_response(
Expand Down
15 changes: 4 additions & 11 deletions samcli/local/services/base_local_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import io
import json
import logging
from typing import Optional, Tuple, Union
from typing import Tuple

from flask import Response

Expand Down Expand Up @@ -85,32 +85,25 @@ def service_response(body, headers, status_code):

class LambdaOutputParser:
@staticmethod
def get_lambda_output(
stdout_stream_str: io.StringIO, stdout_stream_bytes: Optional[io.BytesIO] = None
) -> Tuple[Union[str, bytes], bool]:
def get_lambda_output(stdout_stream: io.StringIO) -> Tuple[str, bool]:
"""
This method will extract read the given stream and return the response from Lambda function separated out
from any log statements it might have outputted. Logs end up in the stdout stream if the Lambda function
wrote directly to stdout using System.out.println or equivalents.
Parameters
----------
stdout_stream_str : io.BaseIO
stdout_stream : io.BaseIO
Stream to fetch data from
stdout_stream_bytes : Optional[io.BytesIO], optional
Stream to fetch raw bytes data from
Returns
-------
str
String data containing response from Lambda function
bool
If the response is an error/exception from the container
"""
lambda_response: Union[str, bytes] = stdout_stream_str.getvalue()
if stdout_stream_bytes and not lambda_response:
lambda_response = stdout_stream_bytes.getvalue()
lambda_response = stdout_stream.getvalue()

# When the Lambda Function returns an Error/Exception, the output is added to the stdout of the container. From
# our perspective, the container returned some value, which is not always true. Since the output is the only
Expand Down
35 changes: 0 additions & 35 deletions tests/integration/local/invoke/test_integrations_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,41 +1197,6 @@ def test_invoke_inline_code_function(self):
self.assertEqual(process.returncode, 1)


class TestInvokeFunctionWithImageBytesAsReturn(InvokeIntegBase):
template = Path("template-return-image.yaml")

@pytest.mark.flaky(reruns=3)
def test_invoke_returncode_is_zero(self):
command_list = InvokeIntegBase.get_command_list(
"GetImageFunction", template_path=self.template_path, event_path=self.event_path
)

process = Popen(command_list, stdout=PIPE)
try:
process.communicate(timeout=TIMEOUT)
except TimeoutExpired:
process.kill()
raise

self.assertEqual(process.returncode, 0)

@pytest.mark.flaky(reruns=3)
def test_invoke_image_is_returned(self):
command_list = InvokeIntegBase.get_command_list(
"GetImageFunction", template_path=self.template_path, event_path=self.event_path
)

process = Popen(command_list, stdout=PIPE)
try:
stdout, _ = process.communicate(timeout=TIMEOUT)
except TimeoutExpired:
process.kill()
raise

# The first byte of a png image file is \x89 so we can check that to verify that it returned an image
self.assertEqual(stdout[0:1], b"\x89")


class TestInvokeFunctionWithError(InvokeIntegBase):
template = Path("template.yml")

Expand Down
Binary file not shown.
7 changes: 0 additions & 7 deletions tests/integration/testdata/invoke/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,3 @@ def execute_git(event, context):

def no_response(event, context):
print("lambda called")


def image_handler(event, context):
with open("image-for-lambda.png", "rb") as f:
image_bytes = f.read()

return image_bytes
12 changes: 0 additions & 12 deletions tests/integration/testdata/invoke/template-return-image.yaml

This file was deleted.

33 changes: 2 additions & 31 deletions tests/unit/lib/utils/test_stream_writer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Tests for StreamWriter
"""
import io

from io import BytesIO, TextIOWrapper
from unittest import TestCase

from samcli.lib.utils.stream_writer import StreamWriter
Expand All @@ -20,35 +20,6 @@ def test_must_write_to_stream(self):

stream_mock.write.assert_called_once_with(buffer.decode("utf-8"))

def test_must_write_to_stream_bytes(self):
img_bytes = b"\xff\xab\x11"
stream_mock = Mock()
byte_stream_mock = Mock(spec=BytesIO)

writer = StreamWriter(stream_mock, byte_stream_mock)
writer.write_bytes(img_bytes)

byte_stream_mock.write.assert_called_once_with(img_bytes)

def test_must_write_to_stream_bytes_for_stdout(self):
img_bytes = b"\xff\xab\x11"
stream_mock = Mock()
byte_stream_mock = Mock(spec=TextIOWrapper)

writer = StreamWriter(stream_mock, byte_stream_mock)
writer.write_bytes(img_bytes)

byte_stream_mock.buffer.write.assert_called_once_with(img_bytes)

def test_must_not_write_to_stream_bytes_if_not_defined(self):
img_bytes = b"\xff\xab\x11"
stream_mock = Mock()

writer = StreamWriter(stream_mock)
writer.write_bytes(img_bytes)

stream_mock.write.assert_not_called()

def test_must_flush_underlying_stream(self):
stream_mock = Mock()
writer = StreamWriter(stream_mock)
Expand All @@ -73,7 +44,7 @@ def test_when_auto_flush_on_flush_after_each_write(self):

lines = ["first", "second", "third"]

writer = StreamWriter(stream_mock, auto_flush=True)
writer = StreamWriter(stream_mock, True)

for line in lines:
writer.write_str(line)
Expand Down
Loading

0 comments on commit 335e529

Please sign in to comment.