Skip to content

Commit

Permalink
Fix codegen server unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoh committed Oct 12, 2023
1 parent c43ec1b commit bcb089d
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -561,9 +561,9 @@ class ServerProtocolTestGenerator(

private fun checkResponse(rustWriter: RustWriter, testCase: HttpResponseTestCase) {
checkStatusCode(rustWriter, testCase.code)
checkHeaders(rustWriter, "&http_response.headers()", testCase.headers)
checkForbidHeaders(rustWriter, "&http_response.headers()", testCase.forbidHeaders)
checkRequiredHeaders(rustWriter, "&http_response.headers()", testCase.requireHeaders)
checkHeaders(rustWriter, "http_response.headers()", testCase.headers)
checkForbidHeaders(rustWriter, "http_response.headers()", testCase.forbidHeaders)
checkRequiredHeaders(rustWriter, "http_response.headers()", testCase.requireHeaders)

// We can't check that the `OperationExtension` is set in the response, because it is set in the implementation
// of the operation `Handler` trait, a code path that does not get exercised when we don't have a request to
Expand All @@ -579,7 +579,7 @@ class ServerProtocolTestGenerator(

private fun checkResponse(rustWriter: RustWriter, testCase: HttpMalformedResponseDefinition) {
checkStatusCode(rustWriter, testCase.code)
checkHeaders(rustWriter, "&http_response.headers()", testCase.headers)
checkHeaders(rustWriter, "http_response.headers()", testCase.headers)

// We can't check that the `OperationExtension` is set in the response, because it is set in the implementation
// of the operation `Handler` trait, a code path that does not get exercised when we don't have a request to
Expand Down
73 changes: 52 additions & 21 deletions rust-runtime/aws-smithy-protocol-test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
mod urlencoded;
mod xml;

use crate::sealed::GetNormalizedHeader;
use crate::xml::try_xml_equivalent;
use assert_json_diff::assert_json_eq_no_panic;
use aws_smithy_runtime_api::client::http::request::Headers;
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use http::Uri;
use http::{HeaderMap, Uri};
use pretty_assertions::Comparison;
use std::collections::HashSet;
use std::fmt::{self, Debug};
Expand Down Expand Up @@ -211,14 +212,46 @@ pub fn require_query_params(
Ok(())
}

mod sealed {
pub trait GetNormalizedHeader {
fn get_header(&self, key: &str) -> Option<String>;
}
}

impl<'a> GetNormalizedHeader for &'a Headers {
fn get_header(&self, key: &str) -> Option<String> {
if !self.contains_key(key) {
None
} else {
Some(self.get_all(key).collect::<Vec<_>>().join(", "))
}
}
}

impl<'a> GetNormalizedHeader for &'a HeaderMap {
fn get_header(&self, key: &str) -> Option<String> {
if !self.contains_key(key) {
None
} else {
Some(
self.get_all(key)
.iter()
.map(|value| std::str::from_utf8(value.as_bytes()).expect("invalid utf-8"))
.collect::<Vec<_>>()
.join(", "),
)
}
}
}

pub fn validate_headers<'a>(
actual_headers: &Headers,
actual_headers: impl GetNormalizedHeader,
expected_headers: impl IntoIterator<Item = (impl AsRef<str> + 'a, impl AsRef<str> + 'a)>,
) -> Result<(), ProtocolTestFailure> {
for (key, expected_value) in expected_headers {
let key = key.as_ref();
let expected_value = expected_value.as_ref();
match normalized_header(actual_headers, key) {
match actual_headers.get_header(key) {
None => {
return Err(ProtocolTestFailure::MissingHeader {
expected: key.to_string(),
Expand All @@ -237,21 +270,13 @@ pub fn validate_headers<'a>(
Ok(())
}

fn normalized_header(headers: &Headers, key: &str) -> Option<String> {
if !headers.contains_key(key) {
None
} else {
Some(headers.get_all(key).collect::<Vec<_>>().join(", "))
}
}

pub fn forbid_headers(
headers: &Headers,
headers: impl GetNormalizedHeader,
forbidden_headers: &[&str],
) -> Result<(), ProtocolTestFailure> {
for key in forbidden_headers {
// Protocol tests store header lists as comma-delimited
if let Some(value) = normalized_header(headers, key) {
if let Some(value) = headers.get_header(key) {
return Err(ProtocolTestFailure::ForbiddenHeader {
forbidden: key.to_string(),
found: format!("{}: {}", key, value),
Expand All @@ -262,12 +287,12 @@ pub fn forbid_headers(
}

pub fn require_headers(
headers: &Headers,
headers: impl GetNormalizedHeader,
required_headers: &[&str],
) -> Result<(), ProtocolTestFailure> {
for key in required_headers {
// Protocol tests store header lists as comma-delimited
if normalized_header(headers, key).is_none() {
if headers.get_header(key).is_none() {
return Err(ProtocolTestFailure::MissingHeader {
expected: key.to_string(),
});
Expand Down Expand Up @@ -442,10 +467,10 @@ mod tests {
#[test]
fn test_validate_headers() {
let mut headers = Headers::new();
headers.append("X-Foo", "foo");
headers.append("X-Foo-List", "foo");
headers.append("X-Foo-List", "bar");
headers.append("X-Inline", "inline, other");
headers.append("x-foo", "foo");
headers.append("x-foo-list", "foo");
headers.append("x-foo-list", "bar");
headers.append("x-inline", "inline, other");

validate_headers(&headers, [("X-Foo", "foo")]).expect("header present");
validate_headers(&headers, [("X-Foo", "Foo")]).expect_err("case sensitive");
Expand All @@ -465,7 +490,7 @@ mod tests {
#[test]
fn test_forbidden_headers() {
let mut headers = Headers::new();
headers.append("X-Foo", "foo");
headers.append("x-foo", "foo");
assert_eq!(
forbid_headers(&headers, &["X-Foo"]).expect_err("should be error"),
ProtocolTestFailure::ForbiddenHeader {
Expand All @@ -479,7 +504,7 @@ mod tests {
#[test]
fn test_required_headers() {
let mut headers = Headers::new();
headers.append("X-Foo", "foo");
headers.append("x-foo", "foo");
require_headers(&headers, &["X-Foo"]).expect("header present");
require_headers(&headers, &["X-Bar"]).expect_err("header not present");
}
Expand Down Expand Up @@ -520,6 +545,12 @@ mod tests {
.expect("inputs matched exactly")
}

#[test]
fn test_validate_headers_http0x() {
let request = http::Request::builder().header("a", "b").body(()).unwrap();
validate_headers(request.headers(), [("a", "b")]).unwrap()
}

#[test]
fn test_float_equals() {
let a = f64::NAN;
Expand Down
13 changes: 10 additions & 3 deletions rust-runtime/aws-smithy-runtime-api/src/client/http/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -582,9 +582,16 @@ impl Error for HttpError {

fn header_name(name: impl AsHeaderComponent) -> Result<http0::HeaderName, HttpError> {
name.repr_as_http03x_header_name().or_else(|name| {
name.into_maybe_static().and_then(|cow| match cow {
Cow::Borrowed(staticc) => Ok(http0::HeaderName::from_static(staticc)),
Cow::Owned(s) => http0::HeaderName::try_from(s).map_err(HttpError::invalid_header_name),
name.into_maybe_static().and_then(|cow| {
if cow.chars().any(|c| c.is_uppercase()) {
return Err(HttpError::new("Header names must be all lower case"));
}
match cow {
Cow::Borrowed(staticc) => Ok(http0::HeaderName::from_static(staticc)),
Cow::Owned(s) => {
http0::HeaderName::try_from(s).map_err(HttpError::invalid_header_name)
}
}
})
})
}
Expand Down

0 comments on commit bcb089d

Please sign in to comment.