Skip to content

Commit

Permalink
Verify plugin presence based on actual plugin list (#1141)
Browse files Browse the repository at this point in the history
  • Loading branch information
sergiyvamz authored Oct 1, 2024
1 parent e1e5aa8 commit bc104c2
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 13 deletions.
2 changes: 2 additions & 0 deletions wrapper/src/main/java/software/amazon/jdbc/PluginService.java
Original file line number Diff line number Diff line change
Expand Up @@ -217,4 +217,6 @@ HostSpec getHostSpecByStrategy(List<HostSpec> hosts, HostRole role, String strat
String getTargetName();

@NonNull SessionStateService getSessionStateService();

<T> T getPlugin(final Class<T> pluginClazz);
}
Original file line number Diff line number Diff line change
Expand Up @@ -689,4 +689,13 @@ public String getTargetName() {
public @NonNull SessionStateService getSessionStateService() {
return this.sessionStateService;
}

public <T> T getPlugin(final Class<T> pluginClazz) {
for (ConnectionPlugin p : this.pluginManager.plugins) {
if (pluginClazz.isAssignableFrom(p.getClass())) {
return pluginClazz.cast(p);
}
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import java.sql.Statement;
import java.util.Collections;
import java.util.List;
import software.amazon.jdbc.ConnectionPluginChainBuilder;
import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider;
import software.amazon.jdbc.hostlistprovider.monitoring.MonitoringRdsHostListProvider;
import software.amazon.jdbc.plugin.failover2.FailoverConnectionPlugin;

public class AuroraMysqlDialect extends MysqlDialect {

Expand Down Expand Up @@ -83,9 +83,9 @@ public boolean isDialect(final Connection connection) {
public HostListProviderSupplier getHostListProvider() {
return (properties, initialUrl, hostListProviderService, pluginService) -> {

final List<String> plugins = ConnectionPluginChainBuilder.getPluginCodes(properties);
final FailoverConnectionPlugin failover2Plugin = pluginService.getPlugin(FailoverConnectionPlugin.class);

if (plugins.contains("failover2")) {
if (failover2Plugin != null) {
return new MonitoringRdsHostListProvider(
properties,
initialUrl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.List;
import java.util.logging.Logger;
import software.amazon.jdbc.ConnectionPluginChainBuilder;
import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider;
import software.amazon.jdbc.hostlistprovider.monitoring.MonitoringRdsHostListProvider;
import software.amazon.jdbc.plugin.failover2.FailoverConnectionPlugin;

/**
* Suitable for the following AWS PG configurations.
Expand Down Expand Up @@ -128,9 +127,9 @@ public boolean isDialect(final Connection connection) {
public HostListProviderSupplier getHostListProvider() {
return (properties, initialUrl, hostListProviderService, pluginService) -> {

final List<String> plugins = ConnectionPluginChainBuilder.getPluginCodes(properties);
final FailoverConnectionPlugin failover2Plugin = pluginService.getPlugin(FailoverConnectionPlugin.class);

if (plugins.contains("failover2")) {
if (failover2Plugin != null) {
return new MonitoringRdsHostListProvider(
properties,
initialUrl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
import java.util.List;
import java.util.Properties;
import org.checkerframework.checker.nullness.qual.NonNull;
import software.amazon.jdbc.ConnectionPluginChainBuilder;
import software.amazon.jdbc.HostSpec;
import software.amazon.jdbc.hostlistprovider.RdsMultiAzDbClusterListProvider;
import software.amazon.jdbc.hostlistprovider.monitoring.MonitoringRdsMultiAzHostListProvider;
import software.amazon.jdbc.plugin.failover.FailoverRestriction;
import software.amazon.jdbc.plugin.failover2.FailoverConnectionPlugin;
import software.amazon.jdbc.util.DriverInfo;

public class RdsMultiAzDbClusterMysqlDialect extends MysqlDialect {
Expand Down Expand Up @@ -98,9 +98,9 @@ public boolean isDialect(final Connection connection) {
public HostListProviderSupplier getHostListProvider() {
return (properties, initialUrl, hostListProviderService, pluginService) -> {

final List<String> plugins = ConnectionPluginChainBuilder.getPluginCodes(properties);
final FailoverConnectionPlugin failover2Plugin = pluginService.getPlugin(FailoverConnectionPlugin.class);

if (plugins.contains("failover2")) {
if (failover2Plugin != null) {
return new MonitoringRdsMultiAzHostListProvider(
properties,
initialUrl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
import java.sql.Statement;
import java.util.List;
import java.util.logging.Logger;
import software.amazon.jdbc.ConnectionPluginChainBuilder;
import software.amazon.jdbc.exceptions.ExceptionHandler;
import software.amazon.jdbc.exceptions.MultiAzDbClusterPgExceptionHandler;
import software.amazon.jdbc.hostlistprovider.RdsMultiAzDbClusterListProvider;
import software.amazon.jdbc.hostlistprovider.monitoring.MonitoringRdsMultiAzHostListProvider;
import software.amazon.jdbc.plugin.failover2.FailoverConnectionPlugin;
import software.amazon.jdbc.util.DriverInfo;

public class RdsMultiAzDbClusterPgDialect extends PgDialect {
Expand Down Expand Up @@ -113,9 +113,10 @@ public boolean isDialect(final Connection connection) {
@Override
public HostListProviderSupplier getHostListProvider() {
return (properties, initialUrl, hostListProviderService, pluginService) -> {
final List<String> plugins = ConnectionPluginChainBuilder.getPluginCodes(properties);

if (plugins.contains("failover2")) {
final FailoverConnectionPlugin failover2Plugin = pluginService.getPlugin(FailoverConnectionPlugin.class);

if (failover2Plugin != null) {
return new MonitoringRdsMultiAzHostListProvider(
properties,
initialUrl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ void setUp() throws SQLException {
when(this.mockConnection.createStatement()).thenReturn(this.mockStatement);
when(this.mockHost.getUrl()).thenReturn("url");
when(this.failResultSet.next()).thenReturn(false);
pluginManager.plugins = new ArrayList<>();
}

@AfterEach
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,11 @@ public String getTargetName() {
return new TestSessionStateService();
}

@Override
public <T> T getPlugin(Class<T> pluginClazz) {
return null;
}

@Override
public boolean isNetworkException(Throwable throwable) {
return false;
Expand Down

0 comments on commit bc104c2

Please sign in to comment.