Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use case-insensitive sets for cors headers #950

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,12 @@ private static <T extends Trait> Map<CorsHeader, String> deduceCorsHeaders(
// Access-Control-Allow-Headers header list. Note that any further modifications that
// add headers during the Smithy to OpenAPI conversion process will need to update this
// list of headers accordingly.
Set<String> headerNames = new TreeSet<>(corsTrait.getAdditionalAllowedHeaders());
Set<String> headerNames = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
headerNames.addAll(corsTrait.getAdditionalAllowedHeaders());

// Sets additional allowed headers from the API Gateway config.
List<String> additionalAllowedHeaders = context.getConfig().getExtensions(ApiGatewayConfig.class)
.getAdditionalAllowedCorsHeaders();
Set<String> additionalAllowedHeaders = context.getConfig().getExtensions(ApiGatewayConfig.class)
.getAdditionalAllowedCorsHeadersSet();
headerNames.addAll(additionalAllowedHeaders);
headerNames.addAll(findAllHeaders(path, pathItem));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ private ObjectNode updateGatewayResponse(

// Add the modeled additional headers. These could potentially be added by an
// apigateway feature, so they need to be present.
Set<String> exposedHeaders = new TreeSet<>(trait.getAdditionalExposedHeaders());
Set<String> exposedHeaders = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
exposedHeaders.addAll(trait.getAdditionalExposedHeaders());

// Find all headers exposed already in the response. These need to be added to the
// Access-Control-Expose-Headers header if any are found.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ private ObjectNode updateIntegrationResponse(
ObjectNode responseParams = response.getObjectMember(RESPONSE_PARAMETERS_KEY).orElseGet(Node::objectNode);

// Created a sorted set of all headers exposed in the integration.
Set<String> headersToExpose = new TreeSet<>(deduced);
Set<String> headersToExpose = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
headersToExpose.addAll(deduced);
responseParams.getStringMap().keySet().stream()
.filter(parameterName -> parameterName.startsWith(HEADER_PREFIX))
.map(parameterName -> parameterName.substring(HEADER_PREFIX.length()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
package software.amazon.smithy.aws.apigateway.openapi;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import software.amazon.smithy.utils.SetUtils;

/**
* API Gateway OpenAPI configuration.
Expand Down Expand Up @@ -55,7 +59,7 @@ public enum ApiType {

private ApiType apiGatewayType = ApiType.REST;
private boolean disableCloudFormationSubstitution;
private List<String> additionalAllowedCorsHeaders = new ArrayList<>();
private Set<String> additionalAllowedCorsHeaders = Collections.emptySet();

/**
* @return Returns true if CloudFormation substitutions are disabled.
Expand Down Expand Up @@ -94,21 +98,26 @@ public void setApiGatewayType(ApiType apiGatewayType) {
}

/**
* @return the list of additional allowed CORS headers.
* @deprecated Use {@link ApiGatewayConfig#getAdditionalAllowedCorsHeadersSet}
*/
@Deprecated
public List<String> getAdditionalAllowedCorsHeaders() {
return new ArrayList<>(additionalAllowedCorsHeaders);
}

/**
* @return the set of additional allowed CORS headers.
*/
public Set<String> getAdditionalAllowedCorsHeadersSet() {
return additionalAllowedCorsHeaders;
}

/**
* Sets the list of additional allowed CORS headers.
*
* <p>If not set, this value defaults to setting "amz-sdk-invocation-id" and
* "amz-sdk-request" as the additional allowed CORS headers.</p>
* Sets the additional allowed CORS headers.
*
* @param additionalAllowedCorsHeaders additional cors headers to be allowed.
*/
public void setAdditionalAllowedCorsHeaders(List<String> additionalAllowedCorsHeaders) {
this.additionalAllowedCorsHeaders = Objects.requireNonNull(additionalAllowedCorsHeaders);
public void setAdditionalAllowedCorsHeaders(Collection<String> additionalAllowedCorsHeaders) {
this.additionalAllowedCorsHeaders = SetUtils.caseInsensitiveCopyOf(additionalAllowedCorsHeaders);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ static <T extends Trait> Set<String> deduceOperationResponseHeaders(
// The deduced response headers of an operation consist of any headers
// returned by security schemes, any headers returned by the protocol,
// and any headers explicitly modeled on the operation.
Set<String> result = new TreeSet<>(cors.getAdditionalExposedHeaders());
Set<String> result = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
result.addAll(cors.getAdditionalExposedHeaders());
result.addAll(context.getOpenApiProtocol().getProtocolResponseHeaders(context, shape));
result.addAll(context.getAllSecuritySchemeResponseHeaders());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public void setsConfiguredAdditionalAllowedHeaders() {
OpenApiConfig config = new OpenApiConfig();
config.setService(ShapeId.from("example.smithy#MyService"));
ApiGatewayConfig apiGatewayConfig = new ApiGatewayConfig();
apiGatewayConfig.setAdditionalAllowedCorsHeaders(ListUtils.of("foo","bar"));
apiGatewayConfig.setAdditionalAllowedCorsHeaders(ListUtils.of("foo", "bar", "content-length"));
config.putExtensions(apiGatewayConfig);
ObjectNode result = OpenApiConverter.create().config(config).convertToNode(model);
Node expectedNode = Node.parse(IoUtils.toUtf8String(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
"statusCode": "200",
"responseParameters": {
"method.response.header.Access-Control-Max-Age": "'86400'",
"method.response.header.Access-Control-Allow-Headers": "'Amz-Sdk-Invocation-Id,Amz-Sdk-Request,Authorization,Date,Host,X-Amz-Content-Sha256,X-Amz-Date,X-Amz-Security-Token,X-Amz-Target,X-Amz-User-Agent,X-Amzn-Trace-Id,X-Service-Input-Metadata,bar,foo'",
"method.response.header.Access-Control-Allow-Headers": "'Amz-Sdk-Invocation-Id,Amz-Sdk-Request,Authorization,bar,content-length,Date,foo,Host,X-Amz-Content-Sha256,X-Amz-Date,X-Amz-Security-Token,X-Amz-Target,X-Amz-User-Agent,X-Amzn-Trace-Id,X-Service-Input-Metadata'",
"method.response.header.Access-Control-Allow-Origin": "'https://www.example.com'",
"method.response.header.Access-Control-Allow-Methods": "'GET'"
}
Expand Down Expand Up @@ -265,7 +265,7 @@
"statusCode": "200",
"responseParameters": {
"method.response.header.Access-Control-Max-Age": "'86400'",
"method.response.header.Access-Control-Allow-Headers": "'Amz-Sdk-Invocation-Id,Amz-Sdk-Request,Authorization,Content-Length,Content-Type,Date,Host,X-Amz-Content-Sha256,X-Amz-Date,X-Amz-Security-Token,X-Amz-Target,X-Amz-User-Agent,X-Amzn-Trace-Id,X-EnumString,X-Foo-Header,X-Service-Input-Metadata,bar,foo'",
"method.response.header.Access-Control-Allow-Headers": "'Amz-Sdk-Invocation-Id,Amz-Sdk-Request,Authorization,bar,content-length,Content-Type,Date,foo,Host,X-Amz-Content-Sha256,X-Amz-Date,X-Amz-Security-Token,X-Amz-Target,X-Amz-User-Agent,X-Amzn-Trace-Id,X-EnumString,X-Foo-Header,X-Service-Input-Metadata'",
"method.response.header.Access-Control-Allow-Origin": "'https://www.example.com'",
"method.response.header.Access-Control-Allow-Methods": "'DELETE,GET,PUT'"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

package software.amazon.smithy.model.traits;

import java.util.LinkedHashSet;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -121,12 +120,12 @@ public Builder maxAge(int maxAge) {
}

public Builder additionalAllowedHeaders(Set<String> additionalAllowedHeaders) {
this.additionalAllowedHeaders = new LinkedHashSet<>(Objects.requireNonNull(additionalAllowedHeaders));
this.additionalAllowedHeaders = SetUtils.caseInsensitiveCopyOf(additionalAllowedHeaders);
return this;
}

public Builder additionalExposedHeaders(Set<String> additionalExposedHeaders) {
this.additionalExposedHeaders = new LinkedHashSet<>(Objects.requireNonNull(additionalExposedHeaders));
this.additionalExposedHeaders = SetUtils.caseInsensitiveCopyOf(additionalExposedHeaders);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ abstract Schema createDocumentSchema(

@Override
public Set<String> getProtocolRequestHeaders(Context<T> context, OperationShape operationShape) {
Set<String> headers = new TreeSet<>(OpenApiProtocol.super.getProtocolRequestHeaders(context, operationShape));
Set<String> headers = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
headers.addAll(OpenApiProtocol.super.getProtocolRequestHeaders(context, operationShape));

HttpBindingIndex bindingIndex = HttpBindingIndex.of(context.getModel());
String documentMediaType = getDocumentMediaType(context, operationShape, MessageType.REQUEST);
Expand All @@ -135,7 +136,8 @@ public Set<String> getProtocolRequestHeaders(Context<T> context, OperationShape

@Override
public Set<String> getProtocolResponseHeaders(Context<T> context, OperationShape operationShape) {
Set<String> headers = new TreeSet<>(OpenApiProtocol.super.getProtocolResponseHeaders(context, operationShape));
Set<String> headers = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
headers.addAll(OpenApiProtocol.super.getProtocolResponseHeaders(context, operationShape));

// If the operation has any defined output or errors, it can return content-type.
if (operationShape.getOutput().isPresent() || !operationShape.getErrors().isEmpty()) {
Expand All @@ -150,7 +152,7 @@ public Set<String> getProtocolResponseHeaders(Context<T> context, OperationShape
}

private Set<String> getChecksumHeaders(List<HttpChecksumProperty> httpChecksumProperties) {
Set<String> headers = new TreeSet<>();
Set<String> headers = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
for (HttpChecksumProperty property : httpChecksumProperties) {
if (property.getLocation().equals(Location.HEADER)) {
headers.add(property.getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ public Class<RestJson1Trait> getProtocolType() {
@Override
public Set<String> getProtocolRequestHeaders(Context<RestJson1Trait> context, OperationShape operationShape) {
// x-amz-api-version if it is an endpoint operation
Set<String> headers = new TreeSet<>(super.getProtocolRequestHeaders(context, operationShape));
Set<String> headers = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
headers.addAll(super.getProtocolRequestHeaders(context, operationShape));
headers.addAll(AWS_REQUEST_HEADERS);
if (operationShape.hasTrait(ClientDiscoveredEndpointTrait.class)) {
headers.add("X-Amz-Api-Version");
Expand All @@ -72,7 +73,8 @@ public Set<String> getProtocolRequestHeaders(Context<RestJson1Trait> context, Op

@Override
public Set<String> getProtocolResponseHeaders(Context<RestJson1Trait> context, OperationShape operationShape) {
Set<String> headers = new TreeSet<>(super.getProtocolResponseHeaders(context, operationShape));
Set<String> headers = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
headers.addAll(super.getProtocolResponseHeaders(context, operationShape));
headers.addAll(AWS_RESPONSE_HEADERS);
return headers;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collector;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -51,6 +53,12 @@ public static <T> Set<T> orderedCopyOf(Collection<? extends T> values) {
return values.isEmpty() ? Collections.emptySet() : Collections.unmodifiableSet(new LinkedHashSet<>(values));
}

public static Set<String> caseInsensitiveCopyOf(Collection<? extends String> values) {
JordonPhillips marked this conversation as resolved.
Show resolved Hide resolved
Set<String> caseInsensitiveSet = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
caseInsensitiveSet.addAll(Objects.requireNonNull(values));
return Collections.unmodifiableSet(caseInsensitiveSet);
}

/**
* Returns an unmodifiable set containing zero entries.
*
Expand Down