Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: downscope credentials used for IAM AuthN login #999

Merged
merged 15 commits into from
Nov 2, 2022
10 changes: 8 additions & 2 deletions core/src/main/java/com/google/cloud/sql/CredentialFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,16 @@

import com.google.api.client.auth.oauth2.Credential;
import com.google.api.client.http.HttpRequestInitializer;
import java.io.IOException;

/** Factory for creating {@link Credential}s for interaction with Cloud SQL Admin API. */
/**
* Factory for creating {@link Credential}s for interaction with Cloud SQL Admin API.
*/
public interface CredentialFactory {
/** Name of system property that can specify an alternative credential factory. */

/**
* Name of system property that can specify an alternative credential factory.
*/
String CREDENTIAL_FACTORY_PROPERTY = "cloudSql.socketFactory.credentialFactory";

HttpRequestInitializer create();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.api.services.sqladmin.model.GenerateEphemeralCertResponse;
import com.google.api.services.sqladmin.model.IpMapping;
import com.google.auth.http.HttpCredentialsAdapter;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.auth.oauth2.OAuth2Credentials;
import com.google.cloud.sql.CredentialFactory;
import com.google.common.base.CharMatcher;
Expand Down Expand Up @@ -81,6 +82,7 @@
*/
class CloudSqlInstance {

private static final String SQL_LOGIN_SCOPE = "https://www.googleapis.com/auth/sqlservice.login";
private static final Logger logger = Logger.getLogger(CloudSqlInstance.class.getName());

// Unique identifier for each Cloud SQL instance in the format "PROJECT:REGION:INSTANCE"
Expand Down Expand Up @@ -131,7 +133,7 @@ class CloudSqlInstance {
boolean enableIamAuth,
CredentialFactory tokenSourceFactory,
ListeningScheduledExecutorService executor,
ListenableFuture<KeyPair> keyPair) {
ListenableFuture<KeyPair> keyPair) throws IOException {

Matcher matcher = CONNECTION_NAME.matcher(connectionName);
checkArgument(
Expand All @@ -155,6 +157,7 @@ class CloudSqlInstance {
HttpCredentialsAdapter credentialsAdapter = (HttpCredentialsAdapter) tokenSourceFactory
.create();
this.credentials = Optional.of((OAuth2Credentials) credentialsAdapter.getCredentials());
this.credentials.get().refresh();
} else {
this.credentials = Optional.empty();
}
Expand Down Expand Up @@ -277,10 +280,10 @@ SSLSocket createSslSocket() throws IOException {
* preferredTypes.
*
* @param preferredTypes Preferred instance IP types to use. Valid IP types include "Public" and
* "Private".
* "Private".
* @return returns a string representing the IP address for the instance
* @throws IllegalArgumentException If the instance has no IP addresses matching the provided
* preferences.
* preferences.
*/
String getPreferredIp(List<String> preferredTypes) {
Map<String, String> ipAddrs = getInstanceData().getIpAddrs();
Expand Down Expand Up @@ -525,8 +528,9 @@ private Certificate fetchEphemeralCertificate(KeyPair keyPair) {

if (enableIamAuth) {
try {
credentials.get().refresh();
String token = credentials.get().getAccessToken().getTokenValue();
GoogleCredentials downscoped = getDownscopedCredentials(credentials.get());
downscoped.refresh();
String token = downscoped.getAccessToken().getTokenValue();
// TODO: remove this once issue with OAuth2 Tokens is resolved.
// See: https://github.com/GoogleCloudPlatform/cloud-sql-jdbc-socket-factory/issues/565
request.setAccessToken(CharMatcher.is('.').trimTrailingFrom(token));
Expand Down Expand Up @@ -563,6 +567,19 @@ private Certificate fetchEphemeralCertificate(KeyPair keyPair) {
return ephemeralCertificate;
}

static GoogleCredentials getDownscopedCredentials(OAuth2Credentials credentials) {
GoogleCredentials downscoped;
try {
GoogleCredentials oldCredentials = (GoogleCredentials) credentials;
downscoped = oldCredentials.createScoped(SQL_LOGIN_SCOPE);
} catch (ClassCastException ex) {
throw new RuntimeException(
"Failed to downscope credentials for IAM Authentication:",
ex);
}
return downscoped;
}

private Date getTokenExpirationTime() {
return credentials.get().getAccessToken().getExpirationTime();
}
Expand Down Expand Up @@ -590,7 +607,7 @@ private long secondsUntilRefresh() {
*
* @param ex exception thrown by the Admin API request
* @param fallbackDesc generic description used as a fallback if no additional information can be
* provided to the user
* provided to the user
*/
private RuntimeException addExceptionContext(IOException ex, String fallbackDesc) {
// Verify we are able to extract a reason from an exception, or fallback to a generic desc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public final class CoreSocketFactory {
* Property used to set the application name for the underlying SQLAdmin client.
*
* @deprecated Use {@link #setApplicationName(String)} to set the application name
* programmatically.
* programmatically.
*/

@Deprecated
Expand Down Expand Up @@ -117,7 +117,7 @@ public final class CoreSocketFactory {
/**
* Returns the {@link CoreSocketFactory} singleton.
*/
public static synchronized CoreSocketFactory getInstance() {
public static synchronized CoreSocketFactory getInstance() throws IOException {
if (coreSocketFactory == null) {
logger.info("First Cloud SQL connection, generating RSA key pair.");

Expand Down Expand Up @@ -155,14 +155,27 @@ public static synchronized CoreSocketFactory getInstance() {
private CloudSqlInstance getCloudSqlInstance(String instanceName, boolean enableIamAuth) {
return instances.computeIfAbsent(
instanceName,
k -> new CloudSqlInstance(k, adminApi, enableIamAuth, credentialFactory, executor,
localKeyPair));
k -> {
try {
return new CloudSqlInstance(k, adminApi, enableIamAuth, credentialFactory, executor,
localKeyPair);
} catch (IOException e) {
throw new RuntimeException(e);
}
});
}

private CloudSqlInstance getCloudSqlInstance(String instanceName) {
return instances.computeIfAbsent(
instanceName,
k -> new CloudSqlInstance(k, adminApi, false, credentialFactory, executor, localKeyPair));
k -> {
try {
return new CloudSqlInstance(k, adminApi, false, credentialFactory, executor,
localKeyPair);
} catch (IOException e) {
throw new RuntimeException(e);
}
});
}

static int getDefaultServerProxyPort() {
Expand Down Expand Up @@ -215,7 +228,7 @@ public static Socket connect(Properties props) throws IOException {
*
* <p>Depending on the given properties, it may return either a SSL Socket or a Unix Socket.
*
* @param props Properties used to configure the connection.
* @param props Properties used to configure the connection.
* @param unixPathSuffix suffix to add the the Unix socket path. Unused if null.
* @return the newly created Socket.
* @throws IOException if error occurs during socket creation.
Expand Down Expand Up @@ -255,18 +268,19 @@ public static Socket connect(Properties props, String unixPathSuffix) throws IOE
/**
* Returns data that can be used to establish Cloud SQL SSL connection.
*/
public static SslData getSslData(String csqlInstanceName, boolean enableIamAuth) {
public static SslData getSslData(String csqlInstanceName, boolean enableIamAuth)
throws IOException {
return getInstance().getCloudSqlInstance(csqlInstanceName, enableIamAuth).getSslData();
}

public static SslData getSslData(String csqlInstanceName) {
public static SslData getSslData(String csqlInstanceName) throws IOException {
return getSslData(csqlInstanceName, false);
}

/**
* Returns preferred ip address that can be used to establish Cloud SQL connection.
*/
public static String getHostIp(String csqlInstanceName) {
public static String getHostIp(String csqlInstanceName) throws IOException {
return getInstance().getHostIp(csqlInstanceName, listIpTypes(DEFAULT_IP_TYPES));
}

Expand All @@ -280,7 +294,7 @@ private String getHostIp(String instanceName, List<String> ipTypes) {
* Creates a secure socket representing a connection to a Cloud SQL instance.
*
* @param instanceName Name of the Cloud SQL instance.
* @param ipTypes Preferred type of IP to use ("PRIVATE", "PUBLIC")
* @param ipTypes Preferred type of IP to use ("PRIVATE", "PUBLIC")
* @return the newly created Socket.
* @throws IOException if error occurs during socket creation.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright 2022 Google LLC. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.cloud.sql.core;

import static com.google.common.truth.Truth.assertThat;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.auth.oauth2.GoogleCredentials;
import com.google.auth.oauth2.OAuth2Credentials;
import java.io.IOException;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

@RunWith(JUnit4.class)
public class CloudSqlInstanceTest {

@Mock
private GoogleCredentials googleCredentials;

@Mock
private GoogleCredentials scopedCredentials;

@Mock
private OAuth2Credentials oAuth2Credentials;

@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
when(googleCredentials.createScoped(
"https://www.googleapis.com/auth/sqlservice.login")).thenReturn(scopedCredentials);
kurtisvg marked this conversation as resolved.
Show resolved Hide resolved
}

@Test
public void downscopesGoogleCredentials() {
GoogleCredentials downscoped = CloudSqlInstance.getDownscopedCredentials(googleCredentials);
assertThat(downscoped).isEqualTo(scopedCredentials);
verify(googleCredentials, times(1)).createScoped(
"https://www.googleapis.com/auth/sqlservice.login");
}


@Test
public void throwsErrorForWrongCredentialType() {
try {
CloudSqlInstance.getDownscopedCredentials(oAuth2Credentials);
} catch (RuntimeException ex) {
assertThat(ex)
.hasMessageThat()
.contains("Failed to downscope credentials for IAM Authentication");
}
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
import io.r2dbc.spi.ConnectionFactoryMetadata;
import io.r2dbc.spi.ConnectionFactoryOptions;
import io.r2dbc.spi.ConnectionFactoryOptions.Builder;
import java.io.IOException;
import java.util.function.Function;
import org.reactivestreams.Publisher;

/**
* * {@link ConnectionFactory} for accessing Cloud SQL instances via R2DBC protocol.
* * {@link ConnectionFactory} for accessing Cloud SQL instances via R2DBC protocol.
*/
public class CloudSqlConnectionFactory implements ConnectionFactory {

Expand All @@ -50,15 +51,23 @@ public CloudSqlConnectionFactory(

@Override
public Publisher<? extends Connection> create() {
return getConnectionFactory().create();
try {
return getConnectionFactory().create();
} catch (IOException e) {
throw new RuntimeException(e);
}
}

@Override
public ConnectionFactoryMetadata getMetadata() {
return getConnectionFactory().getMetadata();
try {
return getConnectionFactory().getMetadata();
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private ConnectionFactory getConnectionFactory() {
private ConnectionFactory getConnectionFactory() throws IOException {
String hostIp = CoreSocketFactory.getHostIp(csqlHostName);
builder.option(HOST, hostIp).option(PORT, CoreSocketFactory.getDefaultServerProxyPort());
return connectionFactoryFactory.apply(builder.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.r2dbc.spi.ConnectionFactoryOptions;
import io.r2dbc.spi.ConnectionFactoryProvider;
import io.r2dbc.spi.Option;
import java.io.IOException;
import java.util.function.Function;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
Expand All @@ -45,7 +46,13 @@ private static Function<SslContextBuilder, SslContextBuilder> createSslCustomize
sslContextBuilder -> {
// Execute in a default scheduler to prevent it from blocking event loop
SslData sslData = Mono
.fromSupplier(() -> CoreSocketFactory.getSslData(connectionName, enableIamAuth))
.fromSupplier(() -> {
try {
return CoreSocketFactory.getSslData(connectionName, enableIamAuth);
} catch (IOException e) {
throw new RuntimeException(e);
}
})
.subscribeOn(Schedulers.boundedElastic())
.share()
.block();
Expand Down Expand Up @@ -93,11 +100,15 @@ public ConnectionFactory create(ConnectionFactoryOptions connectionFactoryOption
"Cannot create ConnectionFactory: unsupported protocol (" + protocol + ")");
}

return createFactory(connectionFactoryOptions);
try {
return createFactory(connectionFactoryOptions);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private ConnectionFactory createFactory(
ConnectionFactoryOptions connectionFactoryOptions) {
ConnectionFactoryOptions connectionFactoryOptions) throws IOException {
String connectionName = (String) connectionFactoryOptions.getRequiredValue(HOST);
String socket = (String) connectionFactoryOptions.getValue(UNIX_SOCKET);

Expand Down