Skip to content

Commit

Permalink
Lambda-http: vary type of response based on request origin (#269)
Browse files Browse the repository at this point in the history
* Lambda-http: vary type of response based on request origin

ApiGatewayV2, ApiGateway and Alb all expect different types of responses to
be returned from the invoked lambda function. Thus, it makes sense to pass
the request origin to the creation of the response, so that the correct
type of LambdaResponse is returned from the function.

This commit also adds support for the "cookies" attribute which can be used
for returning multiple Set-cookie headers from a lambda invoked via
ApiGatewayV2, since ApiGatewayV2 no longer seems to recognize the
"multiValueHeaders" attribute.

Closes: #267.

* Fix Serialize import

* Fix missing reference on self

* Fix import order

* Add missing comma for fmt check

Co-authored-by: Blake Hildebrand <38637276+bahildebrand@users.noreply.github.com>
  • Loading branch information
l3ku and bahildebrand authored Feb 7, 2021
1 parent 6033ce3 commit 5556ab2
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 66 deletions.
18 changes: 11 additions & 7 deletions lambda-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ pub mod request;
mod response;
mod strmap;
pub use crate::{body::Body, ext::RequestExt, response::IntoResponse, strmap::StrMap};
use crate::{request::LambdaRequest, response::LambdaResponse};
use crate::{
request::{LambdaRequest, RequestOrigin},
response::LambdaResponse,
};
use std::{
future::Future,
pin::Pin,
Expand Down Expand Up @@ -124,7 +127,7 @@ where

#[doc(hidden)]
pub struct TransformResponse<R, E> {
is_alb: bool,
request_origin: RequestOrigin,
fut: Pin<Box<dyn Future<Output = Result<R, E>> + Send + Sync>>,
}

Expand All @@ -135,9 +138,9 @@ where
type Output = Result<LambdaResponse, E>;
fn poll(mut self: Pin<&mut Self>, cx: &mut TaskContext) -> Poll<Self::Output> {
match self.fut.as_mut().poll(cx) {
Poll::Ready(result) => {
Poll::Ready(result.map(|resp| LambdaResponse::from_response(self.is_alb, resp.into_response())))
}
Poll::Ready(result) => Poll::Ready(
result.map(|resp| LambdaResponse::from_response(&self.request_origin, resp.into_response())),
),
Poll::Pending => Poll::Pending,
}
}
Expand Down Expand Up @@ -166,9 +169,10 @@ impl<H: Handler> Handler for Adapter<H> {
impl<H: Handler> LambdaHandler<LambdaRequest<'_>, LambdaResponse> for Adapter<H> {
type Error = H::Error;
type Fut = TransformResponse<H::Response, Self::Error>;

fn call(&self, event: LambdaRequest<'_>, context: Context) -> Self::Fut {
let is_alb = event.is_alb();
let request_origin = event.request_origin();
let fut = Box::pin(self.handler.call(event.into(), context));
TransformResponse { is_alb, fut }
TransformResponse { request_origin, fut }
}
}
29 changes: 21 additions & 8 deletions lambda-http/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,30 @@ pub enum LambdaRequest<'a> {
}

impl LambdaRequest<'_> {
/// Return true if this request represents an ALB event
///
/// Alb responses have unique requirements for responses that
/// vary only slightly from APIGateway responses. We serialize
/// responses capturing a hint that the request was an alb triggered
/// event.
pub fn is_alb(&self) -> bool {
matches!(self, LambdaRequest::Alb { .. })
/// Return the `RequestOrigin` of the request to determine where the `LambdaRequest`
/// originated from, so that the appropriate response can be selected based on what
/// type of response the request origin expects.
pub fn request_origin(&self) -> RequestOrigin {
match self {
LambdaRequest::ApiGatewayV2 { .. } => RequestOrigin::ApiGatewayV2,
LambdaRequest::Alb { .. } => RequestOrigin::Alb,
LambdaRequest::ApiGateway { .. } => RequestOrigin::ApiGateway,
}
}
}

/// Represents the origin from which the lambda was requested from.
#[doc(hidden)]
#[derive(Debug)]
pub enum RequestOrigin {
/// API Gateway v2 request origin
ApiGatewayV2,
/// API Gateway request origin
ApiGateway,
/// ALB request origin
Alb,
}

/// See [context-variable-reference](https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-mapping-template-reference.html) for more detail.
#[derive(Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
Expand Down
212 changes: 161 additions & 51 deletions lambda-http/src/response.rs
Original file line number Diff line number Diff line change
@@ -1,47 +1,67 @@
//! Response types

use crate::body::Body;
use crate::{body::Body, request::RequestOrigin};
use http::{
header::{HeaderMap, HeaderValue, CONTENT_TYPE},
header::{HeaderMap, HeaderValue, CONTENT_TYPE, SET_COOKIE},
Response,
};
use serde::{
ser::{Error as SerError, SerializeMap},
ser::{Error as SerError, SerializeMap, SerializeSeq},
Serialize, Serializer,
};

/// Representation of API Gateway response
/// Representation of Lambda response
#[doc(hidden)]
#[derive(Serialize, Debug)]
#[serde(untagged)]
pub enum LambdaResponse {
ApiGatewayV2(ApiGatewayV2Response),
Alb(AlbResponse),
ApiGateway(ApiGatewayResponse),
}

/// Representation of API Gateway v2 lambda response
#[doc(hidden)]
#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct LambdaResponse {
pub status_code: u16,
// ALB requires a statusDescription i.e. "200 OK" field but API Gateway returns an error
// when one is provided. only populate this for ALB responses
pub struct ApiGatewayV2Response {
status_code: u16,
#[serde(serialize_with = "serialize_headers")]
headers: HeaderMap<HeaderValue>,
#[serde(serialize_with = "serialize_headers_slice")]
cookies: Vec<HeaderValue>,
#[serde(skip_serializing_if = "Option::is_none")]
pub status_description: Option<String>,
body: Option<Body>,
is_base64_encoded: bool,
}

/// Representation of ALB lambda response
#[doc(hidden)]
#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct AlbResponse {
status_code: u16,
status_description: String,
#[serde(serialize_with = "serialize_headers")]
pub headers: HeaderMap<HeaderValue>,
#[serde(serialize_with = "serialize_multi_value_headers")]
pub multi_value_headers: HeaderMap<HeaderValue>,
headers: HeaderMap<HeaderValue>,
#[serde(skip_serializing_if = "Option::is_none")]
pub body: Option<Body>,
// This field is optional for API Gateway but required for ALB
pub is_base64_encoded: bool,
body: Option<Body>,
is_base64_encoded: bool,
}

