Skip to content

Commit

Permalink
Add local tests for openapi protocol headers
Browse files Browse the repository at this point in the history
  • Loading branch information
JordonPhillips committed Aug 3, 2021
1 parent 64477cb commit 7f7ea98
Show file tree
Hide file tree
Showing 8 changed files with 323 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ private static <T extends Trait> Map<CorsHeader, String> deduceCorsHeaders(
TopDownIndex topDownIndex = TopDownIndex.of(model);
Map<String, OperationShape> operations = topDownIndex.getContainedOperations(context.getService()).stream()
.collect(Collectors.toMap(o -> o.getId().getName(), o -> o));
for (Map.Entry<String, OperationObject> method : pathItem.getOperations().entrySet()) {
OperationObject operationObject = method.getValue();
for (OperationObject operationObject : pathItem.getOperations().values()) {
if (operationObject.getOperationId().isPresent()) {
OperationShape operationShape = operations.get(operationObject.getOperationId().get());
headerNames.addAll(context.getOpenApiProtocol().getProtocolRequestHeaders(context, operationShape));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
"statusCode": "200",
"responseParameters": {
"method.response.header.Access-Control-Allow-Origin": "'https://foo.com'",
"method.response.header.Access-Control-Expose-Headers": "'X-Amzn-Errortype,X-Amzn-Requestid,X-Hd'"
"method.response.header.Access-Control-Expose-Headers": "'Content-Length,Content-Type,X-Amzn-Errortype,X-Amzn-Requestid,X-Hd'"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
"statusCode": "200",
"responseParameters": {
"method.response.header.Access-Control-Allow-Origin": "'https://foo.com'",
"method.response.header.Access-Control-Expose-Headers": "'X-Amzn-Errortype,X-Amzn-Requestid,X-Hd'"
"method.response.header.Access-Control-Expose-Headers": "'Content-Length,Content-Type,X-Amzn-Errortype,X-Amzn-Requestid,X-Hd'"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.model.traits.ErrorTrait;
import software.amazon.smithy.model.traits.HttpChecksumProperty;
Expand All @@ -58,7 +57,6 @@
import software.amazon.smithy.openapi.model.Ref;
import software.amazon.smithy.openapi.model.RequestBodyObject;
import software.amazon.smithy.openapi.model.ResponseObject;
import software.amazon.smithy.utils.SetUtils;

/**
* Provides the shared functionality used across protocols that use Smithy's
Expand All @@ -79,8 +77,6 @@ abstract class AbstractRestProtocol<T extends Trait> implements OpenApiProtocol<

private static final String AWS_EVENT_STREAM_CONTENT_TYPE = "application/vnd.amazon.eventstream";
private static final Pattern NON_ALPHA_NUMERIC = Pattern.compile("[^A-Za-z0-9]");
// If a request returns / accepts a body then it should allow these un-modeled headers.
private static final Set<String> CONTENT_HEADERS = SetUtils.of("Content-Length", "Content-Type");

private static final Logger LOGGER = Logger.getLogger(AbstractRestProtocol.class.getName());

Expand Down Expand Up @@ -126,14 +122,13 @@ public Set<String> getProtocolRequestHeaders(Context<T> context, OperationShape
String documentMediaType = getDocumentMediaType(context, operationShape, MessageType.REQUEST);
// If the request has a body with a content type, allow the content-type and content-length headers.
bindingIndex.determineRequestContentType(operationShape, documentMediaType)
.ifPresent(c -> headers.addAll(CONTENT_HEADERS));
.ifPresent(c -> headers.addAll(ProtocolUtils.CONTENT_HEADERS));

if (operationShape.hasTrait(HttpChecksumRequiredTrait.class)) {
headers.add("Content-Md5");
}
if (operationShape.hasTrait(HttpChecksumTrait.class)) {
HttpChecksumTrait trait = operationShape.expectTrait(HttpChecksumTrait.class);
headers.addAll(getChecksumHeaders(trait.getRequestProperties()));
} else if (operationShape.hasTrait(HttpChecksumRequiredTrait.class)) {
headers.add("Content-Md5");
}
return headers;
}
Expand All @@ -142,15 +137,14 @@ public Set<String> getProtocolRequestHeaders(Context<T> context, OperationShape
public Set<String> getProtocolResponseHeaders(Context<T> context, OperationShape operationShape) {
Set<String> headers = new TreeSet<>(OpenApiProtocol.super.getProtocolResponseHeaders(context, operationShape));

// If the operation or any attached errors have a content type, then both the content-type and content-length
// headers need to be exposed.
if (willReturnContentType(context, operationShape)) {
headers.addAll(CONTENT_HEADERS);
// If the operation has any defined output or errors, it can return content-type.
if (operationShape.getOutput().isPresent() || !operationShape.getErrors().isEmpty()) {
headers.addAll(ProtocolUtils.CONTENT_HEADERS);
}

if (operationShape.hasTrait(HttpChecksumTrait.class)) {
HttpChecksumTrait trait = operationShape.expectTrait(HttpChecksumTrait.class);
headers.addAll(getChecksumHeaders(trait.getRequestProperties()));
headers.addAll(getChecksumHeaders(trait.getResponseProperties()));
}
return headers;
}
Expand All @@ -165,23 +159,6 @@ private Set<String> getChecksumHeaders(List<HttpChecksumProperty> httpChecksumPr
return headers;
}

private boolean willReturnContentType(Context<T> context, OperationShape operationShape) {
HttpBindingIndex bindingIndex = HttpBindingIndex.of(context.getModel());
String documentMediaType = getDocumentMediaType(context, operationShape, MessageType.RESPONSE);
Optional<String> contentType = bindingIndex.determineResponseContentType(operationShape, documentMediaType);
if (contentType.isPresent()) {
return true;
}
for (ShapeId error : operationShape.getErrors()) {
Shape errorShape = context.getModel().expectShape(error);
contentType = bindingIndex.determineResponseContentType(errorShape, documentMediaType);
if (contentType.isPresent()) {
return true;
}
}
return false;
}

@Override
public Optional<Operation> createOperation(Context<T> context, OperationShape operation) {
ServiceShape serviceShape = context.getService();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
import software.amazon.smithy.aws.traits.clientendpointdiscovery.ClientEndpointDiscoveryTrait;
import software.amazon.smithy.aws.traits.clientendpointdiscovery.ClientDiscoveredEndpointTrait;
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait;
import software.amazon.smithy.jsonschema.Schema;
import software.amazon.smithy.model.Model;
Expand Down Expand Up @@ -64,7 +64,7 @@ public Set<String> getProtocolRequestHeaders(Context<RestJson1Trait> context, Op
// x-amz-api-version if it is an endpoint operation
Set<String> headers = new TreeSet<>(super.getProtocolRequestHeaders(context, operationShape));
headers.addAll(AWS_REQUEST_HEADERS);
if (operationShape.hasTrait(ClientEndpointDiscoveryTrait.class)) {
if (operationShape.hasTrait(ClientDiscoveredEndpointTrait.class)) {
headers.add("X-Amz-Api-Version");
}
return headers;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.smithy.openapi.fromsmithy.protocols;

import java.util.Set;
import software.amazon.smithy.utils.SetUtils;

/**
* Protocol utilities for OpenAPI protocol support.
*/
final class ProtocolUtils {
// If a request returns / accepts a body then it should allow these un-modeled headers.
static final Set<String> CONTENT_HEADERS = SetUtils.of("Content-Length", "Content-Type");

private ProtocolUtils() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,29 @@
import static org.junit.jupiter.api.Assertions.fail;

import java.io.InputStream;
import java.util.Set;
import java.util.stream.Stream;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.node.Node;
import software.amazon.smithy.model.node.ObjectNode;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.traits.Trait;
import software.amazon.smithy.openapi.OpenApiConfig;
import software.amazon.smithy.openapi.fromsmithy.Context;
import software.amazon.smithy.openapi.fromsmithy.OpenApiConverter;
import software.amazon.smithy.openapi.fromsmithy.OpenApiMapper;
import software.amazon.smithy.openapi.model.OpenApi;
import software.amazon.smithy.openapi.model.OperationObject;
import software.amazon.smithy.utils.IoUtils;
import software.amazon.smithy.utils.SetUtils;

public class AwsRestJson1ProtocolTest {

Expand Down Expand Up @@ -119,4 +130,166 @@ public void canRemoveNonAlphaNumericDocumentNames() {
Node.assertEquals(result, expectedNode);
}
}

private static Stream<Arguments> protocolHeaderCases() {
return Stream.of(
Arguments.of(
"NoInputOrOutput",
SetUtils.of(
"X-Amz-User-Agent",
"X-Amzn-Trace-Id",
"Amz-Sdk-Request",
"Amz-Sdk-Invocation-Id"
),
SetUtils.of(
"X-Amzn-Requestid",
"X-Amzn-Errortype"
)
),
Arguments.of(
"EmptyInputAndOutput",
SetUtils.of(
"X-Amz-User-Agent",
"X-Amzn-Trace-Id",
"Amz-Sdk-Request",
"Amz-Sdk-Invocation-Id"
),
SetUtils.of(
"X-Amzn-Requestid",
"X-Amzn-Errortype",
"Content-Length",
"Content-Type"
)
),
Arguments.of(
"OnlyErrorOutput",
SetUtils.of(
"X-Amz-User-Agent",
"X-Amzn-Trace-Id",
"Amz-Sdk-Request",
"Amz-Sdk-Invocation-Id"
),
SetUtils.of(
"X-Amzn-Requestid",
"X-Amzn-Errortype",
"Content-Length",
"Content-Type"
)
),
Arguments.of(
"HttpChecksumRequired",
SetUtils.of(
"X-Amz-User-Agent",
"X-Amzn-Trace-Id",
"Amz-Sdk-Request",
"Amz-Sdk-Invocation-Id",
"Content-Md5"
),
SetUtils.of(
"X-Amzn-Requestid",
"X-Amzn-Errortype"
)
),
Arguments.of(
"HttpChecksumInputOperation",
SetUtils.of(
"X-Amz-User-Agent",
"X-Amzn-Trace-Id",
"Amz-Sdk-Request",
"Amz-Sdk-Invocation-Id",
"x-amz-checksum-sha256"
),
SetUtils.of(
"X-Amzn-Requestid",
"X-Amzn-Errortype"
)
),
Arguments.of(
"HttpChecksumOutputOperation",
SetUtils.of(
"X-Amz-User-Agent",
"X-Amzn-Trace-Id",
"Amz-Sdk-Request",
"Amz-Sdk-Invocation-Id"
),
SetUtils.of(
"X-Amzn-Requestid",
"X-Amzn-Errortype",
"Content-Length",
"Content-Type",
"x-amz-checksum-sha256"
)
),
Arguments.of(
"HasDiscoveredEndpoint",
SetUtils.of(
"X-Amz-User-Agent",
"X-Amzn-Trace-Id",
"Amz-Sdk-Request",
"Amz-Sdk-Invocation-Id",
"X-Amz-Api-Version"
),
SetUtils.of(
"X-Amzn-Requestid",
"X-Amzn-Errortype",
"Content-Length",
"Content-Type"
)
)
);
}

@ParameterizedTest
@MethodSource("protocolHeaderCases")
public void assertProtocolHeaders(
String operationId,
Set<String> expectedRequestHeaders,
Set<String> expectedResponseHeaders
) {
Model model = Model.assembler()
.addImport(getClass().getResource("rest-json-protocol-headers.smithy"))
.discoverModels()
.assemble()
.unwrap();
OpenApiConfig config = new OpenApiConfig();
config.setService(ShapeId.from("smithy.example#Service"));
config.setAlphanumericOnlyRefs(true);

AwsRestJson1Protocol protocol = new AwsRestJson1Protocol();
OperationShape operation = model.expectShape(
ShapeId.fromParts("smithy.example", operationId), OperationShape.class);

ContextCapturingMapper contextCaptor = new ContextCapturingMapper();
OpenApiConverter.create()
.config(config)
.addOpenApiMapper(contextCaptor)
.convert(model);

Context<RestJson1Trait> context = (Context<RestJson1Trait>) contextCaptor.capturedContext;

Set<String> requestHeaders = protocol.getProtocolRequestHeaders(context, operation);
Assertions.assertEquals(expectedRequestHeaders, requestHeaders);

Set<String> responseHeaders = protocol.getProtocolResponseHeaders(context, operation);
Assertions.assertEquals(expectedResponseHeaders, responseHeaders);
}

private static class ContextCapturingMapper implements OpenApiMapper {

public Context<? extends Trait> capturedContext;

@Override
public byte getOrder() {
return 127;
}

@Override
public OperationObject updateOperation(
Context<? extends Trait> context, OperationShape shape, OperationObject operation,
String httpMethodName, String path
) {
this.capturedContext = context;
return OpenApiMapper.super.updateOperation(context, shape, operation, httpMethodName, path);
}
}
}
Loading

0 comments on commit 7f7ea98

Please sign in to comment.