Skip to content

Commit

Permalink
Use case-insensitive sets for cors headers
Browse files Browse the repository at this point in the history
This updates cors header lists to use case-insensitive sets so that
duplicate entries aren't accidentally used.
  • Loading branch information
JordonPhillips committed Oct 25, 2021
1 parent 306fa1a commit e3772fe
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,12 @@ private static <T extends Trait> Map<CorsHeader, String> deduceCorsHeaders(
// Access-Control-Allow-Headers header list. Note that any further modifications that
// add headers during the Smithy to OpenAPI conversion process will need to update this
// list of headers accordingly.
Set<String> headerNames = new TreeSet<>(corsTrait.getAdditionalAllowedHeaders());
Set<String> headerNames = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
headerNames.addAll(corsTrait.getAdditionalAllowedHeaders());

// Sets additional allowed headers from the API Gateway config.
List<String> additionalAllowedHeaders = context.getConfig().getExtensions(ApiGatewayConfig.class)
.getAdditionalAllowedCorsHeaders();
Set<String> additionalAllowedHeaders = context.getConfig().getExtensions(ApiGatewayConfig.class)
.getAdditionalAllowedCorsHeadersSet();
headerNames.addAll(additionalAllowedHeaders);
headerNames.addAll(findAllHeaders(path, pathItem));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ private ObjectNode updateGatewayResponse(

// Add the modeled additional headers. These could potentially be added by an
// apigateway feature, so they need to be present.
Set<String> exposedHeaders = new TreeSet<>(trait.getAdditionalExposedHeaders());
Set<String> exposedHeaders = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
exposedHeaders.addAll(trait.getAdditionalExposedHeaders());

// Find all headers exposed already in the response. These need to be added to the
// Access-Control-Expose-Headers header if any are found.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ private ObjectNode updateIntegrationResponse(
ObjectNode responseParams = response.getObjectMember(RESPONSE_PARAMETERS_KEY).orElseGet(Node::objectNode);

// Created a sorted set of all headers exposed in the integration.
Set<String> headersToExpose = new TreeSet<>(deduced);
Set<String> headersToExpose = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
headersToExpose.addAll(deduced);
responseParams.getStringMap().keySet().stream()
.filter(parameterName -> parameterName.startsWith(HEADER_PREFIX))
.map(parameterName -> parameterName.substring(HEADER_PREFIX.length()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
package software.amazon.smithy.aws.apigateway.openapi;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import software.amazon.smithy.utils.SetUtils;

/**
* API Gateway OpenAPI configuration.
Expand Down Expand Up @@ -55,7 +59,7 @@ public enum ApiType {

private ApiType apiGatewayType = ApiType.REST;
private boolean disableCloudFormationSubstitution;
private List<String> additionalAllowedCorsHeaders = new ArrayList<>();
private Set<String> additionalAllowedCorsHeaders = Collections.emptySet();

/**
* @return Returns true if CloudFormation substitutions are disabled.
Expand Down Expand Up @@ -94,21 +98,26 @@ public void setApiGatewayType(ApiType apiGatewayType) {
}

/**
* @return the list of additional allowed CORS headers.
* @deprecated Use {@link ApiGatewayConfig#getAdditionalAllowedCorsHeadersSet}
*/
@Deprecated
public List<String> getAdditionalAllowedCorsHeaders() {
return new ArrayList<>(additionalAllowedCorsHeaders);
}

/**
* @return the set of additional allowed CORS headers.
*/
public Set<String> getAdditionalAllowedCorsHeadersSet() {
return additionalAllowedCorsHeaders;
}

/**
* Sets the list of additional allowed CORS headers.
*
* <p>If not set, this value defaults to setting "amz-sdk-invocation-id" and
* "amz-sdk-request" as the additional allowed CORS headers.</p>
* Sets the additional allowed CORS headers.
*
* @param additionalAllowedCorsHeaders additional cors headers to be allowed.
*/
public void setAdditionalAllowedCorsHeaders(List<String> additionalAllowedCorsHeaders) {
this.additionalAllowedCorsHeaders = Objects.requireNonNull(additionalAllowedCorsHeaders);
public void setAdditionalAllowedCorsHeaders(Collection<String> additionalAllowedCorsHeaders) {
this.additionalAllowedCorsHeaders = SetUtils.caseInsensitiveCopyOf(additionalAllowedCorsHeaders);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ static <T extends Trait> Set<String> deduceOperationResponseHeaders(
// The deduced response headers of an operation consist of any headers
// returned by security schemes, any headers returned by the protocol,
// and any headers explicitly modeled on the operation.
Set<String> result = new TreeSet<>(cors.getAdditionalExposedHeaders());
Set<String> result = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
result.addAll(cors.getAdditionalExposedHeaders());
result.addAll(context.getOpenApiProtocol().getProtocolResponseHeaders(context, shape));
result.addAll(context.getAllSecuritySchemeResponseHeaders());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public void setsConfiguredAdditionalAllowedHeaders() {
OpenApiConfig config = new OpenApiConfig();
config.setService(ShapeId.from("example.smithy#MyService"));
ApiGatewayConfig apiGatewayConfig = new ApiGatewayConfig();
apiGatewayConfig.setAdditionalAllowedCorsHeaders(ListUtils.of("foo","bar"));
apiGatewayConfig.setAdditionalAllowedCorsHeaders(ListUtils.of("foo", "bar", "content-length"));
config.putExtensions(apiGatewayConfig);
ObjectNode result = OpenApiConverter.create().config(config).convertToNode(model);
Node expectedNode = Node.parse(IoUtils.toUtf8String(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
"statusCode": "200",
"responseParameters": {
"method.response.header.Access-Control-Max-Age": "'86400'",
"method.response.header.Access-Control-Allow-Headers": "'Amz-Sdk-Invocation-Id,Amz-Sdk-Request,Authorization,Date,Host,X-Amz-Content-Sha256,X-Amz-Date,X-Amz-Security-Token,X-Amz-Target,X-Amz-User-Agent,X-Amzn-Trace-Id,X-Service-Input-Metadata,bar,foo'",
"method.response.header.Access-Control-Allow-Headers": "'Amz-Sdk-Invocation-Id,Amz-Sdk-Request,Authorization,bar,content-length,Date,foo,Host,X-Amz-Content-Sha256,X-Amz-Date,X-Amz-Security-Token,X-Amz-Target,X-Amz-User-Agent,X-Amzn-Trace-Id,X-Service-Input-Metadata'",
"method.response.header.Access-Control-Allow-Origin": "'https://www.example.com'",
"method.response.header.Access-Control-Allow-Methods": "'GET'"
}
Expand Down Expand Up @@ -265,7 +265,7 @@
"statusCode": "200",
"responseParameters": {
"method.response.header.Access-Control-Max-Age": "'86400'",
"method.response.header.Access-Control-Allow-Headers": "'Amz-Sdk-Invocation-Id,Amz-Sdk-Request,Authorization,Content-Length,Content-Type,Date,Host,X-Amz-Content-Sha256,X-Amz-Date,X-Amz-Security-Token,X-Amz-Target,X-Amz-User-Agent,X-Amzn-Trace-Id,X-EnumString,X-Foo-Header,X-Service-Input-Metadata,bar,foo'",
"method.response.header.Access-Control-Allow-Headers": "'Amz-Sdk-Invocation-Id,Amz-Sdk-Request,Authorization,bar,content-length,Content-Type,Date,foo,Host,X-Amz-Content-Sha256,X-Amz-Date,X-Amz-Security-Token,X-Amz-Target,X-Amz-User-Agent,X-Amzn-Trace-Id,X-EnumString,X-Foo-Header,X-Service-Input-Metadata'",
"method.response.header.Access-Control-Allow-Origin": "'https://www.example.com'",
"method.response.header.Access-Control-Allow-Methods": "'DELETE,GET,PUT'"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

package software.amazon.smithy.model.traits;

import java.util.LinkedHashSet;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -121,12 +120,12 @@ public Builder maxAge(int maxAge) {
}

public Builder additionalAllowedHeaders(Set<String> additionalAllowedHeaders) {
this.additionalAllowedHeaders = new LinkedHashSet<>(Objects.requireNonNull(additionalAllowedHeaders));
this.additionalAllowedHeaders = SetUtils.caseInsensitiveCopyOf(additionalAllowedHeaders);
return this;
}

public Builder additionalExposedHeaders(Set<String> additionalExposedHeaders) {
this.additionalExposedHeaders = new LinkedHashSet<>(Objects.requireNonNull(additionalExposedHeaders));
this.additionalExposedHeaders = SetUtils.caseInsensitiveCopyOf(additionalExposedHeaders);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ abstract Schema createDocumentSchema(

@Override
public Set<String> getProtocolRequestHeaders(Context<T> context, OperationShape operationShape) {
Set<String> headers = new TreeSet<>(OpenApiProtocol.super.getProtocolRequestHeaders(context, operationShape));
Set<String> headers = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
headers.addAll(OpenApiProtocol.super.getProtocolRequestHeaders(context, operationShape));

HttpBindingIndex bindingIndex = HttpBindingIndex.of(context.getModel());
String documentMediaType = getDocumentMediaType(context, operationShape, MessageType.REQUEST);
Expand All @@ -135,7 +136,8 @@ public Set<String> getProtocolRequestHeaders(Context<T> context, OperationShape

@Override
public Set<String> getProtocolResponseHeaders(Context<T> context, OperationShape operationShape) {
Set<String> headers = new TreeSet<>(OpenApiProtocol.super.getProtocolResponseHeaders(context, operationShape));
Set<String> headers = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
headers.addAll(OpenApiProtocol.super.getProtocolResponseHeaders(context, operationShape));

// If the operation has any defined output or errors, it can return content-type.
if (operationShape.getOutput().isPresent() || !operationShape.getErrors().isEmpty()) {
Expand All @@ -150,7 +152,7 @@ public Set<String> getProtocolResponseHeaders(Context<T> context, OperationShape
}

private Set<String> getChecksumHeaders(List<HttpChecksumProperty> httpChecksumProperties) {
Set<String> headers = new TreeSet<>();
Set<String> headers = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
for (HttpChecksumProperty property : httpChecksumProperties) {
if (property.getLocation().equals(Location.HEADER)) {
headers.add(property.getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ public Class<RestJson1Trait> getProtocolType() {
@Override
public Set<String> getProtocolRequestHeaders(Context<RestJson1Trait> context, OperationShape operationShape) {
// x-amz-api-version if it is an endpoint operation
Set<String> headers = new TreeSet<>(super.getProtocolRequestHeaders(context, operationShape));
Set<String> headers = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
headers.addAll(super.getProtocolRequestHeaders(context, operationShape));
headers.addAll(AWS_REQUEST_HEADERS);
if (operationShape.hasTrait(ClientDiscoveredEndpointTrait.class)) {
headers.add("X-Amz-Api-Version");
Expand All @@ -72,7 +73,8 @@ public Set<String> getProtocolRequestHeaders(Context<RestJson1Trait> context, Op

@Override
public Set<String> getProtocolResponseHeaders(Context<RestJson1Trait> context, OperationShape operationShape) {
Set<String> headers = new TreeSet<>(super.getProtocolResponseHeaders(context, operationShape));
Set<String> headers = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
headers.addAll(super.getProtocolResponseHeaders(context, operationShape));
headers.addAll(AWS_RESPONSE_HEADERS);
return headers;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collector;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -51,6 +53,12 @@ public static <T> Set<T> orderedCopyOf(Collection<? extends T> values) {
return values.isEmpty() ? Collections.emptySet() : Collections.unmodifiableSet(new LinkedHashSet<>(values));
}

public static Set<String> caseInsensitiveCopyOf(Collection<? extends String> values) {
Set<String> caseInsensitiveSet = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
caseInsensitiveSet.addAll(Objects.requireNonNull(values));
return Collections.unmodifiableSet(caseInsensitiveSet);
}

/**
* Returns an unmodifiable set containing zero entries.
*
Expand Down

0 comments on commit e3772fe

Please sign in to comment.