Skip to content

Commit

Permalink
feat: allow custom scopes for compute engine creds (#514)
Browse files Browse the repository at this point in the history
* feat: allow custom scopes for compute engine creds

* update
  • Loading branch information
arithmetic1728 authored Dec 16, 2020
1 parent 5e49463 commit edc8d6e
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,17 @@
import com.google.auth.ServiceAccountSigner;
import com.google.auth.http.HttpTransportFactory;
import com.google.common.annotations.Beta;
import com.google.common.base.Joiner;
import com.google.common.base.MoreObjects;
import com.google.common.collect.ImmutableSet;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.net.SocketTimeoutException;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.List;
Expand Down Expand Up @@ -94,6 +99,8 @@ public class ComputeEngineCredentials extends GoogleCredentials

private final String transportFactoryClassName;

private final Collection<String> scopes;

private transient HttpTransportFactory transportFactory;
private transient String serviceAccountEmail;

Expand All @@ -102,13 +109,28 @@ public class ComputeEngineCredentials extends GoogleCredentials
*
* @param transportFactory HTTP transport factory, creates the transport used to get access
* tokens.
* @param scopes scope strings for the APIs to be called. May be null or an empty collection.
*/
private ComputeEngineCredentials(HttpTransportFactory transportFactory) {
private ComputeEngineCredentials(
HttpTransportFactory transportFactory, Collection<String> scopes) {
this.transportFactory =
firstNonNull(
transportFactory,
getFromServiceLoader(HttpTransportFactory.class, OAuth2Utils.HTTP_TRANSPORT_FACTORY));
this.transportFactoryClassName = this.transportFactory.getClass().getName();
if (scopes == null) {
this.scopes = ImmutableSet.<String>of();
} else {
List<String> scopeList = new ArrayList<String>(scopes);
scopeList.removeAll(Arrays.asList("", null));
this.scopes = ImmutableSet.<String>copyOf(scopeList);
}
}

/** Clones the compute engine account with the specified scopes. */
@Override
public GoogleCredentials createScoped(Collection<String> newScopes) {
return new ComputeEngineCredentials(this.transportFactory, newScopes);
}

/**
Expand All @@ -117,13 +139,30 @@ private ComputeEngineCredentials(HttpTransportFactory transportFactory) {
* @return new ComputeEngineCredentials
*/
public static ComputeEngineCredentials create() {
return new ComputeEngineCredentials(null);
return new ComputeEngineCredentials(null, null);
}

public final Collection<String> getScopes() {
return scopes;
}

/**
* If scopes is specified, add "?scopes=comma-separated-list-of-scopes" to the token url.
*
* @return token url with the given scopes
*/
String createTokenUrlWithScopes() {
GenericUrl tokenUrl = new GenericUrl(getTokenServerEncodedUrl());
if (!scopes.isEmpty()) {
tokenUrl.set("scopes", Joiner.on(',').join(scopes));
}
return tokenUrl.toString();
}

/** Refresh the access token by getting it from the GCE metadata server */
@Override
public AccessToken refreshAccessToken() throws IOException {
HttpResponse response = getMetadataResponse(getTokenServerEncodedUrl());
HttpResponse response = getMetadataResponse(createTokenUrlWithScopes());
int statusCode = response.getStatusCode();
if (statusCode == HttpStatusCodes.STATUS_CODE_NOT_FOUND) {
throw new IOException(
Expand Down Expand Up @@ -307,7 +346,8 @@ public boolean equals(Object obj) {
return false;
}
ComputeEngineCredentials other = (ComputeEngineCredentials) obj;
return Objects.equals(this.transportFactoryClassName, other.transportFactoryClassName);
return Objects.equals(this.transportFactoryClassName, other.transportFactoryClassName)
&& Objects.equals(this.scopes, other.scopes);
}

private void readObject(ObjectInputStream input) throws IOException, ClassNotFoundException {
Expand Down Expand Up @@ -399,24 +439,35 @@ private String getDefaultServiceAccount() throws IOException {

public static class Builder extends GoogleCredentials.Builder {
private HttpTransportFactory transportFactory;
private Collection<String> scopes;

protected Builder() {}

protected Builder(ComputeEngineCredentials credentials) {
this.transportFactory = credentials.transportFactory;
this.scopes = credentials.scopes;
}

public Builder setHttpTransportFactory(HttpTransportFactory transportFactory) {
this.transportFactory = transportFactory;
return this;
}

public Builder setScopes(Collection<String> scopes) {
this.scopes = scopes;
return this;
}

public HttpTransportFactory getHttpTransportFactory() {
return transportFactory;
}

public Collection<String> getScopes() {
return scopes;
}

public ComputeEngineCredentials build() {
return new ComputeEngineCredentials(transportFactory);
return new ComputeEngineCredentials(transportFactory, scopes);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
import java.io.IOException;
import java.net.URI;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.junit.Test;
Expand All @@ -68,6 +70,9 @@ public class ComputeEngineCredentialsTest extends BaseSerializationTest {

private static final URI CALL_URI = URI.create("http://googleapis.com/testapi/v1/foo");

private static final String TOKEN_URL =
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token";

// Id Token which includes basic default claims
public static final String STANDARD_ID_TOKEN =
"eyJhbGciOiJSUzI1NiIsImtpZCI6ImRmMzc1ODkwOGI3OTIyO"
Expand Down Expand Up @@ -113,6 +118,69 @@ public HttpTransport create() {
}
}

@Test
public void createTokenUrlWithScopes_null_scopes() {
ComputeEngineCredentials credentials =
ComputeEngineCredentials.newBuilder().setScopes(null).build();
Collection<String> scopes = credentials.getScopes();
String tokenUrlWithScopes = credentials.createTokenUrlWithScopes();

assertEquals(TOKEN_URL, tokenUrlWithScopes);
assertTrue(scopes.isEmpty());
}

@Test
public void createTokenUrlWithScopes_empty_scopes() {
ComputeEngineCredentials.Builder builder =
ComputeEngineCredentials.newBuilder().setScopes(Collections.<String>emptyList());
ComputeEngineCredentials credentials = builder.build();
Collection<String> scopes = credentials.getScopes();
String tokenUrlWithScopes = credentials.createTokenUrlWithScopes();

assertEquals(TOKEN_URL, tokenUrlWithScopes);
assertTrue(scopes.isEmpty());
assertTrue(builder.getScopes().isEmpty());
}

@Test
public void createTokenUrlWithScopes_single_scope() {
ComputeEngineCredentials credentials =
ComputeEngineCredentials.newBuilder().setScopes(Arrays.asList("foo")).build();
String tokenUrlWithScopes = credentials.createTokenUrlWithScopes();
Collection<String> scopes = credentials.getScopes();

assertEquals(TOKEN_URL + "?scopes=foo", tokenUrlWithScopes);
assertEquals(1, scopes.size());
assertEquals("foo", scopes.toArray()[0]);
}

@Test
public void createTokenUrlWithScopes_multiple_scopes() {
ComputeEngineCredentials credentials =
ComputeEngineCredentials.newBuilder()
.setScopes(Arrays.asList(null, "foo", "", "bar"))
.build();
Collection<String> scopes = credentials.getScopes();
String tokenUrlWithScopes = credentials.createTokenUrlWithScopes();

assertEquals(TOKEN_URL + "?scopes=foo,bar", tokenUrlWithScopes);
assertEquals(2, scopes.size());
assertEquals("foo", scopes.toArray()[0]);
assertEquals("bar", scopes.toArray()[1]);
}

@Test
public void createScoped() {
ComputeEngineCredentials credentials =
ComputeEngineCredentials.newBuilder().setScopes(null).build();
ComputeEngineCredentials credentialsWithScopes =
(ComputeEngineCredentials) credentials.createScoped(Arrays.asList("foo"));
Collection<String> scopes = credentialsWithScopes.getScopes();

assertEquals(1, scopes.size());
assertEquals("foo", scopes.toArray()[0]);
}

@Test
public void getRequestMetadata_hasAccessToken() throws IOException {
String accessToken = "1/MkSJoj1xsli0AccessToken_NKPY2";
Expand Down

0 comments on commit edc8d6e

Please sign in to comment.