diff --git a/README.md b/README.md index 772ab48d6..0cb356895 100644 --- a/README.md +++ b/README.md @@ -211,10 +211,11 @@ Where the following variables need to be substituted: This generates the configuration file in the specified output file. -If you want to use IDMSv2, then below field needs to be added to credential_source -section of credential configuration. +If you want to use the AWS IMDSv2 flow, you can add the field below to the credential_source in your AWS ADC configuration file: "aws_session_token_url": "http://169.254.169.254/latest/api/token" +The gcloud create-cred-config command will be updated to support this soon. + You can now [use the Auth library](#using-external-identities) to call Google Cloud resources from AWS. diff --git a/oauth2_http/java/com/google/auth/oauth2/AwsCredentials.java b/oauth2_http/java/com/google/auth/oauth2/AwsCredentials.java index 6302b1099..fa07eacf8 100644 --- a/oauth2_http/java/com/google/auth/oauth2/AwsCredentials.java +++ b/oauth2_http/java/com/google/auth/oauth2/AwsCredentials.java @@ -32,7 +32,9 @@ package com.google.auth.oauth2; import com.google.api.client.http.GenericUrl; +import com.google.api.client.http.HttpContent; import com.google.api.client.http.HttpHeaders; +import com.google.api.client.http.HttpMethods; import com.google.api.client.http.HttpRequest; import com.google.api.client.http.HttpRequestFactory; import com.google.api.client.http.HttpResponse; @@ -49,6 +51,7 @@ import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; +import javax.annotation.Nullable; /** * AWS credentials representing a third-party identity for calling Google APIs. @@ -143,16 +146,34 @@ public AccessToken refreshAccessToken() throws IOException { @Override public String retrieveSubjectToken() throws IOException { - String awsSessionToken = null; + // AWS IDMSv2 introduced a requirement for a session token to be present + // with the requests made to metadata endpoints. This requirement is to help + // prevent SSRF attacks. + // Presence of "aws_session_token_url" in Credential Source of config file + // will trigger a flow with session token, else there will not be a session + // token with the metadata requests. + // Both flows work for IDMS v1 and v2. But if IDMSv2 is enabled, then if + // session token is not present, Unauthorized exception will be thrown. + Map metadataHeaders = new HashMap<>(); if (awsCredentialSource.awsSessionTokenUrl != null) { - awsSessionToken = getAwsSessionToken(awsCredentialSource.awsSessionTokenUrl); + Map tokenRequestHeaders = + Map.of("x-aws-ec2-metadata-token-ttl-seconds", "21600"); + + String awsSessionToken = + retrieveResource( + awsCredentialSource.awsSessionTokenUrl, + "Session Token", + HttpMethods.PUT, + tokenRequestHeaders, + /*content =*/ null); + metadataHeaders.put("x-aws-ec2-metadata-token", awsSessionToken); } // The targeted region is required to generate the signed request. The regional // endpoint must also be used. - String region = getAwsRegion(awsSessionToken); + String region = getAwsRegion(metadataHeaders); - AwsSecurityCredentials credentials = getAwsSecurityCredentials(awsSessionToken); + AwsSecurityCredentials credentials = getAwsSecurityCredentials(metadataHeaders); // Generate the signed request to the AWS STS GetCallerIdentity API. Map headers = new HashMap<>(); @@ -177,36 +198,32 @@ public GoogleCredentials createScoped(Collection newScopes) { return new AwsCredentials((AwsCredentials.Builder) newBuilder(this).setScopes(newScopes)); } - private String retrieveResource(String url, String resourceName, String sessionToken) + private String retrieveResource(String url, String resourceName, Map headers) throws IOException { - try { - HttpRequestFactory requestFactory = transportFactory.create().createRequestFactory(); - HttpRequest request = requestFactory.buildGetRequest(new GenericUrl(url)); - - if (sessionToken != null) { - HttpHeaders headers = request.getHeaders(); - headers.set("X-aws-ec2-metadata-token", sessionToken); - } - - HttpResponse response = request.execute(); - return response.parseAsString(); - } catch (IOException e) { - throw new IOException(String.format("Failed to retrieve AWS %s.", resourceName), e); - } + return retrieveResource(url, resourceName, HttpMethods.GET, headers, /*content =*/ null); } - private String getAwsSessionToken(String awsSessionTokenUrl) throws IOException { + private String retrieveResource( + String url, + String resourceName, + String requestMethod, + Map headers, + @Nullable HttpContent content) + throws IOException { try { HttpRequestFactory requestFactory = transportFactory.create().createRequestFactory(); HttpRequest request = - requestFactory.buildPutRequest(new GenericUrl(awsSessionTokenUrl), null); - HttpHeaders headers = request.getHeaders(); - headers.set("X-aws-ec2-metadata-token-ttl-seconds", "21600"); + requestFactory.buildRequest(requestMethod, new GenericUrl(url), content); + + HttpHeaders httpHeaders = request.getHeaders(); + for (Map.Entry header : headers.entrySet()) { + httpHeaders.set(header.getKey(), header.getValue()); + } HttpResponse response = request.execute(); return response.parseAsString(); } catch (IOException e) { - throw new IOException(String.format("Failed to fetch AWS Session Token"), e); + throw new IOException(String.format("Failed to retrieve AWS %s.", resourceName), e); } } @@ -236,7 +253,7 @@ private String buildSubjectToken(AwsRequestSignature signature) } @VisibleForTesting - String getAwsRegion(String awsSessionToken) throws IOException { + String getAwsRegion(Map metadataHeaders) throws IOException { // For AWS Lambda, the region is retrieved through the AWS_REGION environment variable. String region = getEnvironmentProvider().getEnv("AWS_REGION"); if (region != null) { @@ -253,7 +270,7 @@ String getAwsRegion(String awsSessionToken) throws IOException { "Unable to determine the AWS region. The credential source does not contain the region URL."); } - region = retrieveResource(awsCredentialSource.regionUrl, "region", awsSessionToken); + region = retrieveResource(awsCredentialSource.regionUrl, "region", metadataHeaders); // There is an extra appended character that must be removed. If `us-east-1b` is returned, // we want `us-east-1`. @@ -261,7 +278,8 @@ String getAwsRegion(String awsSessionToken) throws IOException { } @VisibleForTesting - AwsSecurityCredentials getAwsSecurityCredentials(String awsSessionToken) throws IOException { + AwsSecurityCredentials getAwsSecurityCredentials(Map metadataHeaders) + throws IOException { // Check environment variables for credentials first. String accessKeyId = getEnvironmentProvider().getEnv("AWS_ACCESS_KEY_ID"); String secretAccessKey = getEnvironmentProvider().getEnv("AWS_SECRET_ACCESS_KEY"); @@ -278,12 +296,12 @@ AwsSecurityCredentials getAwsSecurityCredentials(String awsSessionToken) throws "Unable to determine the AWS IAM role name. The credential source does not contain the" + " url field."); } - String roleName = retrieveResource(awsCredentialSource.url, "IAM role", awsSessionToken); + String roleName = retrieveResource(awsCredentialSource.url, "IAM role", metadataHeaders); // Retrieve the AWS security credentials by calling the endpoint specified by the credential // source. String awsCredentials = - retrieveResource(awsCredentialSource.url + "/" + roleName, "credentials", awsSessionToken); + retrieveResource(awsCredentialSource.url + "/" + roleName, "credentials", metadataHeaders); JsonParser parser = OAuth2Utils.JSON_FACTORY.createJsonParser(awsCredentials); GenericJson genericJson = parser.parseAndClose(GenericJson.class); diff --git a/oauth2_http/javatests/com/google/auth/oauth2/AwsCredentialsTest.java b/oauth2_http/javatests/com/google/auth/oauth2/AwsCredentialsTest.java index 87e7a71f4..2aeb78378 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/AwsCredentialsTest.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/AwsCredentialsTest.java @@ -39,6 +39,7 @@ import com.google.api.client.json.GenericJson; import com.google.api.client.json.JsonParser; +import com.google.api.client.testing.http.MockLowLevelHttpRequest; import com.google.auth.TestUtils; import com.google.auth.oauth2.AwsCredentials.AwsCredentialSource; import com.google.auth.oauth2.ExternalAccountCredentialsTest.MockExternalAccountCredentialsTransportFactory; @@ -47,6 +48,7 @@ import java.net.URI; import java.net.URLDecoder; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -56,6 +58,13 @@ class AwsCredentialsTest { private static final String STS_URL = "https://sts.googleapis.com"; + private static final String AWS_CREDENTIALS_URL = "https://www.aws-credentials.com"; + private static final String AWS_CREDENTIALS_URL_WITH_ROLE = + "https://www.aws-credentials.com/roleName"; + private static final String AWS_REGION_URL = "https://www.aws-region.com"; + private static final String AWS_SESSION_TOKEN_URL = "https://www.aws-session-token.com"; + private static final String AWS_SESSION_TOKEN = "sessiontoken"; + private static final String AWS_SESSION_TOKEN_HEADER = "x-aws-ec2-metadata-token"; private static final String GET_CALLER_IDENTITY_URL = "https://sts.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15"; @@ -73,6 +82,9 @@ class AwsCredentialsTest { } }; + private static final Map emptyMetadataHeaders = Collections.emptyMap(); + private static final Map emptyStringHeaders = Collections.emptyMap(); + private static final AwsCredentialSource AWS_CREDENTIAL_SOURCE = new AwsCredentialSource(AWS_CREDENTIAL_SOURCE_MAP); @@ -97,7 +109,7 @@ void refreshAccessToken_withoutServiceAccountImpersonation() throws IOException AwsCredentials.newBuilder(AWS_CREDENTIAL) .setTokenUrl(transportFactory.transport.getStsUrl()) .setHttpTransportFactory(transportFactory) - .setCredentialSource(buildAwsCredentialSource(transportFactory, false)) + .setCredentialSource(buildAwsCredentialSource(transportFactory)) .build(); AccessToken accessToken = awsCredential.refreshAccessToken(); @@ -119,7 +131,7 @@ void refreshAccessToken_withServiceAccountImpersonation() throws IOException { .setServiceAccountImpersonationUrl( transportFactory.transport.getServiceAccountImpersonationUrl()) .setHttpTransportFactory(transportFactory) - .setCredentialSource(buildAwsCredentialSource(transportFactory, false)) + .setCredentialSource(buildAwsCredentialSource(transportFactory)) .build(); AccessToken accessToken = awsCredential.refreshAccessToken(); @@ -137,7 +149,7 @@ void retrieveSubjectToken() throws IOException { (AwsCredentials) AwsCredentials.newBuilder(AWS_CREDENTIAL) .setHttpTransportFactory(transportFactory) - .setCredentialSource(buildAwsCredentialSource(transportFactory, false)) + .setCredentialSource(buildAwsCredentialSource(transportFactory)) .build(); String subjectToken = URLDecoder.decode(awsCredential.retrieveSubjectToken(), "UTF-8"); @@ -158,6 +170,18 @@ void retrieveSubjectToken() throws IOException { assertEquals(awsCredential.getAudience(), headers.get("x-goog-cloud-target-resource")); assertTrue(headers.containsKey("x-amz-date")); assertNotNull(headers.get("Authorization")); + + List requests = transportFactory.transport.getRequests(); + assertEquals(3, requests.size()); + + // Validate region request. + ValidateRequest(requests.get(0), AWS_REGION_URL, emptyStringHeaders); + + // Validate role request. + ValidateRequest(requests.get(1), AWS_CREDENTIALS_URL, emptyStringHeaders); + + // Validate security credentials request. + ValidateRequest(requests.get(2), AWS_CREDENTIALS_URL_WITH_ROLE, emptyStringHeaders); } @Test @@ -165,14 +189,19 @@ void retrieveSubjectTokenWithSessionTokenUrl() throws IOException { MockExternalAccountCredentialsTransportFactory transportFactory = new MockExternalAccountCredentialsTransportFactory(); - // Map credentialMap = new HashMap<>(AWS_CREDENTIAL_SOURCE_MAP); - // credentialMap.put("aws_session_token_url", "awsSesionTokenUrl") + Map credentialSourceMap = new HashMap<>(); + credentialSourceMap.put("environment_id", "aws1"); + credentialSourceMap.put("region_url", transportFactory.transport.getAwsRegionUrl()); + credentialSourceMap.put("url", transportFactory.transport.getAwsCredentialsUrl()); + credentialSourceMap.put("regional_cred_verification_url", GET_CALLER_IDENTITY_URL); + credentialSourceMap.put( + "aws_session_token_url", transportFactory.transport.getAwsSessionTokenUrl()); AwsCredentials awsCredential = (AwsCredentials) AwsCredentials.newBuilder(AWS_CREDENTIAL) .setHttpTransportFactory(transportFactory) - .setCredentialSource(buildAwsCredentialSource(transportFactory, true)) + .setCredentialSource(new AwsCredentialSource(credentialSourceMap)) .build(); String subjectToken = URLDecoder.decode(awsCredential.retrieveSubjectToken(), "UTF-8"); @@ -193,6 +222,29 @@ void retrieveSubjectTokenWithSessionTokenUrl() throws IOException { assertEquals(awsCredential.getAudience(), headers.get("x-goog-cloud-target-resource")); assertTrue(headers.containsKey("x-amz-date")); assertNotNull(headers.get("Authorization")); + + List requests = transportFactory.transport.getRequests(); + assertEquals(4, requests.size()); + + // Validate the session token request + ValidateRequest( + requests.get(0), + AWS_SESSION_TOKEN_URL, + Map.of("x-aws-ec2-metadata-token-ttl-seconds", "21600")); + + // Validate region request. + ValidateRequest( + requests.get(1), AWS_REGION_URL, Map.of(AWS_SESSION_TOKEN_HEADER, AWS_SESSION_TOKEN)); + + // Validate role request. + ValidateRequest( + requests.get(2), AWS_CREDENTIALS_URL, Map.of(AWS_SESSION_TOKEN_HEADER, AWS_SESSION_TOKEN)); + + // Validate security credentials request. + ValidateRequest( + requests.get(3), + AWS_CREDENTIALS_URL_WITH_ROLE, + Map.of(AWS_SESSION_TOKEN_HEADER, AWS_SESSION_TOKEN)); } @Test @@ -207,13 +259,19 @@ void retrieveSubjectToken_noRegion_expectThrows() { (AwsCredentials) AwsCredentials.newBuilder(AWS_CREDENTIAL) .setHttpTransportFactory(transportFactory) - .setCredentialSource(buildAwsCredentialSource(transportFactory, false)) + .setCredentialSource(buildAwsCredentialSource(transportFactory)) .build(); IOException exception = assertThrows( IOException.class, awsCredential::retrieveSubjectToken, "Exception should be thrown."); assertEquals("Failed to retrieve AWS region.", exception.getMessage()); + + List requests = transportFactory.transport.getRequests(); + assertEquals(1, requests.size()); + + // Validate region request. + ValidateRequest(requests.get(0), AWS_REGION_URL, emptyStringHeaders); } @Test @@ -229,13 +287,22 @@ void retrieveSubjectToken_noRole_expectThrows() { (AwsCredentials) AwsCredentials.newBuilder(AWS_CREDENTIAL) .setHttpTransportFactory(transportFactory) - .setCredentialSource(buildAwsCredentialSource(transportFactory, false)) + .setCredentialSource(buildAwsCredentialSource(transportFactory)) .build(); IOException exception = assertThrows( IOException.class, awsCredential::retrieveSubjectToken, "Exception should be thrown."); assertEquals("Failed to retrieve AWS IAM role.", exception.getMessage()); + + List requests = transportFactory.transport.getRequests(); + assertEquals(2, requests.size()); + + // Validate region request. + ValidateRequest(requests.get(0), AWS_REGION_URL, emptyStringHeaders); + + // Validate role request. + ValidateRequest(requests.get(1), AWS_CREDENTIALS_URL, emptyStringHeaders); } @Test @@ -251,13 +318,25 @@ void retrieveSubjectToken_noCredentials_expectThrows() { (AwsCredentials) AwsCredentials.newBuilder(AWS_CREDENTIAL) .setHttpTransportFactory(transportFactory) - .setCredentialSource(buildAwsCredentialSource(transportFactory, false)) + .setCredentialSource(buildAwsCredentialSource(transportFactory)) .build(); IOException exception = assertThrows( IOException.class, awsCredential::retrieveSubjectToken, "Exception should be thrown."); assertEquals("Failed to retrieve AWS credentials.", exception.getMessage()); + + List requests = transportFactory.transport.getRequests(); + assertEquals(3, requests.size()); + + // Validate region request. + ValidateRequest(requests.get(0), AWS_REGION_URL, emptyStringHeaders); + + // Validate role request. + ValidateRequest(requests.get(1), AWS_CREDENTIALS_URL, emptyStringHeaders); + + // Validate security credentials request. + ValidateRequest(requests.get(2), AWS_CREDENTIALS_URL_WITH_ROLE, emptyStringHeaders); } @Test @@ -283,6 +362,10 @@ void retrieveSubjectToken_noRegionUrlProvided() { "Unable to determine the AWS region. The credential source does not " + "contain the region URL.", exception.getMessage()); + + // No requests because the credential source does not contain region URL + List requests = transportFactory.transport.getRequests(); + assertEquals(true, requests.isEmpty()); } @Test @@ -298,7 +381,8 @@ void getAwsSecurityCredentials_fromEnvironmentVariablesNoToken() throws IOExcept .setEnvironmentProvider(environmentProvider) .build(); - AwsSecurityCredentials credentials = testAwsCredentials.getAwsSecurityCredentials(null); + AwsSecurityCredentials credentials = + testAwsCredentials.getAwsSecurityCredentials(emptyMetadataHeaders); assertEquals("awsAccessKeyId", credentials.getAccessKeyId()); assertEquals("awsSecretAccessKey", credentials.getSecretAccessKey()); @@ -319,7 +403,8 @@ void getAwsSecurityCredentials_fromEnvironmentVariablesWithToken() throws IOExce .setEnvironmentProvider(environmentProvider) .build(); - AwsSecurityCredentials credentials = testAwsCredentials.getAwsSecurityCredentials(null); + AwsSecurityCredentials credentials = + testAwsCredentials.getAwsSecurityCredentials(emptyMetadataHeaders); assertEquals("awsAccessKeyId", credentials.getAccessKeyId()); assertEquals("awsSecretAccessKey", credentials.getSecretAccessKey()); @@ -335,14 +420,24 @@ void getAwsSecurityCredentials_fromMetadataServer() throws IOException { (AwsCredentials) AwsCredentials.newBuilder(AWS_CREDENTIAL) .setHttpTransportFactory(transportFactory) - .setCredentialSource(buildAwsCredentialSource(transportFactory, false)) + .setCredentialSource(buildAwsCredentialSource(transportFactory)) .build(); - AwsSecurityCredentials credentials = awsCredential.getAwsSecurityCredentials(null); + AwsSecurityCredentials credentials = + awsCredential.getAwsSecurityCredentials(emptyMetadataHeaders); assertEquals("accessKeyId", credentials.getAccessKeyId()); assertEquals("secretAccessKey", credentials.getSecretAccessKey()); assertEquals("token", credentials.getToken()); + + List requests = transportFactory.transport.getRequests(); + assertEquals(2, requests.size()); + + // Validate role request. + ValidateRequest(requests.get(0), AWS_CREDENTIALS_URL, emptyStringHeaders); + + // Validate security credentials request. + ValidateRequest(requests.get(1), AWS_CREDENTIALS_URL_WITH_ROLE, emptyStringHeaders); } @Test @@ -365,12 +460,16 @@ void getAwsSecurityCredentials_fromMetadataServer_noUrlProvided() { assertThrows( IOException.class, () -> { - awsCredential.getAwsSecurityCredentials(null); + awsCredential.getAwsSecurityCredentials(emptyMetadataHeaders); }, "Exception should be thrown."); assertEquals( "Unable to determine the AWS IAM role name. The credential source does not contain the url field.", exception.getMessage()); + + // No requests because url field is not present in credential source + List requests = transportFactory.transport.getRequests(); + assertEquals(true, requests.isEmpty()); } @Test @@ -385,15 +484,19 @@ void getAwsRegion_awsRegionEnvironmentVariable() throws IOException { (AwsCredentials) AwsCredentials.newBuilder(AWS_CREDENTIAL) .setHttpTransportFactory(transportFactory) - .setCredentialSource(buildAwsCredentialSource(transportFactory, false)) + .setCredentialSource(buildAwsCredentialSource(transportFactory)) .setEnvironmentProvider(environmentProvider) .build(); - String region = awsCredentials.getAwsRegion(null); + String region = awsCredentials.getAwsRegion(emptyMetadataHeaders); // Should attempt to retrieve the region from AWS_REGION env var first. // Metadata server would return us-east-1b. assertEquals("region", region); + + // No requests because region is obtained from environment variables + List requests = transportFactory.transport.getRequests(); + assertEquals(true, requests.isEmpty()); } @Test @@ -407,15 +510,19 @@ void getAwsRegion_awsDefaultRegionEnvironmentVariable() throws IOException { (AwsCredentials) AwsCredentials.newBuilder(AWS_CREDENTIAL) .setHttpTransportFactory(transportFactory) - .setCredentialSource(buildAwsCredentialSource(transportFactory, false)) + .setCredentialSource(buildAwsCredentialSource(transportFactory)) .setEnvironmentProvider(environmentProvider) .build(); - String region = awsCredentials.getAwsRegion(null); + String region = awsCredentials.getAwsRegion(emptyMetadataHeaders); // Should attempt to retrieve the region from DEFAULT_AWS_REGION before calling the metadata // server. Metadata server would return us-east-1b. assertEquals("defaultRegion", region); + + // No requests because region is obtained from environment variables + List requests = transportFactory.transport.getRequests(); + assertEquals(true, requests.isEmpty()); } @Test @@ -426,10 +533,10 @@ void getAwsRegion_metadataServer() throws IOException { (AwsCredentials) AwsCredentials.newBuilder(AWS_CREDENTIAL) .setHttpTransportFactory(transportFactory) - .setCredentialSource(buildAwsCredentialSource(transportFactory, false)) + .setCredentialSource(buildAwsCredentialSource(transportFactory)) .build(); - String region = awsCredentials.getAwsRegion(null); + String region = awsCredentials.getAwsRegion(emptyMetadataHeaders); // Should retrieve the region from the Metadata server. String expectedRegion = @@ -438,6 +545,12 @@ void getAwsRegion_metadataServer() throws IOException { .getAwsRegion() .substring(0, transportFactory.transport.getAwsRegion().length() - 1); assertEquals(expectedRegion, region); + + List requests = transportFactory.transport.getRequests(); + assertEquals(1, requests.size()); + + // Validate region request. + ValidateRequest(requests.get(0), AWS_REGION_URL, emptyStringHeaders); } @Test @@ -547,20 +660,27 @@ public void builder() { assertEquals(credentials.getEnvironmentProvider(), SystemEnvironmentProvider.getInstance()); } + private static void ValidateRequest( + MockLowLevelHttpRequest request, String expectedUrl, Map expectedHeaders) { + assertEquals(expectedUrl, request.getUrl()); + Map> actualHeaders = request.getHeaders(); + + for (Map.Entry expectedHeader : expectedHeaders.entrySet()) { + assertEquals(true, actualHeaders.containsKey(expectedHeader.getKey())); + List actualValues = actualHeaders.get(expectedHeader.getKey()); + assertEquals(1, actualValues.size()); + assertEquals(expectedHeader.getValue(), actualValues.get(0)); + } + } + private static AwsCredentialSource buildAwsCredentialSource( - MockExternalAccountCredentialsTransportFactory transportFactory, - Boolean includeAwsSessionTokenUrl) { + MockExternalAccountCredentialsTransportFactory transportFactory) { Map credentialSourceMap = new HashMap<>(); credentialSourceMap.put("environment_id", "aws1"); credentialSourceMap.put("region_url", transportFactory.transport.getAwsRegionUrl()); credentialSourceMap.put("url", transportFactory.transport.getAwsCredentialsUrl()); credentialSourceMap.put("regional_cred_verification_url", GET_CALLER_IDENTITY_URL); - if (includeAwsSessionTokenUrl) { - credentialSourceMap.put( - "aws_session_token_url", transportFactory.transport.getAwsSessionTokenUrl()); - } - return new AwsCredentialSource(credentialSourceMap); } diff --git a/oauth2_http/javatests/com/google/auth/oauth2/ExternalAccountCredentialsTest.java b/oauth2_http/javatests/com/google/auth/oauth2/ExternalAccountCredentialsTest.java index c59560f56..fb94bb93d 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/ExternalAccountCredentialsTest.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/ExternalAccountCredentialsTest.java @@ -411,7 +411,7 @@ void exchangeExternalCredentialForAccessToken() throws IOException { // Validate no internal options set. Map query = - TestUtils.parseQuery(transportFactory.transport.getRequest().getContentAsString()); + TestUtils.parseQuery(transportFactory.transport.getLastRequest().getContentAsString()); assertNull(query.get("options")); } @@ -435,7 +435,7 @@ void exchangeExternalCredentialForAccessToken_withInternalOptions() throws IOExc // Validate internal options set. Map query = - TestUtils.parseQuery(transportFactory.transport.getRequest().getContentAsString()); + TestUtils.parseQuery(transportFactory.transport.getLastRequest().getContentAsString()); assertNotNull(query.get("options")); assertEquals(internalOptions.toString(), query.get("options")); } @@ -457,7 +457,7 @@ void exchangeExternalCredentialForAccessToken_workforceCred_expectUserProjectPas // Validate internal options set. Map query = - TestUtils.parseQuery(transportFactory.transport.getRequest().getContentAsString()); + TestUtils.parseQuery(transportFactory.transport.getLastRequest().getContentAsString()); GenericJson internalOptions = new GenericJson(); internalOptions.setFactory(OAuth2Utils.JSON_FACTORY); internalOptions.put("userProject", "userProject"); @@ -485,7 +485,7 @@ void exchangeExternalCredentialForAccessToken_workforceCredWithInternalOptions_e // Validate internal options set. Map query = - TestUtils.parseQuery(transportFactory.transport.getRequest().getContentAsString()); + TestUtils.parseQuery(transportFactory.transport.getLastRequest().getContentAsString()); assertNotNull(query.get("options")); assertEquals(internalOptions.toString(), query.get("options")); } diff --git a/oauth2_http/javatests/com/google/auth/oauth2/IdentityPoolCredentialsTest.java b/oauth2_http/javatests/com/google/auth/oauth2/IdentityPoolCredentialsTest.java index 5f8dc3ca0..73a3d2f12 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/IdentityPoolCredentialsTest.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/IdentityPoolCredentialsTest.java @@ -346,7 +346,7 @@ void refreshAccessToken_internalOptionsSet() throws IOException { // If the IdentityPoolCredential is initialized with a userProject, it must be passed // to STS via internal options. Map query = - TestUtils.parseQuery(transportFactory.transport.getRequest().getContentAsString()); + TestUtils.parseQuery(transportFactory.transport.getLastRequest().getContentAsString()); assertNotNull(query.get("options")); GenericJson expectedInternalOptions = new GenericJson(); diff --git a/oauth2_http/javatests/com/google/auth/oauth2/MockExternalAccountCredentialsTransport.java b/oauth2_http/javatests/com/google/auth/oauth2/MockExternalAccountCredentialsTransport.java index b3fae4895..bf40410f6 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/MockExternalAccountCredentialsTransport.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/MockExternalAccountCredentialsTransport.java @@ -74,6 +74,7 @@ public class MockExternalAccountCredentialsTransport extends MockHttpTransport { private static final String SUBJECT_TOKEN = "subjectToken"; private static final String TOKEN_TYPE = "Bearer"; private static final String ACCESS_TOKEN = "accessToken"; + private static final String AWS_SESSION_TOKEN = "sessiontoken"; private static final String SERVICE_ACCOUNT_ACCESS_TOKEN = "serviceAccountAccessToken"; private static final String AWS_REGION = "us-east-1b"; private static final Long EXPIRES_IN = 3600L; @@ -87,7 +88,7 @@ public class MockExternalAccountCredentialsTransport extends MockHttpTransport { private Queue responseErrorSequence = new ArrayDeque<>(); private Queue refreshTokenSequence = new ArrayDeque<>(); private Queue> scopeSequence = new ArrayDeque<>(); - private MockLowLevelHttpRequest request; + private List requests = new ArrayList<>(); private String expireTime; private String metadataServerContentType; private String stsContent; @@ -110,7 +111,7 @@ public void addScopeSequence(List... scopes) { @Override public LowLevelHttpRequest buildRequest(final String method, final String url) { - this.request = + MockLowLevelHttpRequest request = new MockLowLevelHttpRequest(url) { @Override public LowLevelHttpResponse execute() throws IOException { @@ -123,7 +124,7 @@ public LowLevelHttpResponse execute() throws IOException { if (AWS_SESSION_TOKEN_URL.equals(url)) { return new MockLowLevelHttpResponse() .setContentType("text/html") - .setContent("sessiontoken"); + .setContent(AWS_SESSION_TOKEN); } if (AWS_REGION_URL.equals(url)) { return new MockLowLevelHttpResponse() @@ -214,15 +215,25 @@ public LowLevelHttpResponse execute() throws IOException { return null; } }; - return this.request; + + this.requests.add(request); + return request; } public String getStsContent() { return stsContent; } - public MockLowLevelHttpRequest getRequest() { - return request; + public MockLowLevelHttpRequest getLastRequest() { + if (requests.isEmpty()) { + return null; + } + + return requests.get(requests.size() - 1); + } + + public List getRequests() { + return Collections.unmodifiableList(requests); } public String getTokenType() {