From 4572a10f72b00b12bbed217da8d11c34357837a9 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Thu, 7 Dec 2023 15:04:20 +0100 Subject: [PATCH] Add support for dynamic usernames and passwords. [closes #613] Signed-off-by: Mark Paluch --- .../PostgresqlConnectionConfiguration.java | 74 +++++++++++++++---- .../PostgresqlConnectionFactoryProvider.java | 25 ++++++- .../SingleHostConnectionFunction.java | 45 +++++++++-- ...resqlConnectionConfigurationUnitTests.java | 8 +- ...sqlConnectionFactoryProviderUnitTests.java | 31 ++++++++ .../ReactorNettyClientIntegrationTests.java | 2 +- 6 files changed, 154 insertions(+), 31 deletions(-) diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java index 1acbeab5..aa323787 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java @@ -36,6 +36,8 @@ import io.r2dbc.postgresql.message.backend.NoticeResponse; import io.r2dbc.postgresql.util.Assert; import io.r2dbc.postgresql.util.LogLevel; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; import reactor.netty.resources.LoopResources; import reactor.util.annotation.Nullable; @@ -103,7 +105,7 @@ public final class PostgresqlConnectionConfiguration { private final Map options; - private final CharSequence password; + private final Publisher password; private final boolean preferAttachedBuffers; @@ -123,18 +125,18 @@ public final class PostgresqlConnectionConfiguration { private final TimeZone timeZone; - private final String username; + private final Publisher username; private PostgresqlConnectionConfiguration(String applicationName, boolean autodetectExtensions, @Nullable boolean compatibilityMode, @Nullable Duration connectTimeout, @Nullable String database, LogLevel errorResponseLogLevel, List extensions, ToIntFunction fetchSize, boolean forceBinary, @Nullable Duration lockWaitTimeout, @Nullable LoopResources loopResources, @Nullable MultiHostConfiguration multiHostConfiguration, - LogLevel noticeLogLevel, @Nullable Map options, @Nullable CharSequence password, boolean preferAttachedBuffers, + LogLevel noticeLogLevel, @Nullable Map options, Publisher password, boolean preferAttachedBuffers, int preparedStatementCacheQueries, @Nullable String schema, @Nullable SingleHostConfiguration singleHostConfiguration, SSLConfig sslConfig, @Nullable Duration statementTimeout, boolean tcpKeepAlive, boolean tcpNoDelay, TimeZone timeZone, - String username) { + Publisher username) { this.applicationName = Assert.requireNonNull(applicationName, "applicationName must not be null"); this.autodetectExtensions = autodetectExtensions; this.compatibilityMode = compatibilityMode; @@ -200,7 +202,7 @@ public String toString() { ", multiHostConfiguration='" + this.multiHostConfiguration + '\'' + ", noticeLogLevel='" + this.noticeLogLevel + '\'' + ", options='" + this.options + '\'' + - ", password='" + obfuscate(this.password != null ? this.password.length() : 0) + '\'' + + ", password='" + obfuscate(this.password != null ? 4 : 0) + '\'' + ", preferAttachedBuffers=" + this.preferAttachedBuffers + ", singleHostConfiguration=" + this.singleHostConfiguration + ", statementTimeout=" + this.statementTimeout + @@ -261,8 +263,7 @@ Map getOptions() { return Collections.unmodifiableMap(this.options); } - @Nullable - CharSequence getPassword() { + Publisher getPassword() { return this.password; } @@ -290,7 +291,7 @@ SingleHostConfiguration getRequiredSingleHostConfiguration() { return config; } - String getUsername() { + Publisher getUsername() { return this.username; } @@ -380,7 +381,7 @@ public static final class Builder { private Map options; @Nullable - private CharSequence password; + private Publisher password; private boolean preferAttachedBuffers = false; @@ -423,7 +424,7 @@ public static final class Builder { private LoopResources loopResources = null; @Nullable - private String username; + private Publisher username; private Builder() { } @@ -743,7 +744,31 @@ public Builder options(Map options) { * @return this {@link Builder} */ public Builder password(@Nullable CharSequence password) { - this.password = password; + this.password = Mono.justOrEmpty(password); + return this; + } + + /** + * Configure the password publisher. The publisher is used on each authentication attempt. + * + * @param password the password + * @return this {@link Builder} + * @since 1.0.3 + */ + public Builder password(Publisher password) { + this.password = Mono.from(password); + return this; + } + + /** + * Configure the password supplier. The supplier is used on each authentication attempt. + * + * @param password the password + * @return this {@link Builder} + * @since 1.0.3 + */ + public Builder password(Supplier password) { + this.password = Mono.fromSupplier(password); return this; } @@ -780,7 +805,6 @@ public Builder preferAttachedBuffers(boolean preferAttachedBuffers) { * * @param preparedStatementCacheQueries the preparedStatementCacheQueries * @return this {@link Builder} - * @throws IllegalArgumentException if {@code username} is {@code null} * @since 0.8.1 */ public Builder preparedStatementCacheQueries(int preparedStatementCacheQueries) { @@ -1023,10 +1047,34 @@ public Builder timeZone(TimeZone timeZone) { * @throws IllegalArgumentException if {@code username} is {@code null} */ public Builder username(String username) { + this.username = Mono.just(Assert.requireNonNull(username, "username must not be null")); + return this; + } + + /** + * Configure the username publisher. The publisher is used on each authentication attempt. + * + * @param username the username + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code username} is {@code null} + */ + public Builder username(Publisher username) { this.username = Assert.requireNonNull(username, "username must not be null"); return this; } + /** + * Configure the username supplier. The supplier is used on each authentication attempt. + * + * @param username the username + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code username} is {@code null} + */ + public Builder username(Supplier username) { + this.username = Mono.fromSupplier(Assert.requireNonNull(username, "username must not be null")); + return this; + } + @Override public String toString() { return "Builder{" + @@ -1044,7 +1092,7 @@ public String toString() { ", multiHostConfiguration='" + this.multiHostConfiguration + '\'' + ", noticeLogLevel='" + this.noticeLogLevel + '\'' + ", parameters='" + this.options + '\'' + - ", password='" + obfuscate(this.password != null ? this.password.length() : 0) + '\'' + + ", password='" + obfuscate(this.password != null ? 4 : 0) + '\'' + ", preparedStatementCacheQueries='" + this.preparedStatementCacheQueries + '\'' + ", schema='" + this.schema + '\'' + ", singleHostConfiguration='" + this.singleHostConfiguration + '\'' + diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java index 02db95fd..b5180fb4 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java @@ -27,6 +27,7 @@ import io.r2dbc.spi.ConnectionFactoryOptions; import io.r2dbc.spi.ConnectionFactoryProvider; import io.r2dbc.spi.Option; +import org.reactivestreams.Publisher; import reactor.netty.resources.LoopResources; import javax.net.ssl.HostnameVerifier; @@ -37,6 +38,7 @@ import java.util.Map; import java.util.TimeZone; import java.util.function.Function; +import java.util.function.Supplier; import static io.r2dbc.spi.ConnectionFactoryOptions.CONNECT_TIMEOUT; import static io.r2dbc.spi.ConnectionFactoryOptions.DATABASE; @@ -290,6 +292,7 @@ public boolean supports(ConnectionFactoryOptions connectionFactoryOptions) { * @return this {@link PostgresqlConnectionConfiguration.Builder} * @throws IllegalArgumentException if {@code options} is {@code null} */ + @SuppressWarnings("unchecked") private static PostgresqlConnectionConfiguration.Builder fromConnectionFactoryOptions(ConnectionFactoryOptions options) { Assert.requireNonNull(options, "connectionFactoryOptions must not be null"); @@ -344,7 +347,6 @@ private static PostgresqlConnectionConfiguration.Builder fromConnectionFactoryOp mapper.fromTyped(LOOP_RESOURCES).to(builder::loopResources); mapper.from(NOTICE_LOG_LEVEL).map(it -> OptionMapper.toEnum(it, LogLevel.class)).to(builder::noticeLogLevel); mapper.from(OPTIONS).map(PostgresqlConnectionFactoryProvider::convertToMap).to(builder::options); - mapper.fromTyped(PASSWORD).to(builder::password); mapper.from(PORT).map(OptionMapper::toInteger).to(builder::port); mapper.from(PREFER_ATTACHED_BUFFERS).map(OptionMapper::toBoolean).to(builder::preferAttachedBuffers); mapper.from(PREPARED_STATEMENT_CACHE_QUERIES).map(OptionMapper::toInteger).to(builder::preparedStatementCacheQueries); @@ -363,7 +365,26 @@ private static PostgresqlConnectionConfiguration.Builder fromConnectionFactoryOp return TimeZone.getTimeZone(it.toString()); }).to(builder::timeZone); - builder.username("" + options.getRequiredValue(USER)); + + Object user = options.getRequiredValue(USER); + Object password = options.getValue(PASSWORD); + + if (user instanceof Supplier) { + builder.username((Supplier) user); + } else if (user instanceof Publisher) { + builder.username((Publisher) user); + } else { + builder.username("" + user); + } + if (password != null) { + if (password instanceof Supplier) { + builder.password((Supplier) password); + } else if (password instanceof Publisher) { + builder.password((Publisher) password); + } else { + builder.password((CharSequence) password); + } + } return builder; } diff --git a/src/main/java/io/r2dbc/postgresql/SingleHostConnectionFunction.java b/src/main/java/io/r2dbc/postgresql/SingleHostConnectionFunction.java index 6f43212e..e3383c18 100644 --- a/src/main/java/io/r2dbc/postgresql/SingleHostConnectionFunction.java +++ b/src/main/java/io/r2dbc/postgresql/SingleHostConnectionFunction.java @@ -26,6 +26,7 @@ import io.r2dbc.postgresql.message.backend.AuthenticationMessage; import io.r2dbc.postgresql.util.Assert; import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; import java.net.SocketAddress; @@ -44,9 +45,9 @@ final class SingleHostConnectionFunction implements ConnectionFunction { public Mono connect(SocketAddress endpoint, ConnectionSettings settings) { return this.upstreamFunction.connect(endpoint, settings) - .delayUntil(client -> StartupMessageFlow - .exchange(this::getAuthenticationHandler, client, this.configuration.getDatabase(), this.configuration.getUsername(), - getParameterProvider(this.configuration, settings)) + .delayUntil(client -> getCredentials().flatMapMany(credentials -> StartupMessageFlow + .exchange(auth -> getAuthenticationHandler(auth, credentials), client, this.configuration.getDatabase(), credentials.getUsername(), + getParameterProvider(this.configuration, settings))) .handle(ExceptionFactory.INSTANCE::handleErrorResponse)); } @@ -54,16 +55,44 @@ private static PostgresStartupParameterProvider getParameterProvider(PostgresqlC return new PostgresStartupParameterProvider(configuration.getApplicationName(), configuration.getTimeZone(), settings); } - protected AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message) { + protected AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message, UsernameAndPassword usernameAndPassword) { if (PasswordAuthenticationHandler.supports(message)) { - CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null"); - return new PasswordAuthenticationHandler(password, this.configuration.getUsername()); + CharSequence password = Assert.requireNonNull(usernameAndPassword.getPassword(), "Password must not be null"); + return new PasswordAuthenticationHandler(password, usernameAndPassword.getUsername()); } else if (SASLAuthenticationHandler.supports(message)) { - CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null"); - return new SASLAuthenticationHandler(password, this.configuration.getUsername()); + CharSequence password = Assert.requireNonNull(usernameAndPassword.getPassword(), "Password must not be null"); + return new SASLAuthenticationHandler(password, usernameAndPassword.getUsername()); } else { throw new IllegalStateException(String.format("Unable to provide AuthenticationHandler capable of handling %s", message)); } } + Mono getCredentials() { + + return Mono.zip(Mono.from(this.configuration.getUsername()).single(), Mono.from(this.configuration.getPassword()).singleOptional()).map(it -> { + return new UsernameAndPassword(it.getT1(), it.getT2().orElse(null)); + }); + } + + static class UsernameAndPassword { + + final String username; + + final @Nullable CharSequence password; + + public UsernameAndPassword(String username, @Nullable CharSequence password) { + this.username = username; + this.password = password; + } + + public String getUsername() { + return this.username; + } + + @Nullable + public CharSequence getPassword() { + return this.password; + } + } + } diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionConfigurationUnitTests.java b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionConfigurationUnitTests.java index d7059922..f6c346f3 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionConfigurationUnitTests.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionConfigurationUnitTests.java @@ -53,7 +53,7 @@ void builderHostAndSocket() { @Test void builderNoUsername() { - assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlConnectionConfiguration.builder().username(null)) + assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlConnectionConfiguration.builder().username((String) null)) .withMessage("username must not be null"); } @@ -84,9 +84,7 @@ void configuration() { .hasFieldOrPropertyWithValue("database", "test-database") .hasFieldOrPropertyWithValue("singleHostConfiguration.host", "test-host") .hasFieldOrProperty("options") - .hasFieldOrPropertyWithValue("password", null) .hasFieldOrPropertyWithValue("singleHostConfiguration.port", 100) - .hasFieldOrPropertyWithValue("username", "test-username") .hasFieldOrProperty("sslConfig") .hasFieldOrPropertyWithValue("tcpKeepAlive", true) .hasFieldOrPropertyWithValue("tcpNoDelay", false) @@ -116,9 +114,7 @@ void configureStatementAndLockTimeouts() { .hasFieldOrPropertyWithValue("database", "test-database") .hasFieldOrPropertyWithValue("singleHostConfiguration.host", "test-host") .hasFieldOrProperty("options") - .hasFieldOrPropertyWithValue("password", null) .hasFieldOrPropertyWithValue("singleHostConfiguration.port", 100) - .hasFieldOrPropertyWithValue("username", "test-username") .hasFieldOrProperty("sslConfig") .hasFieldOrPropertyWithValue("tcpKeepAlive", true) .hasFieldOrPropertyWithValue("tcpNoDelay", false) @@ -160,10 +156,8 @@ void configurationDefaults() { .hasFieldOrPropertyWithValue("applicationName", "r2dbc-postgresql") .hasFieldOrPropertyWithValue("database", "test-database") .hasFieldOrPropertyWithValue("singleHostConfiguration.host", "test-host") - .hasFieldOrPropertyWithValue("password", "test-password") .hasFieldOrPropertyWithValue("singleHostConfiguration.port", 5432) .hasFieldOrProperty("options") - .hasFieldOrPropertyWithValue("username", "test-username") .hasFieldOrProperty("sslConfig") .hasFieldOrPropertyWithValue("tcpKeepAlive", false) .hasFieldOrPropertyWithValue("tcpNoDelay", true) diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderUnitTests.java b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderUnitTests.java index 7da0e55e..7a1b36bc 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderUnitTests.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderUnitTests.java @@ -25,6 +25,8 @@ import io.r2dbc.spi.ConnectionFactoryOptions; import io.r2dbc.spi.Option; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; import java.time.Duration; import java.util.Arrays; @@ -33,6 +35,7 @@ import java.util.Map; import java.util.Objects; import java.util.TimeZone; +import java.util.function.Supplier; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.AUTODETECT_EXTENSIONS; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.COMPATIBILITY_MODE; @@ -617,6 +620,34 @@ void shouldConfigureExtensions() { assertThat(factory.getConfiguration().getExtensions()).containsExactly(testExtension1, testExtension2); } + @Test + void supportsUsernameAndPasswordSupplier() { + PostgresqlConnectionFactory factory = this.provider.create(builder() + .option(DRIVER, LEGACY_POSTGRESQL_DRIVER) + .option(HOST, "test-host") + .option(Option.valueOf("password"), (Supplier) () -> "test-password") + .option(Option.valueOf("user"), (Supplier) () -> "test-user") + .option(USER, "test-user") + .build()); + + StepVerifier.create(factory.getConfiguration().getPassword()).expectNext("test-password").verifyComplete(); + StepVerifier.create(factory.getConfiguration().getUsername()).expectNext("test-user").verifyComplete(); + } + + @Test + void supportsUsernameAndPasswordPublisher() { + PostgresqlConnectionFactory factory = this.provider.create(builder() + .option(DRIVER, LEGACY_POSTGRESQL_DRIVER) + .option(HOST, "test-host") + .option(Option.valueOf("password"), Mono.just("test-password")) + .option(Option.valueOf("user"), Mono.just("test-user")) + .option(USER, "test-user") + .build()); + + StepVerifier.create(factory.getConfiguration().getPassword()).expectNext("test-password").verifyComplete(); + StepVerifier.create(factory.getConfiguration().getUsername()).expectNext("test-user").verifyComplete(); + } + private static class TestExtension implements Extension { private final String name; diff --git a/src/test/java/io/r2dbc/postgresql/client/ReactorNettyClientIntegrationTests.java b/src/test/java/io/r2dbc/postgresql/client/ReactorNettyClientIntegrationTests.java index c5c64fd8..ecb94ef0 100644 --- a/src/test/java/io/r2dbc/postgresql/client/ReactorNettyClientIntegrationTests.java +++ b/src/test/java/io/r2dbc/postgresql/client/ReactorNettyClientIntegrationTests.java @@ -507,7 +507,7 @@ private void client(Function