#[cfg(test)]
impl Default for LambdaResponse {
fn default() -> Self {
Self {
status_code: 200,
status_description: Default::default(),
headers: Default::default(),
multi_value_headers: Default::default(),
body: Default::default(),
is_base64_encoded: Default::default(),
}
}
/// Representation of API Gateway lambda response
#[doc(hidden)]
#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct ApiGatewayResponse {
status_code: u16,
#[serde(serialize_with = "serialize_headers")]
headers: HeaderMap<HeaderValue>,
#[serde(serialize_with = "serialize_multi_value_headers")]
multi_value_headers: HeaderMap<HeaderValue>,
#[serde(skip_serializing_if = "Option::is_none")]
body: Option<Body>,
is_base64_encoded: bool,
}

/// Serialize a http::HeaderMap into a serde str => str map
Expand Down Expand Up @@ -73,9 +93,21 @@ where
map.end()
}

/// Serialize a &[HeaderValue] into a Vec<str>
fn serialize_headers_slice<S>(headers: &[HeaderValue], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(headers.len()))?;
for header in headers {
seq.serialize_element(header.to_str().map_err(S::Error::custom)?)?;
}
seq.end()
}

/// tranformation from http type to internal type
impl LambdaResponse {
pub(crate) fn from_response<T>(is_alb: bool, value: Response<T>) -> Self
pub(crate) fn from_response<T>(request_origin: &RequestOrigin, value: Response<T>) -> Self
where
T: Into<Body>,
{
Expand All @@ -85,21 +117,43 @@ impl LambdaResponse {
b @ Body::Text(_) => (false, Some(b)),
b @ Body::Binary(_) => (true, Some(b)),
};
Self {
status_code: parts.status.as_u16(),
status_description: if is_alb {
Some(format!(

let mut headers = parts.headers;
let status_code = parts.status.as_u16();

match request_origin {
RequestOrigin::ApiGatewayV2 => {
// ApiGatewayV2 expects the set-cookies headers to be in the "cookies" attribute,
// so remove them from the headers.
let cookies: Vec<HeaderValue> = headers.get_all(SET_COOKIE).iter().cloned().collect();
headers.remove(SET_COOKIE);

LambdaResponse::ApiGatewayV2(ApiGatewayV2Response {
body,
status_code,
is_base64_encoded,
cookies,
headers,
})
}
RequestOrigin::ApiGateway => LambdaResponse::ApiGateway(ApiGatewayResponse {
body,
status_code,
is_base64_encoded,
headers: headers.clone(),
multi_value_headers: headers,
}),
RequestOrigin::Alb => LambdaResponse::Alb(AlbResponse {
body,
status_code,
is_base64_encoded,
headers,
status_description: format!(
"{} {}",
parts.status.as_u16(),
status_code,
parts.status.canonical_reason().unwrap_or_default()
))
} else {
None
},
body,
headers: parts.headers.clone(),
multi_value_headers: parts.headers,
is_base64_encoded,
),
}),
}
}
}
Expand Down Expand Up @@ -159,10 +213,42 @@ impl IntoResponse for serde_json::Value {

#[cfg(test)]
mod tests {
use super::{Body, IntoResponse, LambdaResponse};
use super::{
AlbResponse, ApiGatewayResponse, ApiGatewayV2Response, Body, IntoResponse, LambdaResponse, RequestOrigin,
};
use http::{header::CONTENT_TYPE, Response};
use serde_json::{self, json};

fn api_gateway_response() -> ApiGatewayResponse {
ApiGatewayResponse {
status_code: 200,
headers: Default::default(),
multi_value_headers: Default::default(),
body: Default::default(),
is_base64_encoded: Default::default(),
}
}

fn alb_response() -> AlbResponse {
AlbResponse {
status_code: 200,
status_description: "200 OK".to_string(),
headers: Default::default(),
body: Default::default(),
is_base64_encoded: Default::default(),
}
}

fn api_gateway_v2_response() -> ApiGatewayV2Response {
ApiGatewayV2Response {
status_code: 200,
headers: Default::default(),
body: Default::default(),
cookies: Default::default(),
is_base64_encoded: Default::default(),
}
}

#[test]
fn json_into_response() {
let response = json!({ "hello": "lambda"}).into_response();
Expand All @@ -189,32 +275,39 @@ mod tests {
}

#[test]
fn default_response() {
assert_eq!(LambdaResponse::default().status_code, 200)
fn serialize_body_for_api_gateway() {
let mut resp = api_gateway_response();
resp.body = Some("foo".into());
assert_eq!(
serde_json::to_string(&resp).expect("failed to serialize response"),
r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"body":"foo","isBase64Encoded":false}"#
);
}

#[test]
fn serialize_default() {
fn serialize_body_for_alb() {
let mut resp = alb_response();
resp.body = Some("foo".into());
assert_eq!(
serde_json::to_string(&LambdaResponse::default()).expect("failed to serialize response"),
r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"isBase64Encoded":false}"#
serde_json::to_string(&resp).expect("failed to serialize response"),
r#"{"statusCode":200,"statusDescription":"200 OK","headers":{},"body":"foo","isBase64Encoded":false}"#
);
}

#[test]
fn serialize_body() {
let mut resp = LambdaResponse::default();
fn serialize_body_for_api_gateway_v2() {
let mut resp = api_gateway_v2_response();
resp.body = Some("foo".into());
assert_eq!(
serde_json::to_string(&resp).expect("failed to serialize response"),
r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"body":"foo","isBase64Encoded":false}"#
r#"{"statusCode":200,"headers":{},"cookies":[],"body":"foo","isBase64Encoded":false}"#
);
}

#[test]
fn serialize_multi_value_headers() {
let res = LambdaResponse::from_response(
false,
&RequestOrigin::ApiGateway,
Response::builder()
.header("multi", "a")
.header("multi", "b")
Expand All @@ -227,4 +320,21 @@ mod tests {
r#"{"statusCode":200,"headers":{"multi":"a"},"multiValueHeaders":{"multi":["a","b"]},"isBase64Encoded":false}"#
)
}

#[test]
fn serialize_cookies() {
let res = LambdaResponse::from_response(
&RequestOrigin::ApiGatewayV2,
Response::builder()
.header("set-cookie", "cookie1=a")
.header("set-cookie", "cookie2=b")
.body(Body::from(()))
.expect("failed to create response"),
);
let json = serde_json::to_string(&res).expect("failed to serialize to json");
assert_eq!(
json,
r#"{"statusCode":200,"headers":{},"cookies":["cookie1=a","cookie2=b"],"isBase64Encoded":false}"#
)
}
}

0 comments on commit 5556ab2

Please sign in to comment.