Skip to content

Commit

Permalink
feat: Add AWS Session Token to Metadata Requests
Browse files Browse the repository at this point in the history
  • Loading branch information
sai-sunder-s committed Feb 10, 2022
1 parent 4b236c1 commit abc8711
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 26 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ Where the following variables need to be substituted:

This generates the configuration file in the specified output file.

If you want to use IDMSv2, then below field needs to be added to credential_source
section of credential configuration.
"aws_session_token_url": "http://169.254.169.254/latest/api/token"

You can now [use the Auth library](#using-external-identities) to call Google Cloud
resources from AWS.

Expand Down
51 changes: 43 additions & 8 deletions oauth2_http/java/com/google/auth/oauth2/AwsCredentials.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
package com.google.auth.oauth2;

import com.google.api.client.http.GenericUrl;
import com.google.api.client.http.HttpHeaders;
import com.google.api.client.http.HttpRequest;
import com.google.api.client.http.HttpRequestFactory;
import com.google.api.client.http.HttpResponse;
Expand Down Expand Up @@ -65,6 +66,7 @@ static class AwsCredentialSource extends CredentialSource {
private final String regionUrl;
private final String url;
private final String regionalCredentialVerificationUrl;
private final String awsSessionTokenUrl;

/**
* The source of the AWS credential. The credential source map must contain the
Expand Down Expand Up @@ -107,6 +109,12 @@ static class AwsCredentialSource extends CredentialSource {
this.url = (String) credentialSourceMap.get("url");
this.regionalCredentialVerificationUrl =
(String) credentialSourceMap.get("regional_cred_verification_url");

if (credentialSourceMap.containsKey("aws_session_token_url")) {
this.awsSessionTokenUrl = (String) credentialSourceMap.get("aws_session_token_url");
} else {
this.awsSessionTokenUrl = null;
}
}
}

Expand Down Expand Up @@ -135,11 +143,16 @@ public AccessToken refreshAccessToken() throws IOException {

@Override
public String retrieveSubjectToken() throws IOException {
String awsSessionToken = null;
if (awsCredentialSource.awsSessionTokenUrl != null) {
awsSessionToken = getAwsSessionToken(awsCredentialSource.awsSessionTokenUrl);
}

// The targeted region is required to generate the signed request. The regional
// endpoint must also be used.
String region = getAwsRegion();
String region = getAwsRegion(awsSessionToken);

AwsSecurityCredentials credentials = getAwsSecurityCredentials();
AwsSecurityCredentials credentials = getAwsSecurityCredentials(awsSessionToken);

// Generate the signed request to the AWS STS GetCallerIdentity API.
Map<String, String> headers = new HashMap<>();
Expand All @@ -164,17 +177,39 @@ public GoogleCredentials createScoped(Collection<String> newScopes) {
return new AwsCredentials((AwsCredentials.Builder) newBuilder(this).setScopes(newScopes));
}

private String retrieveResource(String url, String resourceName) throws IOException {
private String retrieveResource(String url, String resourceName, String sessionToken)
throws IOException {
try {
HttpRequestFactory requestFactory = transportFactory.create().createRequestFactory();
HttpRequest request = requestFactory.buildGetRequest(new GenericUrl(url));

if (sessionToken != null) {
HttpHeaders headers = request.getHeaders();
headers.set("X-aws-ec2-metadata-token", sessionToken);
}

HttpResponse response = request.execute();
return response.parseAsString();
} catch (IOException e) {
throw new IOException(String.format("Failed to retrieve AWS %s.", resourceName), e);
}
}

private String getAwsSessionToken(String awsSessionTokenUrl) throws IOException {
try {
HttpRequestFactory requestFactory = transportFactory.create().createRequestFactory();
HttpRequest request =
requestFactory.buildPutRequest(new GenericUrl(awsSessionTokenUrl), null);
HttpHeaders headers = request.getHeaders();
headers.set("X-aws-ec2-metadata-token-ttl-seconds", "21600");

HttpResponse response = request.execute();
return response.parseAsString();
} catch (IOException e) {
throw new IOException(String.format("Failed to fetch AWS Session Token"), e);
}
}

private String buildSubjectToken(AwsRequestSignature signature)
throws UnsupportedEncodingException {
Map<String, String> canonicalHeaders = signature.getCanonicalHeaders();
Expand All @@ -201,7 +236,7 @@ private String buildSubjectToken(AwsRequestSignature signature)
}

@VisibleForTesting
String getAwsRegion() throws IOException {
String getAwsRegion(String awsSessionToken) throws IOException {
// For AWS Lambda, the region is retrieved through the AWS_REGION environment variable.
String region = getEnvironmentProvider().getEnv("AWS_REGION");
if (region != null) {
Expand All @@ -218,15 +253,15 @@ String getAwsRegion() throws IOException {
"Unable to determine the AWS region. The credential source does not contain the region URL.");
}

region = retrieveResource(awsCredentialSource.regionUrl, "region");
region = retrieveResource(awsCredentialSource.regionUrl, "region", awsSessionToken);

// There is an extra appended character that must be removed. If `us-east-1b` is returned,
// we want `us-east-1`.
return region.substring(0, region.length() - 1);
}

@VisibleForTesting
AwsSecurityCredentials getAwsSecurityCredentials() throws IOException {
AwsSecurityCredentials getAwsSecurityCredentials(String awsSessionToken) throws IOException {
// Check environment variables for credentials first.
String accessKeyId = getEnvironmentProvider().getEnv("AWS_ACCESS_KEY_ID");
String secretAccessKey = getEnvironmentProvider().getEnv("AWS_SECRET_ACCESS_KEY");
Expand All @@ -243,12 +278,12 @@ AwsSecurityCredentials getAwsSecurityCredentials() throws IOException {
"Unable to determine the AWS IAM role name. The credential source does not contain the"
+ " url field.");
}
String roleName = retrieveResource(awsCredentialSource.url, "IAM role");
String roleName = retrieveResource(awsCredentialSource.url, "IAM role", awsSessionToken);

// Retrieve the AWS security credentials by calling the endpoint specified by the credential
// source.
String awsCredentials =
retrieveResource(awsCredentialSource.url + "/" + roleName, "credentials");
retrieveResource(awsCredentialSource.url + "/" + roleName, "credentials", awsSessionToken);

JsonParser parser = OAuth2Utils.JSON_FACTORY.createJsonParser(awsCredentials);
GenericJson genericJson = parser.parseAndClose(GenericJson.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void refreshAccessToken_withoutServiceAccountImpersonation() throws IOException
AwsCredentials.newBuilder(AWS_CREDENTIAL)
.setTokenUrl(transportFactory.transport.getStsUrl())
.setHttpTransportFactory(transportFactory)
.setCredentialSource(buildAwsCredentialSource(transportFactory))
.setCredentialSource(buildAwsCredentialSource(transportFactory, false))
.build();

AccessToken accessToken = awsCredential.refreshAccessToken();
Expand All @@ -119,7 +119,7 @@ void refreshAccessToken_withServiceAccountImpersonation() throws IOException {
.setServiceAccountImpersonationUrl(
transportFactory.transport.getServiceAccountImpersonationUrl())
.setHttpTransportFactory(transportFactory)
.setCredentialSource(buildAwsCredentialSource(transportFactory))
.setCredentialSource(buildAwsCredentialSource(transportFactory, false))
.build();

AccessToken accessToken = awsCredential.refreshAccessToken();
Expand All @@ -137,7 +137,42 @@ void retrieveSubjectToken() throws IOException {
(AwsCredentials)
AwsCredentials.newBuilder(AWS_CREDENTIAL)
.setHttpTransportFactory(transportFactory)
.setCredentialSource(buildAwsCredentialSource(transportFactory))
.setCredentialSource(buildAwsCredentialSource(transportFactory, false))
.build();

String subjectToken = URLDecoder.decode(awsCredential.retrieveSubjectToken(), "UTF-8");

JsonParser parser = OAuth2Utils.JSON_FACTORY.createJsonParser(subjectToken);
GenericJson json = parser.parseAndClose(GenericJson.class);

List<Map<String, String>> headersList = (List<Map<String, String>>) json.get("headers");
Map<String, String> headers = new HashMap<>();
for (Map<String, String> header : headersList) {
headers.put(header.get("key"), header.get("value"));
}

assertEquals("POST", json.get("method"));
assertEquals(GET_CALLER_IDENTITY_URL, json.get("url"));
assertEquals(URI.create(GET_CALLER_IDENTITY_URL).getHost(), headers.get("host"));
assertEquals("token", headers.get("x-amz-security-token"));
assertEquals(awsCredential.getAudience(), headers.get("x-goog-cloud-target-resource"));
assertTrue(headers.containsKey("x-amz-date"));
assertNotNull(headers.get("Authorization"));
}

@Test
void retrieveSubjectTokenWithSessionTokenUrl() throws IOException {
MockExternalAccountCredentialsTransportFactory transportFactory =
new MockExternalAccountCredentialsTransportFactory();

// Map<String, Object> credentialMap = new HashMap<>(AWS_CREDENTIAL_SOURCE_MAP);
// credentialMap.put("aws_session_token_url", "awsSesionTokenUrl")

AwsCredentials awsCredential =
(AwsCredentials)
AwsCredentials.newBuilder(AWS_CREDENTIAL)
.setHttpTransportFactory(transportFactory)
.setCredentialSource(buildAwsCredentialSource(transportFactory, true))
.build();

String subjectToken = URLDecoder.decode(awsCredential.retrieveSubjectToken(), "UTF-8");
Expand Down Expand Up @@ -172,7 +207,7 @@ void retrieveSubjectToken_noRegion_expectThrows() {
(AwsCredentials)
AwsCredentials.newBuilder(AWS_CREDENTIAL)
.setHttpTransportFactory(transportFactory)
.setCredentialSource(buildAwsCredentialSource(transportFactory))
.setCredentialSource(buildAwsCredentialSource(transportFactory, false))
.build();

IOException exception =
Expand All @@ -194,7 +229,7 @@ void retrieveSubjectToken_noRole_expectThrows() {
(AwsCredentials)
AwsCredentials.newBuilder(AWS_CREDENTIAL)
.setHttpTransportFactory(transportFactory)
.setCredentialSource(buildAwsCredentialSource(transportFactory))
.setCredentialSource(buildAwsCredentialSource(transportFactory, false))
.build();

IOException exception =
Expand All @@ -216,7 +251,7 @@ void retrieveSubjectToken_noCredentials_expectThrows() {
(AwsCredentials)
AwsCredentials.newBuilder(AWS_CREDENTIAL)
.setHttpTransportFactory(transportFactory)
.setCredentialSource(buildAwsCredentialSource(transportFactory))
.setCredentialSource(buildAwsCredentialSource(transportFactory, false))
.build();

IOException exception =
Expand Down Expand Up @@ -263,7 +298,7 @@ void getAwsSecurityCredentials_fromEnvironmentVariablesNoToken() throws IOExcept
.setEnvironmentProvider(environmentProvider)
.build();

AwsSecurityCredentials credentials = testAwsCredentials.getAwsSecurityCredentials();
AwsSecurityCredentials credentials = testAwsCredentials.getAwsSecurityCredentials(null);

assertEquals("awsAccessKeyId", credentials.getAccessKeyId());
assertEquals("awsSecretAccessKey", credentials.getSecretAccessKey());
Expand All @@ -284,7 +319,7 @@ void getAwsSecurityCredentials_fromEnvironmentVariablesWithToken() throws IOExce
.setEnvironmentProvider(environmentProvider)
.build();

AwsSecurityCredentials credentials = testAwsCredentials.getAwsSecurityCredentials();
AwsSecurityCredentials credentials = testAwsCredentials.getAwsSecurityCredentials(null);

assertEquals("awsAccessKeyId", credentials.getAccessKeyId());
assertEquals("awsSecretAccessKey", credentials.getSecretAccessKey());
Expand All @@ -300,10 +335,10 @@ void getAwsSecurityCredentials_fromMetadataServer() throws IOException {
(AwsCredentials)
AwsCredentials.newBuilder(AWS_CREDENTIAL)
.setHttpTransportFactory(transportFactory)
.setCredentialSource(buildAwsCredentialSource(transportFactory))
.setCredentialSource(buildAwsCredentialSource(transportFactory, false))
.build();

AwsSecurityCredentials credentials = awsCredential.getAwsSecurityCredentials();
AwsSecurityCredentials credentials = awsCredential.getAwsSecurityCredentials(null);

assertEquals("accessKeyId", credentials.getAccessKeyId());
assertEquals("secretAccessKey", credentials.getSecretAccessKey());
Expand All @@ -329,7 +364,9 @@ void getAwsSecurityCredentials_fromMetadataServer_noUrlProvided() {
IOException exception =
assertThrows(
IOException.class,
awsCredential::getAwsSecurityCredentials,
() -> {
awsCredential.getAwsSecurityCredentials(null);
},
"Exception should be thrown.");
assertEquals(
"Unable to determine the AWS IAM role name. The credential source does not contain the url field.",
Expand All @@ -348,11 +385,11 @@ void getAwsRegion_awsRegionEnvironmentVariable() throws IOException {
(AwsCredentials)
AwsCredentials.newBuilder(AWS_CREDENTIAL)
.setHttpTransportFactory(transportFactory)
.setCredentialSource(buildAwsCredentialSource(transportFactory))
.setCredentialSource(buildAwsCredentialSource(transportFactory, false))
.setEnvironmentProvider(environmentProvider)
.build();

String region = awsCredentials.getAwsRegion();
String region = awsCredentials.getAwsRegion(null);

// Should attempt to retrieve the region from AWS_REGION env var first.
// Metadata server would return us-east-1b.
Expand All @@ -370,11 +407,11 @@ void getAwsRegion_awsDefaultRegionEnvironmentVariable() throws IOException {
(AwsCredentials)
AwsCredentials.newBuilder(AWS_CREDENTIAL)
.setHttpTransportFactory(transportFactory)
.setCredentialSource(buildAwsCredentialSource(transportFactory))
.setCredentialSource(buildAwsCredentialSource(transportFactory, false))
.setEnvironmentProvider(environmentProvider)
.build();

String region = awsCredentials.getAwsRegion();
String region = awsCredentials.getAwsRegion(null);

// Should attempt to retrieve the region from DEFAULT_AWS_REGION before calling the metadata
// server. Metadata server would return us-east-1b.
Expand All @@ -389,10 +426,10 @@ void getAwsRegion_metadataServer() throws IOException {
(AwsCredentials)
AwsCredentials.newBuilder(AWS_CREDENTIAL)
.setHttpTransportFactory(transportFactory)
.setCredentialSource(buildAwsCredentialSource(transportFactory))
.setCredentialSource(buildAwsCredentialSource(transportFactory, false))
.build();

String region = awsCredentials.getAwsRegion();
String region = awsCredentials.getAwsRegion(null);

// Should retrieve the region from the Metadata server.
String expectedRegion =
Expand Down Expand Up @@ -511,12 +548,19 @@ public void builder() {
}

private static AwsCredentialSource buildAwsCredentialSource(
MockExternalAccountCredentialsTransportFactory transportFactory) {
MockExternalAccountCredentialsTransportFactory transportFactory,
Boolean includeAwsSessionTokenUrl) {
Map<String, Object> credentialSourceMap = new HashMap<>();
credentialSourceMap.put("environment_id", "aws1");
credentialSourceMap.put("region_url", transportFactory.transport.getAwsRegionUrl());
credentialSourceMap.put("url", transportFactory.transport.getAwsCredentialsUrl());
credentialSourceMap.put("regional_cred_verification_url", GET_CALLER_IDENTITY_URL);

if (includeAwsSessionTokenUrl) {
credentialSourceMap.put(
"aws_session_token_url", transportFactory.transport.getAwsSessionTokenUrl());
}

return new AwsCredentialSource(credentialSourceMap);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public class MockExternalAccountCredentialsTransport extends MockHttpTransport {
private static final String ISSUED_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token";
private static final String AWS_CREDENTIALS_URL = "https://www.aws-credentials.com";
private static final String AWS_REGION_URL = "https://www.aws-region.com";
private static final String AWS_SESSION_TOKEN_URL = "https://www.aws-session-token.com";
private static final String METADATA_SERVER_URL = "https://www.metadata.google.com";
private static final String STS_URL = "https://sts.googleapis.com";

Expand Down Expand Up @@ -119,6 +120,11 @@ public LowLevelHttpResponse execute() throws IOException {
throw responseErrorSequence.poll();
}

if (AWS_SESSION_TOKEN_URL.equals(url)) {
return new MockLowLevelHttpResponse()
.setContentType("text/html")
.setContent("sessiontoken");
}
if (AWS_REGION_URL.equals(url)) {
return new MockLowLevelHttpResponse()
.setContentType("text/html")
Expand Down Expand Up @@ -255,6 +261,10 @@ public String getAwsRegionUrl() {
return AWS_REGION_URL;
}

public String getAwsSessionTokenUrl() {
return AWS_SESSION_TOKEN_URL;
}

public String getAwsRegion() {
return AWS_REGION;
}
Expand Down

0 comments on commit abc8711

Please sign in to comment.