Skip to content

Commit

Permalink
Allow user to set up a lambda to initialize new connections (#705)
Browse files Browse the repository at this point in the history
Co-authored-by: crystall-bitquill <97126568+crystall-bitquill@users.noreply.github.com>
Co-authored-by: Karen <64801825+karenc-bq@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 27, 2023
1 parent 87ec172 commit bf9b25c
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 2 deletions.
22 changes: 22 additions & 0 deletions docs/using-the-jdbc-driver/UsingTheJdbcDriver.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,28 @@ DriverConfigurationProfiles.addOrReplaceProfile(
CustomConnectionPluginFactory.class));
```

### Executing Custom Code When Initializing a Connection
In some use cases you may need to define a specific configuration for a new driver connection before your application can use it. For instance:
- you might need to run some initial SQL queries when a connection is established, or;
- you might need to check for some additional conditions to determine the initialization configuration required for a particular connection.

The AWS JDBC Driver allows specifying a special function that can initialize a connection. It can be done with `ConnectionProviderManager.setConnectionInitFunc` method. The `resetConnectionInitFunc` method is also available to remove the function.

The initialization function is called for all connections, including connections opened by the internal connection pools (see [Using Read Write Splitting Plugin and Internal Connection Pooling](./using-plugins/UsingTheReadWriteSplittingPlugin.md#internal-connection-pooling)). This helps user applications clean up connection sessions that have been altered by previous operations, as returning a connection to a pool will reset the state and retrieving it will call the initialization function again.

> [!WARNING]\
> Executing CPU and network intensive code in the initialization function may significantly impact the AWS JDBC Driver's overall performance.
```java
ConnectionProviderManager.setConnectionInitFunc((connection, protocol, hostSpec, props) -> {
// Set custom schema for connections to a test-database
if ("test-database".equals(props.getProperty("database"))) {
connection.setSchema("test-database-schema");
}
});
```


### List of Available Plugins
The AWS JDBC Driver has several built-in plugins that are available to use. Please visit the individual plugin page for more details.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

package software.amazon.jdbc;

import java.sql.Connection;
import java.sql.SQLException;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import software.amazon.jdbc.cleanup.CanReleaseResources;

public class ConnectionProviderManager {
Expand All @@ -28,6 +31,8 @@ public class ConnectionProviderManager {
private static ConnectionProvider connProvider = null;
private final ConnectionProvider defaultProvider;

private static ConnectionInitFunc connectionInitFunc = null;

/**
* {@link ConnectionProviderManager} constructor.
*
Expand Down Expand Up @@ -193,4 +198,34 @@ public static void releaseResources() {
}
}
}

public static void setConnectionInitFunc(final @NonNull ConnectionInitFunc func) {
connectionInitFunc = func;
}

public static void resetConnectionInitFunc() {
connectionInitFunc = null;
}

public void initConnection(
final @Nullable Connection connection,
final @NonNull String protocol,
final @NonNull HostSpec hostSpec,
final @NonNull Properties props) throws SQLException {

final ConnectionInitFunc copy = connectionInitFunc;
if (copy == null) {
return;
}

copy.initConnection(connection, protocol, hostSpec, props);
}

public interface ConnectionInitFunc {
void initConnection(
final @Nullable Connection connection,
final @NonNull String protocol,
final @NonNull HostSpec hostSpec,
final @NonNull Properties props) throws SQLException;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ public DefaultConnectionPlugin(
final PluginService pluginService,
final ConnectionProvider defaultConnProvider,
final PluginManagerService pluginManagerService) {
this(pluginService,
defaultConnProvider,
pluginManagerService,
new ConnectionProviderManager(defaultConnProvider));
}

public DefaultConnectionPlugin(
final PluginService pluginService,
final ConnectionProvider defaultConnProvider,
final PluginManagerService pluginManagerService,
final ConnectionProviderManager connProviderManager) {

if (pluginService == null) {
throw new IllegalArgumentException("pluginService");
}
Expand All @@ -79,7 +91,7 @@ public DefaultConnectionPlugin(

this.pluginService = pluginService;
this.pluginManagerService = pluginManagerService;
this.connProviderManager = new ConnectionProviderManager(defaultConnProvider);
this.connProviderManager = connProviderManager;
}

@Override
Expand Down Expand Up @@ -173,6 +185,8 @@ private Connection connectInternal(
telemetryContext.closeContext();
}

this.connProviderManager.initConnection(conn, driverProtocol, hostSpec, props);

this.pluginService.setAvailability(hostSpec.asAliases(), HostAvailability.AVAILABLE);
this.pluginService.updateDialect(conn);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand All @@ -32,6 +33,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Properties;
import java.util.stream.Stream;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
Expand All @@ -42,6 +44,8 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import software.amazon.jdbc.ConnectionProvider;
import software.amazon.jdbc.ConnectionProviderManager;
import software.amazon.jdbc.HostSpec;
import software.amazon.jdbc.JdbcCallable;
import software.amazon.jdbc.PluginManagerService;
import software.amazon.jdbc.PluginService;
Expand All @@ -59,12 +63,16 @@ class DefaultConnectionPluginTest {
@Mock ConnectionProvider connectionProvider;
@Mock PluginManagerService pluginManagerService;
@Mock JdbcCallable<Void, SQLException> mockSqlFunction;
@Mock JdbcCallable<Connection, SQLException> mockConnectFunction;
@Mock Connection conn;
@Mock Connection oldConn;
@Mock private TelemetryFactory mockTelemetryFactory;
@Mock TelemetryContext mockTelemetryContext;
@Mock TelemetryCounter mockTelemetryCounter;
@Mock TelemetryGauge mockTelemetryGauge;
@Mock ConnectionProviderManager mockConnectionProviderManager;
@Mock ConnectionProvider mockConnectionProvider;
@Mock HostSpec mockHostSpec;


private AutoCloseable closeable;
Expand All @@ -79,8 +87,11 @@ void setUp() {
when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter);
// noinspection unchecked
when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge);
when(mockConnectionProviderManager.getConnectionProvider(anyString(), any(), any()))
.thenReturn(mockConnectionProvider);

plugin = new DefaultConnectionPlugin(pluginService, connectionProvider, pluginManagerService);
plugin = new DefaultConnectionPlugin(
pluginService, connectionProvider, pluginManagerService, mockConnectionProviderManager);
}

@AfterEach
Expand Down Expand Up @@ -109,6 +120,13 @@ void testExecute_closeOldConnection() throws SQLException {
verify(pluginManagerService, never()).setInTransaction(anyBoolean());
}

@Test
void testConnect() throws SQLException {
plugin.connect("anyProtocol", mockHostSpec, new Properties(), true, mockConnectFunction);
verify(mockConnectionProvider, atLeastOnce()).connect(anyString(), any(), any(), any());
verify(mockConnectionProviderManager, atLeastOnce()).initConnection(any(), anyString(), any(), any());
}

private static Stream<Arguments> multiStatementQueries() {
return Stream.of(
Arguments.of("", new ArrayList<String>()),
Expand Down

0 comments on commit bf9b25c

Please sign in to comment.