diff --git a/docs/using-the-jdbc-driver/UsingTheJdbcDriver.md b/docs/using-the-jdbc-driver/UsingTheJdbcDriver.md index 9c51707c1..70fb2c279 100644 --- a/docs/using-the-jdbc-driver/UsingTheJdbcDriver.md +++ b/docs/using-the-jdbc-driver/UsingTheJdbcDriver.md @@ -100,6 +100,25 @@ DriverConfigurationProfiles.addOrReplaceProfile( CustomConnectionPluginFactory.class)); ``` +### Connection Initialization +In some cases it's necessary to configure a connection before a user application can use it. Some target drivers provides such functionality and allow to specify a configuration parameter with SQL statements that are executed when connection is established. However, not all drivers supports such functionality. Also, in some cases, additional conditions should be checked in order to identify what initialization is required for a particular connection. + +AWS JDBC Driver allows to specify a special function that can initialize a connection. It can be done with `ConnectionProviderManager.setConnectionInitFunc` method. `resetConnectionInitFunc` method is also available. + +The initialization function is called for all connections, including pre-opened connections provided by internal connection pool (see [Using Read Write Splitting Plugin Internal Connection Pooling](./using-plugins/UsingTheReadWriteSplittingPlugin.md#internal-connection-pooling)), and, thus, helping a user application to clean up connection session "contaminated" by previous use. + +> :warning: Executing CPU and network intensive code in the initialization function may cause significant performance degradation. + +```java +ConnectionProviderManager.setConnectionInitFunc((connection, protocol, hostSpec, props) -> { + // Set custom schema for connections to a test-database + if ("test-database".equals(props.getProperty("database"))) { + connection.setSchema("test-database-schema"); + } +}); +``` + + ### List of Available Plugins The AWS JDBC Driver has several built-in plugins that are available to use. Please visit the individual plugin page for more details. diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/DefaultConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/DefaultConnectionPlugin.java index cf95d0936..837adba9e 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/DefaultConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/DefaultConnectionPlugin.java @@ -67,6 +67,18 @@ public DefaultConnectionPlugin( final PluginService pluginService, final ConnectionProvider defaultConnProvider, final PluginManagerService pluginManagerService) { + this(pluginService, + defaultConnProvider, + pluginManagerService, + new ConnectionProviderManager(defaultConnProvider)); + } + + public DefaultConnectionPlugin( + final PluginService pluginService, + final ConnectionProvider defaultConnProvider, + final PluginManagerService pluginManagerService, + final ConnectionProviderManager connProviderManager) { + if (pluginService == null) { throw new IllegalArgumentException("pluginService"); } @@ -79,7 +91,7 @@ public DefaultConnectionPlugin( this.pluginService = pluginService; this.pluginManagerService = pluginManagerService; - this.connProviderManager = new ConnectionProviderManager(defaultConnProvider); + this.connProviderManager = connProviderManager; } @Override diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java index 5157448b9..cf30ae7e6 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java @@ -21,6 +21,7 @@ import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -32,6 +33,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Properties; import java.util.stream.Stream; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -42,6 +44,9 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import software.amazon.jdbc.ConnectionProvider; +import software.amazon.jdbc.ConnectionProviderManager; +import software.amazon.jdbc.HostRole; +import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.PluginManagerService; import software.amazon.jdbc.PluginService; @@ -59,12 +64,16 @@ class DefaultConnectionPluginTest { @Mock ConnectionProvider connectionProvider; @Mock PluginManagerService pluginManagerService; @Mock JdbcCallable mockSqlFunction; + @Mock JdbcCallable mockConnectFunction; @Mock Connection conn; @Mock Connection oldConn; @Mock private TelemetryFactory mockTelemetryFactory; @Mock TelemetryContext mockTelemetryContext; @Mock TelemetryCounter mockTelemetryCounter; @Mock TelemetryGauge mockTelemetryGauge; + @Mock ConnectionProviderManager mockConnectionProviderManager; + @Mock ConnectionProvider mockConnectionProvider; + @Mock HostSpec mockHostSpec; private AutoCloseable closeable; @@ -79,8 +88,11 @@ void setUp() { when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); // noinspection unchecked when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); + when(mockConnectionProviderManager.getConnectionProvider(anyString(), any(), any())) + .thenReturn(mockConnectionProvider); - plugin = new DefaultConnectionPlugin(pluginService, connectionProvider, pluginManagerService); + plugin = new DefaultConnectionPlugin( + pluginService, connectionProvider, pluginManagerService, mockConnectionProviderManager); } @AfterEach @@ -109,6 +121,19 @@ void testExecute_closeOldConnection() throws SQLException { verify(pluginManagerService, never()).setInTransaction(anyBoolean()); } + @Test + void testConnect() throws SQLException { + plugin.connect("anyProtocol", mockHostSpec, new Properties(), true, mockConnectFunction); + verify(mockConnectionProvider, atLeastOnce()).connect(anyString(), any(), any(), any()); + verify(mockConnectionProviderManager, atLeastOnce()).initConnection(any(), anyString(), any(), any()); + + ConnectionProviderManager.setConnectionInitFunc((connection, protocol, hostSpec, props) -> { + if ("test-database".equals(props.getProperty("database"))) { + connection.setSchema("test-database-schema"); + } + }); + } + private static Stream multiStatementQueries() { return Stream.of( Arguments.of("", new ArrayList()),