Skip to content

Commit

Permalink
refresh when initializing credentials
Browse files Browse the repository at this point in the history
  • Loading branch information
shubha-rajan committed Nov 1, 2022
1 parent 1870f48 commit f909ac8
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 26 deletions.
12 changes: 9 additions & 3 deletions core/src/main/java/com/google/cloud/sql/CredentialFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@

import com.google.api.client.auth.oauth2.Credential;
import com.google.api.client.http.HttpRequestInitializer;
import java.io.IOException;

/** Factory for creating {@link Credential}s for interaction with Cloud SQL Admin API. */
/**
* Factory for creating {@link Credential}s for interaction with Cloud SQL Admin API.
*/
public interface CredentialFactory {
/** Name of system property that can specify an alternative credential factory. */

/**
* Name of system property that can specify an alternative credential factory.
*/
String CREDENTIAL_FACTORY_PROPERTY = "cloudSql.socketFactory.credentialFactory";

HttpRequestInitializer create();
HttpRequestInitializer create() throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class CloudSqlInstance {
boolean enableIamAuth,
CredentialFactory tokenSourceFactory,
ListeningScheduledExecutorService executor,
ListenableFuture<KeyPair> keyPair) {
ListenableFuture<KeyPair> keyPair) throws IOException {

Matcher matcher = CONNECTION_NAME.matcher(connectionName);
checkArgument(
Expand Down Expand Up @@ -279,10 +279,10 @@ SSLSocket createSslSocket() throws IOException {
* preferredTypes.
*
* @param preferredTypes Preferred instance IP types to use. Valid IP types include "Public" and
* "Private".
* "Private".
* @return returns a string representing the IP address for the instance
* @throws IllegalArgumentException If the instance has no IP addresses matching the provided
* preferences.
* preferences.
*/
String getPreferredIp(List<String> preferredTypes) {
Map<String, String> ipAddrs = getInstanceData().getIpAddrs();
Expand Down Expand Up @@ -527,9 +527,7 @@ private Certificate fetchEphemeralCertificate(KeyPair keyPair) {

if (enableIamAuth) {
try {
OAuth2Credentials creds = credentials.get();
creds.refresh();
GoogleCredentials downscoped = getDownscopedCredentials(creds);
GoogleCredentials downscoped = getDownscopedCredentials(credentials.get());
downscoped.refresh();
String token = downscoped.getAccessToken().getTokenValue();
// TODO: remove this once issue with OAuth2 Tokens is resolved.
Expand Down Expand Up @@ -580,7 +578,7 @@ static GoogleCredentials getDownscopedCredentials(OAuth2Credentials credentials)
}
return downscoped;
}

private Date getTokenExpirationTime() {
return credentials.get().getAccessToken().getExpirationTime();
}
Expand Down Expand Up @@ -608,7 +606,7 @@ private long secondsUntilRefresh() {
*
* @param ex exception thrown by the Admin API request
* @param fallbackDesc generic description used as a fallback if no additional information can be
* provided to the user
* provided to the user
*/
private RuntimeException addExceptionContext(IOException ex, String fallbackDesc) {
// Verify we are able to extract a reason from an exception, or fallback to a generic desc
Expand Down
37 changes: 26 additions & 11 deletions core/src/main/java/com/google/cloud/sql/core/CoreSocketFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public final class CoreSocketFactory {
* Property used to set the application name for the underlying SQLAdmin client.
*
* @deprecated Use {@link #setApplicationName(String)} to set the application name
* programmatically.
* programmatically.
*/

@Deprecated
Expand Down Expand Up @@ -117,7 +117,7 @@ public final class CoreSocketFactory {
/**
* Returns the {@link CoreSocketFactory} singleton.
*/
public static synchronized CoreSocketFactory getInstance() {
public static synchronized CoreSocketFactory getInstance() throws IOException {
if (coreSocketFactory == null) {
logger.info("First Cloud SQL connection, generating RSA key pair.");

Expand Down Expand Up @@ -155,14 +155,27 @@ public static synchronized CoreSocketFactory getInstance() {
private CloudSqlInstance getCloudSqlInstance(String instanceName, boolean enableIamAuth) {
return instances.computeIfAbsent(
instanceName,
k -> new CloudSqlInstance(k, adminApi, enableIamAuth, credentialFactory, executor,
localKeyPair));
k -> {
try {
return new CloudSqlInstance(k, adminApi, enableIamAuth, credentialFactory, executor,
localKeyPair);
} catch (IOException e) {
throw new RuntimeException(e);
}
});
}

private CloudSqlInstance getCloudSqlInstance(String instanceName) {
return instances.computeIfAbsent(
instanceName,
k -> new CloudSqlInstance(k, adminApi, false, credentialFactory, executor, localKeyPair));
k -> {
try {
return new CloudSqlInstance(k, adminApi, false, credentialFactory, executor,
localKeyPair);
} catch (IOException e) {
throw new RuntimeException(e);
}
});
}

static int getDefaultServerProxyPort() {
Expand Down Expand Up @@ -215,7 +228,7 @@ public static Socket connect(Properties props) throws IOException {
*
* <p>Depending on the given properties, it may return either a SSL Socket or a Unix Socket.
*
* @param props Properties used to configure the connection.
* @param props Properties used to configure the connection.
* @param unixPathSuffix suffix to add the the Unix socket path. Unused if null.
* @return the newly created Socket.
* @throws IOException if error occurs during socket creation.
Expand Down Expand Up @@ -255,18 +268,19 @@ public static Socket connect(Properties props, String unixPathSuffix) throws IOE
/**
* Returns data that can be used to establish Cloud SQL SSL connection.
*/
public static SslData getSslData(String csqlInstanceName, boolean enableIamAuth) {
public static SslData getSslData(String csqlInstanceName, boolean enableIamAuth)
throws IOException {
return getInstance().getCloudSqlInstance(csqlInstanceName, enableIamAuth).getSslData();
}

public static SslData getSslData(String csqlInstanceName) {
public static SslData getSslData(String csqlInstanceName) throws IOException {
return getSslData(csqlInstanceName, false);
}

/**
* Returns preferred ip address that can be used to establish Cloud SQL connection.
*/
public static String getHostIp(String csqlInstanceName) {
public static String getHostIp(String csqlInstanceName) throws IOException {
return getInstance().getHostIp(csqlInstanceName, listIpTypes(DEFAULT_IP_TYPES));
}

Expand All @@ -280,7 +294,7 @@ private String getHostIp(String instanceName, List<String> ipTypes) {
* Creates a secure socket representing a connection to a Cloud SQL instance.
*
* @param instanceName Name of the Cloud SQL instance.
* @param ipTypes Preferred type of IP to use ("PRIVATE", "PUBLIC")
* @param ipTypes Preferred type of IP to use ("PRIVATE", "PUBLIC")
* @return the newly created Socket.
* @throws IOException if error occurs during socket creation.
*/
Expand Down Expand Up @@ -368,7 +382,7 @@ private static SQLAdmin createAdminApiClient(HttpRequestInitializer requestIniti
private static class ApplicationDefaultCredentialFactory implements CredentialFactory {

@Override
public HttpRequestInitializer create() {
public HttpRequestInitializer create() throws IOException {
GoogleCredentials credentials;
try {
credentials = GoogleCredentials.getApplicationDefault();
Expand All @@ -382,6 +396,7 @@ public HttpRequestInitializer create() {
SQLAdminScopes.SQLSERVICE_ADMIN,
SQLAdminScopes.CLOUD_PLATFORM)
);
credentials.refresh();
}
return new HttpCredentialsAdapter(credentials);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
import io.r2dbc.spi.ConnectionFactoryMetadata;
import io.r2dbc.spi.ConnectionFactoryOptions;
import io.r2dbc.spi.ConnectionFactoryOptions.Builder;
import java.io.IOException;
import java.util.function.Function;
import org.reactivestreams.Publisher;

/**
* * {@link ConnectionFactory} for accessing Cloud SQL instances via R2DBC protocol.
* * {@link ConnectionFactory} for accessing Cloud SQL instances via R2DBC protocol.
*/
public class CloudSqlConnectionFactory implements ConnectionFactory {

Expand All @@ -50,15 +51,23 @@ public CloudSqlConnectionFactory(

@Override
public Publisher<? extends Connection> create() {
return getConnectionFactory().create();
try {
return getConnectionFactory().create();
} catch (IOException e) {
throw new RuntimeException(e);
}
}

@Override
public ConnectionFactoryMetadata getMetadata() {
return getConnectionFactory().getMetadata();
try {
return getConnectionFactory().getMetadata();
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private ConnectionFactory getConnectionFactory() {
private ConnectionFactory getConnectionFactory() throws IOException {
String hostIp = CoreSocketFactory.getHostIp(csqlHostName);
builder.option(HOST, hostIp).option(PORT, CoreSocketFactory.getDefaultServerProxyPort());
return connectionFactoryFactory.apply(builder.build());
Expand Down

0 comments on commit f909ac8

Please sign in to comment.