diff --git a/http-client-core/src/main/java/io/micronaut/http/client/HttpClientConfiguration.java b/http-client-core/src/main/java/io/micronaut/http/client/HttpClientConfiguration.java index 16119133eae..29a284cc561 100644 --- a/http-client-core/src/main/java/io/micronaut/http/client/HttpClientConfiguration.java +++ b/http-client-core/src/main/java/io/micronaut/http/client/HttpClientConfiguration.java @@ -36,8 +36,11 @@ import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.temporal.ChronoUnit; +import java.util.Arrays; import java.util.Collections; +import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.OptionalInt; import java.util.concurrent.ThreadFactory; @@ -145,7 +148,16 @@ public abstract class HttpClientConfiguration { private String eventLoopGroup = "default"; - private HttpVersion httpVersion = HttpVersion.HTTP_1_1; + @Deprecated + @Nullable + private HttpVersion httpVersion = null; + + private HttpVersionSelection.PlaintextMode plaintextMode = HttpVersionSelection.PlaintextMode.HTTP_1; + + private List alpnModes = Arrays.asList( + HttpVersionSelection.ALPN_HTTP_2, + HttpVersionSelection.ALPN_HTTP_1 + ); private LogLevel logLevel; @@ -201,7 +213,11 @@ public HttpClientConfiguration(HttpClientConfiguration copy) { /** * The HTTP version to use. Defaults to {@link HttpVersion#HTTP_1_1}. * @return The http version + * @deprecated There are now separate settings for HTTP and HTTPS connections. To configure + * HTTP connections (e.g. for h2c), use {@link #plaintextMode}. To configure ALPN, set + * {@link #alpnModes}. */ + @Deprecated public HttpVersion getHttpVersion() { return httpVersion; } @@ -209,7 +225,11 @@ public HttpVersion getHttpVersion() { /** * Sets the HTTP version to use. Defaults to {@link HttpVersion#HTTP_1_1}. * @param httpVersion The http version + * @deprecated There are now separate settings for HTTP and HTTPS connections. To configure + * HTTP connections (e.g. for h2c), use {@link #plaintextMode}. To configure ALPN, set + * {@link #alpnModes}. */ + @Deprecated public void setHttpVersion(HttpVersion httpVersion) { if (httpVersion != null) { this.httpVersion = httpVersion; @@ -637,6 +657,58 @@ public Proxy resolveProxy(boolean isSsl, String host, int port) { } } + /** + * The connection mode to use for plaintext (http as opposed to https) connections. + *
+ * Note: If {@link #httpVersion} is set, this setting is ignored! + * + * @return The plaintext connection mode. + * @since 4.0.0 + */ + @NonNull + public HttpVersionSelection.PlaintextMode getPlaintextMode() { + return plaintextMode; + } + + /** + * The connection mode to use for plaintext (http as opposed to https) connections. + *
+ * Note: If {@link #httpVersion} is set, this setting is ignored! + * + * @param plaintextMode The plaintext connection mode. + * @since 4.0.0 + */ + public void setPlaintextMode(@NonNull HttpVersionSelection.PlaintextMode plaintextMode) { + this.plaintextMode = Objects.requireNonNull(plaintextMode, "plaintextMode"); + } + + /** + * The protocols to support for TLS ALPN. If HTTP 2 is included, this will also restrict the + * TLS cipher suites to those supported by the HTTP 2 standard. + *
+ * Note: If {@link #httpVersion} is set, this setting is ignored! + * + * @return The supported ALPN protocols. + * @since 4.0.0 + */ + @NonNull + public List getAlpnModes() { + return alpnModes; + } + + /** + * The protocols to support for TLS ALPN. If HTTP 2 is included, this will also restrict the + * TLS cipher suites to those supported by the HTTP 2 standard. + *
+ * Note: If {@link #httpVersion} is set, this setting is ignored! + * + * @param alpnModes The supported ALPN protocols. + * @since 4.0.0 + */ + public void setAlpnModes(@NonNull List alpnModes) { + this.alpnModes = Objects.requireNonNull(alpnModes, "alpnModes"); + } + /** * Configuration for the HTTP client connnection pool. */ @@ -650,15 +722,13 @@ public static class ConnectionPoolConfiguration implements Toggleable { * The default enable value. */ @SuppressWarnings("WeakerAccess") - public static final boolean DEFAULT_ENABLED = false; + public static final boolean DEFAULT_ENABLED = true; - /** - * The default max connections value. - */ - @SuppressWarnings("WeakerAccess") - public static final int DEFAULT_MAXCONNECTIONS = -1; + private int maxPendingConnections = 4; - private int maxConnections = DEFAULT_MAXCONNECTIONS; + private int maxConcurrentRequestsPerHttp2Connection = Integer.MAX_VALUE; + private int maxConcurrentHttp1Connections = Integer.MAX_VALUE; + private int maxConcurrentHttp2Connections = 1; private int maxPendingAcquires = Integer.MAX_VALUE; @@ -685,24 +755,6 @@ public void setEnabled(boolean enabled) { this.enabled = enabled; } - /** - * The maximum number of connections. Defaults to ({@value io.micronaut.http.client.HttpClientConfiguration.ConnectionPoolConfiguration#DEFAULT_MAXCONNECTIONS}); no maximum. - * - * @return The max connections - */ - public int getMaxConnections() { - return maxConnections; - } - - /** - * Sets the maximum number of connections. Defaults to no maximum. - * - * @param maxConnections The count - */ - public void setMaxConnections(int maxConnections) { - this.maxConnections = maxConnections; - } - /** * Maximum number of futures awaiting connection acquisition. Defaults to no maximum. * @@ -738,5 +790,90 @@ public Optional getAcquireTimeout() { public void setAcquireTimeout(@Nullable Duration acquireTimeout) { this.acquireTimeout = acquireTimeout; } + + /** + * The maximum number of pending (new) connections before they are assigned to a + * pool. + * + * @return The maximum number of pending connections + * @since 4.0.0 + */ + public int getMaxPendingConnections() { + return maxPendingConnections; + } + + /** + * The maximum number of pending (new) connections before they are assigned to a + * pool. + * + * @param maxPendingConnections The maximum number of pending connections + * @since 4.0.0 + */ + public void setMaxPendingConnections(int maxPendingConnections) { + this.maxPendingConnections = maxPendingConnections; + } + + /** + * The maximum number of requests (streams) that can run concurrently on one HTTP2 + * connection. + * + * @return The maximum concurrent request count + * @since 4.0.0 + */ + public int getMaxConcurrentRequestsPerHttp2Connection() { + return maxConcurrentRequestsPerHttp2Connection; + } + + /** + * The maximum number of requests (streams) that can run concurrently on one HTTP2 + * connection. + * + * @param maxConcurrentRequestsPerHttp2Connection The maximum concurrent request count + * @since 4.0.0 + */ + public void setMaxConcurrentRequestsPerHttp2Connection(int maxConcurrentRequestsPerHttp2Connection) { + this.maxConcurrentRequestsPerHttp2Connection = maxConcurrentRequestsPerHttp2Connection; + } + + /** + * The maximum number of concurrent HTTP1 connections in the pool. + * + * @return The maximum concurrent connection count + * @since 4.0.0 + */ + public int getMaxConcurrentHttp1Connections() { + return maxConcurrentHttp1Connections; + } + + /** + * The maximum number of concurrent HTTP1 connections in the pool. + * + * @param maxConcurrentHttp1Connections The maximum concurrent connection count + * @since 4.0.0 + */ + public void setMaxConcurrentHttp1Connections(int maxConcurrentHttp1Connections) { + this.maxConcurrentHttp1Connections = maxConcurrentHttp1Connections; + } + + /** + * The maximum number of concurrent HTTP2 connections in the pool. + * + * @return The maximum concurrent connection count + * @since 4.0.0 + */ + public int getMaxConcurrentHttp2Connections() { + return maxConcurrentHttp2Connections; + } + + /** + * The maximum number of concurrent HTTP2 connections in the pool. + * + * @param maxConcurrentHttp2Connections The maximum concurrent connection count + * @since 4.0.0 + */ + public void setMaxConcurrentHttp2Connections(int maxConcurrentHttp2Connections) { + this.maxConcurrentHttp2Connections = maxConcurrentHttp2Connections; + } } + } diff --git a/http-client-core/src/main/java/io/micronaut/http/client/HttpClientRegistry.java b/http-client-core/src/main/java/io/micronaut/http/client/HttpClientRegistry.java index 8872c9927bf..068a1c315de 100644 --- a/http-client-core/src/main/java/io/micronaut/http/client/HttpClientRegistry.java +++ b/http-client-core/src/main/java/io/micronaut/http/client/HttpClientRegistry.java @@ -47,9 +47,25 @@ public interface HttpClientRegistry { * @param clientId The client ID * @param path The path (Optional) * @return The client + * @deprecated Use {@link #getClient(HttpVersionSelection, String, String)} instead */ + @Deprecated @NonNull - T getClient(HttpVersion httpVersion, @NonNull String clientId, @Nullable String path); + default T getClient(HttpVersion httpVersion, @NonNull String clientId, @Nullable String path) { + return getClient(HttpVersionSelection.forLegacyVersion(httpVersion), clientId, path); + } + + /** + * Return the client for the client ID and path. + * + * @param httpVersion The HTTP version + * @param clientId The client ID + * @param path The path (Optional) + * @return The client + * @since 4.0.0 + */ + @NonNull + T getClient(@NonNull HttpVersionSelection httpVersion, @NonNull String clientId, @Nullable String path); /** * Resolves a {@link HttpClient} for the given injection point. diff --git a/http-client-core/src/main/java/io/micronaut/http/client/HttpVersionSelection.java b/http-client-core/src/main/java/io/micronaut/http/client/HttpVersionSelection.java new file mode 100644 index 00000000000..4d8d8afce7a --- /dev/null +++ b/http-client-core/src/main/java/io/micronaut/http/client/HttpVersionSelection.java @@ -0,0 +1,197 @@ +/* + * Copyright 2017-2022 original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.micronaut.http.client; + +import io.micronaut.core.annotation.AnnotationMetadata; +import io.micronaut.core.annotation.Internal; +import io.micronaut.core.annotation.NonNull; +import io.micronaut.core.annotation.Nullable; +import io.micronaut.http.HttpVersion; +import io.micronaut.http.client.annotation.Client; + +import java.util.Arrays; + +/** + * This class collects information about HTTP client protocol version settings, such as the + * {@link PlaintextMode} and the ALPN configuration. + * + * @author Jonas Konrad + * @since 4.0 + */ +public final class HttpVersionSelection { + /** + * ALPN protocol ID for HTTP/1.1. + */ + public static final String ALPN_HTTP_1 = "http/1.1"; + /** + * ALPN protocol ID for HTTP/2. + */ + public static final String ALPN_HTTP_2 = "h2"; + + private static final HttpVersionSelection LEGACY_1 = new HttpVersionSelection( + PlaintextMode.HTTP_1, + false, + new String[]{ALPN_HTTP_1}, + false + ); + + private static final HttpVersionSelection LEGACY_2 = new HttpVersionSelection( + PlaintextMode.H2C, + true, + new String[]{ALPN_HTTP_1, ALPN_HTTP_2}, + true + ); + + private final PlaintextMode plaintextMode; + private final boolean alpn; + private final String[] alpnSupportedProtocols; + private final boolean http2CipherSuites; + + private HttpVersionSelection(@NonNull PlaintextMode plaintextMode, boolean alpn, @NonNull String[] alpnSupportedProtocols, boolean http2CipherSuites) { + this.plaintextMode = plaintextMode; + this.alpn = alpn; + this.alpnSupportedProtocols = alpnSupportedProtocols; + this.http2CipherSuites = http2CipherSuites; + } + + /** + * Get the {@link HttpVersionSelection} that matches Micronaut HTTP client 3.x behavior for the + * given version setting. + * + * @param httpVersion The HTTP version as configured for Micronaut HTTP client 3.x + * @return The version selection + */ + @NonNull + public static HttpVersionSelection forLegacyVersion(@NonNull HttpVersion httpVersion) { + switch (httpVersion) { + case HTTP_1_0: + case HTTP_1_1: + return LEGACY_1; + case HTTP_2_0: + return LEGACY_2; + default: + throw new IllegalArgumentException("HTTP version " + httpVersion + " not supported here"); + } + } + + /** + * Construct a version selection from the given client configuration. + * + * @param clientConfiguration The client configuration + * @return The configured version selection + */ + public static HttpVersionSelection forClientConfiguration(HttpClientConfiguration clientConfiguration) { + @SuppressWarnings("deprecation") + HttpVersion legacyHttpVersion = clientConfiguration.getHttpVersion(); + if (legacyHttpVersion != null) { + return forLegacyVersion(legacyHttpVersion); + } else { + String[] alpnModes = clientConfiguration.getAlpnModes().toArray(new String[0]); + return new HttpVersionSelection( + clientConfiguration.getPlaintextMode(), + true, + alpnModes, + Arrays.asList(alpnModes).contains(ALPN_HTTP_2) + ); + } + } + + /** + * Infer the version selection for the given {@link Client} annotation, if any version settings + * are set. + * + * @param metadata The annotation metadata possibly containing a {@link Client} annotation + * @return The configured version selection, or {@code null} if the version is not explicitly + * set and should be inherited from the normal configuration instead. + */ + @Internal + @Nullable + public static HttpVersionSelection forClientAnnotation(AnnotationMetadata metadata) { + HttpVersion legacyHttpVersion = + metadata.enumValue(Client.class, "httpVersion", HttpVersion.class).orElse(null); + if (legacyHttpVersion != null) { + return forLegacyVersion(legacyHttpVersion); + } else { + String[] alpnModes = metadata.stringValues(Client.class, "alpnModes"); + PlaintextMode plaintextMode = metadata.enumValue(Client.class, "plaintextMode", PlaintextMode.class) + .orElse(null); + if (alpnModes.length == 0 && plaintextMode == null) { + // nothing set at all, default to client configuration + return null; + } + + // defaults + if (alpnModes.length == 0) { + alpnModes = new String[]{ALPN_HTTP_2, ALPN_HTTP_1}; + } + if (plaintextMode == null) { + plaintextMode = PlaintextMode.HTTP_1; + } + return new HttpVersionSelection( + plaintextMode, + true, + alpnModes, + Arrays.asList(alpnModes).contains(ALPN_HTTP_2) + ); + } + } + + /** + * @return Connection mode to use for plaintext connections + */ + @Internal + public PlaintextMode getPlaintextMode() { + return plaintextMode; + } + + /** + * @return Protocols that should be shown as supported via ALPN + */ + @Internal + public String[] getAlpnSupportedProtocols() { + return alpnSupportedProtocols; + } + + /** + * @return Whether ALPN should be used + */ + @Internal + public boolean isAlpn() { + return alpn; + } + + /** + * @return Whether TLS cipher suites should be constrained to those defined by the HTTP/2 spec + */ + @Internal + public boolean isHttp2CipherSuites() { + return http2CipherSuites; + } + + /** + * The connection mode to use for plaintext (non-TLS) connections. + */ + public enum PlaintextMode { + /** + * Normal HTTP/1.1 connection. + */ + HTTP_1, + /** + * HTTP/2 cleartext upgrade from HTTP/1.1. + */ + H2C, + } +} diff --git a/http-client-core/src/main/java/io/micronaut/http/client/ServiceHttpClientFactory.java b/http-client-core/src/main/java/io/micronaut/http/client/ServiceHttpClientFactory.java index ca9f17c7d95..2aed9421670 100644 --- a/http-client-core/src/main/java/io/micronaut/http/client/ServiceHttpClientFactory.java +++ b/http-client-core/src/main/java/io/micronaut/http/client/ServiceHttpClientFactory.java @@ -98,7 +98,7 @@ ApplicationEventListener healthCheckStarter(@Parameter Servi Collection loadBalancedURIs = instanceList.getLoadBalancedURIs(); final HttpClient httpClient = clientFactory.get() .getClient( - configuration.getHttpVersion(), + HttpVersionSelection.forClientConfiguration(configuration), configuration.getServiceId(), configuration.getPath().orElse(null)); final Duration initialDelay = configuration.getHealthCheckInterval(); diff --git a/http-client-core/src/main/java/io/micronaut/http/client/annotation/Client.java b/http-client-core/src/main/java/io/micronaut/http/client/annotation/Client.java index 71794cdb4f4..5eb3cb824ac 100644 --- a/http-client-core/src/main/java/io/micronaut/http/client/annotation/Client.java +++ b/http-client-core/src/main/java/io/micronaut/http/client/annotation/Client.java @@ -18,8 +18,10 @@ import io.micronaut.aop.Introduction; import io.micronaut.context.annotation.AliasFor; import io.micronaut.context.annotation.Type; +import io.micronaut.core.annotation.NonNull; import io.micronaut.http.HttpVersion; import io.micronaut.http.client.HttpClientConfiguration; +import io.micronaut.http.client.HttpVersionSelection; import io.micronaut.http.client.interceptor.HttpClientIntroductionAdvice; import io.micronaut.http.hateoas.JsonError; import io.micronaut.retry.annotation.Recoverable; @@ -80,6 +82,36 @@ * The HTTP version. * * @return The HTTP version of the client. + * @deprecated There are now separate settings for HTTP and HTTPS connections. To configure + * HTTP connections (e.g. for h2c), use {@link #plaintextMode}. To configure ALPN, set + * {@link #alpnModes}. */ + @Deprecated HttpVersion httpVersion() default HttpVersion.HTTP_1_1; + + /** + * The connection mode to use for plaintext (http as opposed to https) connections. + *
+ * Note: If {@link #httpVersion} is set, this setting is ignored! + * + * @return The plaintext connection mode. + * @since 4.0.0 + */ + @NonNull + HttpVersionSelection.PlaintextMode plaintextMode() default HttpVersionSelection.PlaintextMode.HTTP_1; + + /** + * The protocols to support for TLS ALPN. If HTTP 2 is included, this will also restrict the + * TLS cipher suites to those supported by the HTTP 2 standard. + *
+ * Note: If {@link #httpVersion} is set, this setting is ignored! + * + * @return The supported ALPN protocols. + * @since 4.0.0 + */ + @NonNull + String[] alpnModes() default { + HttpVersionSelection.ALPN_HTTP_2, + HttpVersionSelection.ALPN_HTTP_1 + }; } diff --git a/http-client/src/main/java/io/micronaut/http/client/netty/CancellableMonoSink.java b/http-client/src/main/java/io/micronaut/http/client/netty/CancellableMonoSink.java new file mode 100644 index 00000000000..c56a84accf5 --- /dev/null +++ b/http-client/src/main/java/io/micronaut/http/client/netty/CancellableMonoSink.java @@ -0,0 +1,140 @@ +/* + * Copyright 2017-2022 original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.micronaut.http.client.netty; + +import io.micronaut.core.annotation.Internal; +import io.micronaut.core.annotation.NonNull; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +/** + * Version of {@link Sinks#one()} where cancellation of the {@link Mono} will make future emit + * calls fail. + * + * @param Element type + */ +@Internal +final class CancellableMonoSink implements Publisher, Sinks.One, Subscription { + private static final Object EMPTY = new Object(); + + private T value; + private Throwable failure; + private boolean complete = false; + private Subscriber subscriber = null; + private boolean subscriberWaiting = false; + + @Override + public synchronized void subscribe(Subscriber s) { + if (this.subscriber != null) { + s.onError(new IllegalStateException("Only one subscriber allowed")); + } + subscriber = s; + subscriber.onSubscribe(this); + } + + private void tryForward() { + if (subscriberWaiting && complete) { + if (failure == null) { + if (value != EMPTY) { + subscriber.onNext(value); + } + subscriber.onComplete(); + } else { + subscriber.onError(failure); + } + } + } + + @NonNull + @Override + public synchronized Sinks.EmitResult tryEmitValue(T value) { + if (complete) { + return Sinks.EmitResult.FAIL_OVERFLOW; + } else { + this.value = value; + complete = true; + tryForward(); + return Sinks.EmitResult.OK; + } + } + + @Override + public void emitValue(T value, @NonNull Sinks.EmitFailureHandler failureHandler) { + throw new UnsupportedOperationException(); + } + + @SuppressWarnings("unchecked") + @NonNull + @Override + public Sinks.EmitResult tryEmitEmpty() { + return tryEmitValue((T) EMPTY); + } + + @NonNull + @Override + public synchronized Sinks.EmitResult tryEmitError(@NonNull Throwable error) { + if (complete) { + return Sinks.EmitResult.FAIL_OVERFLOW; + } else { + this.failure = error; + complete = true; + tryForward(); + return Sinks.EmitResult.OK; + } + } + + @Override + public void emitEmpty(@NonNull Sinks.EmitFailureHandler failureHandler) { + throw new UnsupportedOperationException(); + } + + @Override + public void emitError(@NonNull Throwable error, @NonNull Sinks.EmitFailureHandler failureHandler) { + throw new UnsupportedOperationException(); + } + + @Override + public synchronized int currentSubscriberCount() { + return subscriber == null ? 0 : 1; + } + + @NonNull + @Override + public Mono asMono() { + return Mono.from(this); + } + + @Override + public Object scanUnsafe(@NonNull Attr key) { + return null; + } + + @Override + public synchronized void request(long n) { + if (n > 0 && !subscriberWaiting) { + subscriberWaiting = true; + tryForward(); + } + } + + @Override + public synchronized void cancel() { + complete = true; + } +} diff --git a/http-client/src/main/java/io/micronaut/http/client/netty/ConnectTTLHandler.java b/http-client/src/main/java/io/micronaut/http/client/netty/ConnectTTLHandler.java deleted file mode 100644 index 724611fb18a..00000000000 --- a/http-client/src/main/java/io/micronaut/http/client/netty/ConnectTTLHandler.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright 2017-2020 original authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.micronaut.http.client.netty; - -import io.netty.channel.Channel; -import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.util.AttributeKey; -import io.netty.util.concurrent.ScheduledFuture; - -import java.util.concurrent.TimeUnit; - -/** - * A handler that will close channels after they have reached their time-to-live, regardless of usage. - * - * channels that are in use will be closed when they are next - * released to the underlying connection pool. - */ -public class ConnectTTLHandler extends ChannelDuplexHandler { - - public static final AttributeKey RELEASE_CHANNEL = AttributeKey.newInstance("release_channel"); - - private final Long connectionTtlMillis; - private ScheduledFuture channelKiller; - - /** - * Construct ConnectTTLHandler for given arguments. - * @param connectionTtlMillis The configured connect-ttl - */ - public ConnectTTLHandler(Long connectionTtlMillis) { - if (connectionTtlMillis <= 0) { - throw new IllegalArgumentException("connectTTL must be positive"); - } - this.connectionTtlMillis = connectionTtlMillis; - } - - /** - * Will schedule a task when the handler added. - * @param ctx The context to use - * @throws Exception - */ - @Override - public void handlerAdded(ChannelHandlerContext ctx) throws Exception { - super.handlerAdded(ctx); - channelKiller = ctx.channel().eventLoop().schedule(() -> markChannelExpired(ctx), connectionTtlMillis, TimeUnit.MILLISECONDS); - } - - /** - * Will cancel the scheduled tasks when handler removed. - * @param ctx The context to use - */ - @Override - public void handlerRemoved(ChannelHandlerContext ctx) { - channelKiller.cancel(false); - } - - /** - * Will set RELEASE_CHANNEL as true for the channel attribute when connect-ttl is reached. - * @param ctx The context to use - */ - private void markChannelExpired(ChannelHandlerContext ctx) { - if (ctx.channel().isOpen()) { - ctx.channel().attr(RELEASE_CHANNEL).set(true); - } - } - - /** - * Indicates whether the channels connection ttl has expired. - * @param channel The channel to check - * @return true if the channels ttl has expired - */ - public static boolean isChannelExpired(Channel channel) { - return Boolean.TRUE.equals(channel.attr(ConnectTTLHandler.RELEASE_CHANNEL).get()); - } -} diff --git a/http-client/src/main/java/io/micronaut/http/client/netty/ConnectionManager.java b/http-client/src/main/java/io/micronaut/http/client/netty/ConnectionManager.java index 28d8755172c..7221852440c 100644 --- a/http-client/src/main/java/io/micronaut/http/client/netty/ConnectionManager.java +++ b/http-client/src/main/java/io/micronaut/http/client/netty/ConnectionManager.java @@ -20,71 +20,47 @@ import io.micronaut.core.annotation.Nullable; import io.micronaut.core.reflect.InstantiationUtils; import io.micronaut.core.util.StringUtils; -import io.micronaut.http.HttpVersion; +import io.micronaut.core.util.SupplierUtil; import io.micronaut.http.client.HttpClientConfiguration; +import io.micronaut.http.client.HttpVersionSelection; import io.micronaut.http.client.exceptions.HttpClientException; import io.micronaut.http.client.netty.ssl.NettyClientSslBuilder; import io.micronaut.http.netty.channel.ChannelPipelineCustomizer; -import io.micronaut.http.netty.channel.ChannelPipelineListener; import io.micronaut.http.netty.channel.NettyThreadFactory; -import io.micronaut.http.netty.stream.DefaultHttp2Content; -import io.micronaut.http.netty.stream.Http2Content; -import io.micronaut.http.netty.stream.HttpStreamsClientHandler; -import io.micronaut.http.netty.stream.StreamingInboundHttp2ToHttpAdapter; import io.micronaut.scheduling.instrument.Instrumentation; import io.micronaut.scheduling.instrument.InvocationInstrumenter; import io.micronaut.websocket.exceptions.WebSocketSessionException; import io.netty.bootstrap.Bootstrap; -import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelFactory; import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPipeline; -import io.netty.channel.ChannelPromise; import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.channel.pool.AbstractChannelPoolHandler; -import io.netty.channel.pool.AbstractChannelPoolMap; -import io.netty.channel.pool.ChannelHealthChecker; -import io.netty.channel.pool.ChannelPool; -import io.netty.channel.pool.ChannelPoolMap; -import io.netty.channel.pool.FixedChannelPool; -import io.netty.channel.pool.SimpleChannelPool; -import io.netty.channel.socket.SocketChannel; -import io.netty.handler.codec.LineBasedFrameDecoder; +import io.netty.handler.codec.DecoderException; import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultHttpContent; -import io.netty.handler.codec.http.FullHttpMessage; import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpClientUpgradeHandler; -import io.netty.handler.codec.http.HttpContent; import io.netty.handler.codec.http.HttpContentDecompressor; import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpMessage; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpObjectAggregator; -import io.netty.handler.codec.http.HttpUtil; -import io.netty.handler.codec.http.LastHttpContent; import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketClientCompressionHandler; -import io.netty.handler.codec.http2.DefaultHttp2Connection; -import io.netty.handler.codec.http2.DelegatingDecompressorFrameListener; import io.netty.handler.codec.http2.Http2ClientUpgradeCodec; -import io.netty.handler.codec.http2.Http2Connection; -import io.netty.handler.codec.http2.Http2FrameListener; +import io.netty.handler.codec.http2.Http2FrameCodec; +import io.netty.handler.codec.http2.Http2FrameCodecBuilder; import io.netty.handler.codec.http2.Http2FrameLogger; -import io.netty.handler.codec.http2.Http2Settings; -import io.netty.handler.codec.http2.Http2Stream; -import io.netty.handler.codec.http2.HttpConversionUtil; -import io.netty.handler.codec.http2.HttpToHttp2ConnectionHandler; -import io.netty.handler.codec.http2.HttpToHttp2ConnectionHandlerBuilder; -import io.netty.handler.codec.http2.InboundHttp2ToHttpAdapterBuilder; +import io.netty.handler.codec.http2.Http2MultiplexHandler; +import io.netty.handler.codec.http2.Http2SettingsFrame; +import io.netty.handler.codec.http2.Http2StreamChannel; +import io.netty.handler.codec.http2.Http2StreamChannelBootstrap; +import io.netty.handler.codec.http2.Http2StreamFrameToHttpObjectCodec; import io.netty.handler.logging.LoggingHandler; import io.netty.handler.proxy.HttpProxyHandler; import io.netty.handler.proxy.Socks5ProxyHandler; @@ -92,89 +68,97 @@ import io.netty.handler.ssl.ApplicationProtocolNegotiationHandler; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslHandler; -import io.netty.handler.timeout.IdleStateEvent; +import io.netty.handler.ssl.SslHandshakeCompletionEvent; import io.netty.handler.timeout.IdleStateHandler; +import io.netty.handler.timeout.ReadTimeoutException; import io.netty.handler.timeout.ReadTimeoutHandler; import io.netty.resolver.NoopAddressResolverGroup; -import io.netty.util.Attribute; -import io.netty.util.AttributeKey; import io.netty.util.ReferenceCountUtil; +import io.netty.util.ResourceLeakDetector; +import io.netty.util.ResourceLeakDetectorFactory; +import io.netty.util.ResourceLeakTracker; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.GenericFutureListener; -import io.netty.util.concurrent.Promise; -import org.reactivestreams.Publisher; +import io.netty.util.concurrent.ScheduledFuture; import org.slf4j.Logger; -import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; -import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; +import javax.net.ssl.SSLException; import java.net.InetSocketAddress; import java.net.Proxy; import java.net.SocketAddress; import java.time.Duration; -import java.util.Collection; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.OptionalInt; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; /** * Connection manager for {@link DefaultHttpClient}. This class manages the lifecycle of netty * channels (wrapped in {@link PoolHandle}s), including pooling and timeouts. */ @Internal -final class ConnectionManager { - final ChannelPoolMap poolMap; +class ConnectionManager { final InvocationInstrumenter instrumenter; - final HttpVersion httpVersion; - - // not static to avoid build-time initialization by native image - private final AttributeKey CHANNEL_CUSTOMIZER_KEY = - AttributeKey.valueOf("micronaut.http.customizer"); - /** - * Future on a pooled channel that will be completed when the channel has fully connected (e.g. - * TLS handshake has completed). If unset, then no handshake is needed or it has already - * completed. - */ - private final AttributeKey> STREAM_CHANNEL_INITIALIZED = - AttributeKey.valueOf("micronaut.http.streamChannelInitialized"); - private final AttributeKey STREAM_KEY = AttributeKey.valueOf("micronaut.http2.stream"); + private final HttpVersionSelection httpVersion; private final Logger log; + private final Map pools = new ConcurrentHashMap<>(); private EventLoopGroup group; private final boolean shutdownGroup; private final ThreadFactory threadFactory; private final ChannelFactory socketChannelFactory; private Bootstrap bootstrap; private final HttpClientConfiguration configuration; - @Nullable - private final Long readTimeoutMillis; - @Nullable - private final Long connectionTimeAliveMillis; private final SslContext sslContext; private final NettyClientCustomizer clientCustomizer; - private final Collection pipelineListeners; private final String informationalServiceId; + /** + * Copy constructor used by the test suite to patch this manager. + * + * @param from Original connection manager + */ + ConnectionManager(ConnectionManager from) { + this.instrumenter = from.instrumenter; + this.httpVersion = from.httpVersion; + this.log = from.log; + this.group = from.group; + this.shutdownGroup = from.shutdownGroup; + this.threadFactory = from.threadFactory; + this.socketChannelFactory = from.socketChannelFactory; + this.bootstrap = from.bootstrap; + this.configuration = from.configuration; + this.sslContext = from.sslContext; + this.clientCustomizer = from.clientCustomizer; + this.informationalServiceId = from.informationalServiceId; + } + ConnectionManager( Logger log, @Nullable EventLoopGroup eventLoopGroup, - ThreadFactory threadFactory, + @Nullable ThreadFactory threadFactory, HttpClientConfiguration configuration, - HttpVersion httpVersion, + @Nullable HttpVersionSelection httpVersion, InvocationInstrumenter instrumenter, ChannelFactory socketChannelFactory, NettyClientSslBuilder nettyClientSslBuilder, NettyClientCustomizer clientCustomizer, - Collection pipelineListeners, String informationalServiceId) { if (httpVersion == null) { - httpVersion = configuration.getHttpVersion(); + httpVersion = HttpVersionSelection.forClientConfiguration(configuration); } this.log = log; @@ -184,16 +168,9 @@ final class ConnectionManager { this.configuration = configuration; this.instrumenter = instrumenter; this.clientCustomizer = clientCustomizer; - this.pipelineListeners = pipelineListeners; this.informationalServiceId = informationalServiceId; - this.connectionTimeAliveMillis = configuration.getConnectTtl() - .map(duration -> !duration.isNegative() ? duration.toMillis() : null) - .orElse(null); - this.readTimeoutMillis = configuration.getReadTimeout() - .map(duration -> !duration.isNegative() ? duration.toMillis() : null) - .orElse(null); - this.sslContext = nettyClientSslBuilder.build(configuration.getSslConfiguration(), httpVersion).orElse(null); + this.sslContext = nettyClientSslBuilder.build(configuration.getSslConfiguration(), httpVersion); if (eventLoopGroup != null) { group = eventLoopGroup; @@ -205,55 +182,6 @@ final class ConnectionManager { initBootstrap(); - final ChannelHealthChecker channelHealthChecker = channel -> channel.eventLoop().newSucceededFuture(channel.isActive() && !ConnectTTLHandler.isChannelExpired(channel)); - - HttpClientConfiguration.ConnectionPoolConfiguration connectionPoolConfiguration = configuration.getConnectionPoolConfiguration(); - // HTTP/2 defaults to keep alive connections so should we should always use a pool - if (connectionPoolConfiguration.isEnabled() || httpVersion == io.micronaut.http.HttpVersion.HTTP_2_0) { - int maxConnections = connectionPoolConfiguration.getMaxConnections(); - if (maxConnections > -1) { - poolMap = new AbstractChannelPoolMap() { - @Override - protected ChannelPool newPool(DefaultHttpClient.RequestKey key) { - Bootstrap newBootstrap = bootstrap.clone(group); - initBootstrapForProxy(newBootstrap, key.isSecure(), key.getHost(), key.getPort()); - newBootstrap.remoteAddress(key.getRemoteAddress()); - - AbstractChannelPoolHandler channelPoolHandler = newPoolHandler(key); - final long acquireTimeoutMillis = connectionPoolConfiguration.getAcquireTimeout().map(Duration::toMillis).orElse(-1L); - return new FixedChannelPool( - newBootstrap, - channelPoolHandler, - channelHealthChecker, - acquireTimeoutMillis > -1 ? FixedChannelPool.AcquireTimeoutAction.FAIL : null, - acquireTimeoutMillis, - maxConnections, - connectionPoolConfiguration.getMaxPendingAcquires() - - ); - } - }; - } else { - poolMap = new AbstractChannelPoolMap() { - @Override - protected ChannelPool newPool(DefaultHttpClient.RequestKey key) { - Bootstrap newBootstrap = bootstrap.clone(group); - initBootstrapForProxy(newBootstrap, key.isSecure(), key.getHost(), key.getPort()); - newBootstrap.remoteAddress(key.getRemoteAddress()); - - AbstractChannelPoolHandler channelPoolHandler = newPoolHandler(key); - return new SimpleChannelPool( - newBootstrap, - channelPoolHandler, - channelHealthChecker - ); - } - }; - } - } else { - this.poolMap = null; - } - Optional connectTimeout = configuration.getConnectTimeout(); connectTimeout.ifPresent(duration -> bootstrap.option( ChannelOption.CONNECT_TIMEOUT_MILLIS, @@ -301,6 +229,45 @@ private static NioEventLoopGroup createEventLoopGroup(HttpClientConfiguration co return group; } + /** + * For testing. + * + * @return Connected channels in all pools + * @since 4.0.0 + */ + @NonNull + @SuppressWarnings("unused") + List getChannels() { + List channels = new ArrayList<>(); + for (Pool pool : pools.values()) { + pool.forEachConnection(c -> channels.add(((Pool.ConnectionHolder) c).channel)); + } + return channels; + } + + /** + * For testing. + * + * @return Number of running requests + * @since 4.0.0 + */ + @SuppressWarnings("unused") + int liveRequestCount() { + AtomicInteger count = new AtomicInteger(); + for (Pool pool : pools.values()) { + pool.forEachConnection(c -> { + if (c instanceof Pool.Http1ConnectionHolder) { + if (((Pool.Http1ConnectionHolder) c).hasLiveRequests()) { + count.incrementAndGet(); + } + } else { + count.addAndGet(((Pool.Http2ConnectionHolder) c).liveRequests.get()); + } + }); + } + return count.get(); + } + /** * @see DefaultHttpClient#start() */ @@ -323,27 +290,8 @@ private void initBootstrap() { * @see DefaultHttpClient#stop() */ public void shutdown() { - if (poolMap instanceof Iterable) { - Iterable> i = (Iterable) poolMap; - for (Map.Entry entry : i) { - ChannelPool cp = entry.getValue(); - try { - if (cp instanceof SimpleChannelPool) { - addInstrumentedListener(((SimpleChannelPool) cp).closeAsync(), future -> { - if (!future.isSuccess()) { - final Throwable cause = future.cause(); - if (cause != null) { - log.error("Error shutting down HTTP client connection pool: " + cause.getMessage(), cause); - } - } - }); - } else { - cp.close(); - } - } catch (Exception cause) { - log.error("Error shutting down HTTP client connection pool: " + cause.getMessage(), cause); - } - } + for (Pool pool : pools.values()) { + pool.shutdown(); } if (shutdownGroup) { Duration shutdownTimeout = configuration.getShutdownTimeout() @@ -374,54 +322,23 @@ public boolean isRunning() { } /** - * Get a reactive scheduler that runs on the event loop group of this connection manager. - * - * @return A scheduler that runs on the event loop - */ - public Scheduler getEventLoopScheduler() { - return Schedulers.fromExecutor(group); - } - - /** - * Creates an initial connection to the given remote host. + * Use the bootstrap to connect to the given host. Also does some proxy setup. This method is + * protected: The test suite overrides it to return embedded channels instead. * - * @param requestKey The request key to connect to - * @param isStream Is the connection a stream connection - * @param isProxy Is this a streaming proxy - * @param acceptsEvents Whether the connection will accept events - * @param contextConsumer The logic to run once the channel is configured correctly - * @return A ChannelFuture - * @throws HttpClientException If the URI is invalid + * @param requestKey The host to connect to + * @param channelInitializer The initializer to use + * @return Future that terminates when the TCP connection is established. */ - private ChannelFuture doConnect( - DefaultHttpClient.RequestKey requestKey, - boolean isStream, - boolean isProxy, - boolean acceptsEvents, - Consumer contextConsumer) throws HttpClientException { - - SslContext sslCtx = buildSslContext(requestKey); + protected ChannelFuture doConnect(DefaultHttpClient.RequestKey requestKey, ChannelInitializer channelInitializer) { String host = requestKey.getHost(); int port = requestKey.getPort(); Bootstrap localBootstrap = bootstrap.clone(); - initBootstrapForProxy(localBootstrap, sslCtx != null, host, port); - localBootstrap.handler(new HttpClientInitializer( - sslCtx, - host, - port, - isStream, - isProxy, - acceptsEvents, - contextConsumer) - ); - return localBootstrap.connect(host, port); - } - - private void initBootstrapForProxy(Bootstrap localBootstrap, boolean sslCtx, String host, int port) { - Proxy proxy = configuration.resolveProxy(sslCtx, host, port); + Proxy proxy = configuration.resolveProxy(requestKey.isSecure(), host, port); if (proxy.type() != Proxy.Type.DIRECT) { localBootstrap.resolver(NoopAddressResolverGroup.INSTANCE); } + localBootstrap.handler(channelInitializer); + return localBootstrap.connect(host, port); } /** @@ -429,6 +346,7 @@ private void initBootstrapForProxy(Bootstrap localBootstrap, boolean sslCtx, Str * * @return The {@link SslContext} instance */ + @Nullable private SslContext buildSslContext(DefaultHttpClient.RequestKey requestKey) { final SslContext sslCtx; if (requestKey.isSecure()) { @@ -443,128 +361,14 @@ private SslContext buildSslContext(DefaultHttpClient.RequestKey requestKey) { return sslCtx; } - private PoolHandle mockPoolHandle(Channel channel) { - return new PoolHandle(null, channel); - } - - /** - * Get a connection for exchange-like (non-streaming) http client methods. - * - * @param requestKey The remote to connect to - * @param multipart Whether the request should be multipart - * @param acceptEvents Whether the response may be an event stream - * @return A mono that will complete once the channel is ready for transmission - */ - Mono connectForExchange(DefaultHttpClient.RequestKey requestKey, boolean multipart, boolean acceptEvents) { - return Mono.create(emitter -> { - if (poolMap != null && !multipart) { - try { - ChannelPool channelPool = poolMap.get(requestKey); - addInstrumentedListener(channelPool.acquire(), future -> { - if (future.isSuccess()) { - Channel channel = future.get(); - PoolHandle poolHandle = new PoolHandle(channelPool, channel); - Future initFuture = channel.attr(STREAM_CHANNEL_INITIALIZED).get(); - if (initFuture == null) { - emitter.success(poolHandle); - } else { - // we should wait until the handshake completes - addInstrumentedListener(initFuture, f -> { - emitter.success(poolHandle); - }); - } - } else { - Throwable cause = future.cause(); - emitter.error(customizeException(new HttpClientException("Connect Error: " + cause.getMessage(), cause))); - } - }); - } catch (HttpClientException e) { - emitter.error(e); - } - } else { - ChannelFuture connectionFuture = doConnect(requestKey, false, false, acceptEvents, null); - addInstrumentedListener(connectionFuture, future -> { - if (!future.isSuccess()) { - Throwable cause = future.cause(); - emitter.error(customizeException(new HttpClientException("Connect Error: " + cause.getMessage(), cause))); - } else { - emitter.success(mockPoolHandle(connectionFuture.channel())); - } - }); - } - }) - .delayUntil(this::delayUntilHttp2Ready) - .map(poolHandle -> { - addReadTimeoutHandler(poolHandle.channel.pipeline()); - return poolHandle; - }); - } - - private Publisher delayUntilHttp2Ready(PoolHandle poolHandle) { - Http2SettingsHandler settingsHandler = (Http2SettingsHandler) poolHandle.channel.pipeline().get(ChannelPipelineCustomizer.HANDLER_HTTP2_SETTINGS); - if (settingsHandler == null) { - return Flux.empty(); - } - Sinks.Empty empty = Sinks.empty(); - addInstrumentedListener(settingsHandler.promise, future -> { - if (future.isSuccess()) { - empty.tryEmitEmpty(); - } else { - poolHandle.taint(); - poolHandle.release(); - empty.tryEmitError(future.cause()); - } - }); - return empty.asMono(); - } - /** - * Get a connection for streaming http client methods. + * Get a connection for non-websocket http client methods. * * @param requestKey The remote to connect to - * @param isProxy Whether the request is for a {@link io.micronaut.http.client.ProxyHttpClient} call - * @param acceptEvents Whether the response may be an event stream * @return A mono that will complete once the channel is ready for transmission */ - Mono connectForStream(DefaultHttpClient.RequestKey requestKey, boolean isProxy, boolean acceptEvents) { - return Mono.create(emitter -> { - ChannelFuture channelFuture; - try { - if (httpVersion == HttpVersion.HTTP_2_0) { - - channelFuture = doConnect(requestKey, true, isProxy, acceptEvents, channelHandlerContext -> { - try { - final Channel channel = channelHandlerContext.channel(); - emitter.success(mockPoolHandle(channel)); - } catch (Exception e) { - emitter.error(e); - } - }); - } else { - channelFuture = doConnect(requestKey, true, isProxy, acceptEvents, null); - addInstrumentedListener(channelFuture, - (ChannelFutureListener) f -> { - if (f.isSuccess()) { - Channel channel = f.channel(); - emitter.success(mockPoolHandle(channel)); - } else { - Throwable cause = f.cause(); - emitter.error(customizeException(new HttpClientException("Connect error:" + cause.getMessage(), cause))); - } - }); - } - } catch (HttpClientException e) { - emitter.error(e); - return; - } - - // todo: on emitter dispose/cancel, close channel - }) - .delayUntil(this::delayUntilHttp2Ready) - .map(poolHandle -> { - addReadTimeoutHandler(poolHandle.channel.pipeline()); - return poolHandle; - }); + Mono connect(DefaultHttpClient.RequestKey requestKey) { + return pools.computeIfAbsent(requestKey, Pool::new).acquire(); } /** @@ -576,49 +380,48 @@ Mono connectForStream(DefaultHttpClient.RequestKey requestKey, boole * @return A mono that will complete when the handshakes complete */ Mono connectForWebsocket(DefaultHttpClient.RequestKey requestKey, ChannelHandler handler) { - Sinks.Empty initial = Sinks.empty(); - - Bootstrap bootstrap = this.bootstrap.clone(); - SslContext sslContext = buildSslContext(requestKey); - - bootstrap.remoteAddress(requestKey.getHost(), requestKey.getPort()); - initBootstrapForProxy(bootstrap, sslContext != null, requestKey.getHost(), requestKey.getPort()); - bootstrap.handler(new HttpClientInitializer( - sslContext, - requestKey.getHost(), - requestKey.getPort(), - false, - false, - false, - null - ) { + Sinks.Empty initial = new CancellableMonoSink<>(); + + ChannelFuture connectFuture = doConnect(requestKey, new ChannelInitializer() { @Override - protected void addFinalHandler(ChannelPipeline pipeline) { - pipeline.remove(ChannelPipelineCustomizer.HANDLER_HTTP_DECODER); - ReadTimeoutHandler readTimeoutHandler = pipeline.get(ReadTimeoutHandler.class); - if (readTimeoutHandler != null) { - pipeline.remove(readTimeoutHandler); + protected void initChannel(@NonNull Channel ch) { + addLogHandler(ch); + + SslContext sslContext = buildSslContext(requestKey); + if (sslContext != null) { + SslHandler sslHandler = sslContext.newHandler(ch.alloc(), requestKey.getHost(), requestKey.getPort()); + sslHandler.setHandshakeTimeoutMillis(configuration.getSslConfiguration().getHandshakeTimeout().toMillis()); + ch.pipeline().addLast(sslHandler); } + ch.pipeline() + .addLast(ChannelPipelineCustomizer.HANDLER_HTTP_CLIENT_CODEC, new HttpClientCodec()) + .addLast(ChannelPipelineCustomizer.HANDLER_HTTP_AGGREGATOR, new HttpObjectAggregator(configuration.getMaxContentLength())); + Optional readIdleTime = configuration.getReadIdleTimeout(); if (readIdleTime.isPresent()) { Duration duration = readIdleTime.get(); if (!duration.isNegative()) { - pipeline.addLast(ChannelPipelineCustomizer.HANDLER_IDLE_STATE, new IdleStateHandler(duration.toMillis(), duration.toMillis(), duration.toMillis(), TimeUnit.MILLISECONDS)); + ch.pipeline() + .addLast(ChannelPipelineCustomizer.HANDLER_IDLE_STATE, new IdleStateHandler(duration.toMillis(), duration.toMillis(), duration.toMillis(), TimeUnit.MILLISECONDS)); } } try { - pipeline.addLast(WebSocketClientCompressionHandler.INSTANCE); - pipeline.addLast(ChannelPipelineCustomizer.HANDLER_MICRONAUT_WEBSOCKET_CLIENT, handler); - initial.tryEmitEmpty(); + ch.pipeline().addLast(WebSocketClientCompressionHandler.INSTANCE); + ch.pipeline().addLast(ChannelPipelineCustomizer.HANDLER_MICRONAUT_WEBSOCKET_CLIENT, handler); + clientCustomizer.specializeForChannel(ch, NettyClientCustomizer.ChannelRole.CONNECTION).onInitialPipelineBuilt(); + if (initial.tryEmitEmpty().isSuccess()) { + return; + } } catch (Throwable e) { initial.tryEmitError(new WebSocketSessionException("Error opening WebSocket client session: " + e.getMessage(), e)); } + // failed + ch.close(); } }); - - addInstrumentedListener(bootstrap.connect(), future -> { + addInstrumentedListener(connectFuture, future -> { if (!future.isSuccess()) { initial.tryEmitError(future.cause()); } @@ -627,235 +430,6 @@ protected void addFinalHandler(ChannelPipeline pipeline) { return initial.asMono(); } - private AbstractChannelPoolHandler newPoolHandler(DefaultHttpClient.RequestKey key) { - return new AbstractChannelPoolHandler() { - @Override - public void channelCreated(Channel ch) { - Promise streamPipelineBuilt = ch.newPromise(); - ch.attr(STREAM_CHANNEL_INITIALIZED).set(streamPipelineBuilt); - - // make sure the future completes eventually - ChannelHandler failureHandler = new ChannelInboundHandlerAdapter() { - @Override - public void handlerRemoved(ChannelHandlerContext ctx) { - streamPipelineBuilt.trySuccess(null); - } - - @Override - public void channelInactive(ChannelHandlerContext ctx) { - streamPipelineBuilt.trySuccess(null); - ctx.fireChannelInactive(); - } - }; - ch.pipeline().addLast(failureHandler); - - ch.pipeline().addLast(ChannelPipelineCustomizer.HANDLER_HTTP_CLIENT_INIT, new HttpClientInitializer( - key.isSecure() ? sslContext : null, - key.getHost(), - key.getPort(), - false, - false, - false, - null - ) { - @Override - protected void addFinalHandler(ChannelPipeline pipeline) { - // no-op, don't add the stream handler which is not supported - // in the connection pooled scenario - } - - @Override - void onStreamPipelineBuilt() { - super.onStreamPipelineBuilt(); - streamPipelineBuilt.trySuccess(null); - ch.pipeline().remove(failureHandler); - ch.attr(STREAM_CHANNEL_INITIALIZED).set(null); - } - }); - - if (connectionTimeAliveMillis != null) { - ch.pipeline() - .addLast( - ChannelPipelineCustomizer.HANDLER_CONNECT_TTL, - new ConnectTTLHandler(connectionTimeAliveMillis) - ); - } - } - - @Override - public void channelReleased(Channel ch) { - Duration idleTimeout = configuration.getConnectionPoolIdleTimeout().orElse(Duration.ofNanos(0)); - ChannelPipeline pipeline = ch.pipeline(); - if (ch.isOpen()) { - ch.config().setAutoRead(true); - pipeline.addLast(IdlingConnectionHandler.INSTANCE); - if (idleTimeout.toNanos() > 0) { - pipeline.addLast(ChannelPipelineCustomizer.HANDLER_IDLE_STATE, new IdleStateHandler(idleTimeout.toNanos(), idleTimeout.toNanos(), 0, TimeUnit.NANOSECONDS)); - pipeline.addLast(IdleTimeoutHandler.INSTANCE); - } - } - - if (ConnectTTLHandler.isChannelExpired(ch) && ch.isOpen() && !ch.eventLoop().isShuttingDown()) { - ch.close(); - } - - removeReadTimeoutHandler(pipeline); - } - - @Override - public void channelAcquired(Channel ch) throws Exception { - ChannelPipeline pipeline = ch.pipeline(); - if (pipeline.context(IdlingConnectionHandler.INSTANCE) != null) { - pipeline.remove(IdlingConnectionHandler.INSTANCE); - } - if (pipeline.context(ChannelPipelineCustomizer.HANDLER_IDLE_STATE) != null) { - pipeline.remove(ChannelPipelineCustomizer.HANDLER_IDLE_STATE); - } - if (pipeline.context(IdleTimeoutHandler.INSTANCE) != null) { - pipeline.remove(IdleTimeoutHandler.INSTANCE); - } - } - }; - } - - /** - * Configures HTTP/2 for the channel when SSL is enabled. - * - * @param httpClientInitializer The client initializer - * @param ch The channel - * @param sslCtx The SSL context - * @param host The host - * @param port The port - * @param connectionHandler The connection handler - */ - private void configureHttp2Ssl( - HttpClientInitializer httpClientInitializer, - @NonNull SocketChannel ch, - @NonNull SslContext sslCtx, - String host, - int port, - HttpToHttp2ConnectionHandler connectionHandler) { - ChannelPipeline pipeline = ch.pipeline(); - // Specify Host in SSLContext New Handler to add TLS SNI Extension - pipeline.addLast(ChannelPipelineCustomizer.HANDLER_SSL, sslCtx.newHandler(ch.alloc(), host, port)); - // We must wait for the handshake to finish and the protocol to be negotiated before configuring - // the HTTP/2 components of the pipeline. - pipeline.addLast( - ChannelPipelineCustomizer.HANDLER_HTTP2_PROTOCOL_NEGOTIATOR, - new ApplicationProtocolNegotiationHandler(ApplicationProtocolNames.HTTP_2) { - - @Override - public void handlerRemoved(ChannelHandlerContext ctx) { - // the logic to send the request should only be executed once the HTTP/2 - // Connection Preface request has been sent. Once the Preface has been sent and - // removed then this handler is removed so we invoke the remaining logic once - // this handler removed - final Consumer contextConsumer = - httpClientInitializer.contextConsumer; - if (contextConsumer != null) { - contextConsumer.accept(ctx); - } - } - - @Override - protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { - if (ApplicationProtocolNames.HTTP_2.equals(protocol)) { - ChannelPipeline p = ctx.pipeline(); - if (httpClientInitializer.stream) { - // stream consumer manages backpressure and reads - ctx.channel().config().setAutoRead(false); - } - p.addLast( - ChannelPipelineCustomizer.HANDLER_HTTP2_SETTINGS, - new Http2SettingsHandler(ch.newPromise()) - ); - httpClientInitializer.addEventStreamHandlerIfNecessary(p); - httpClientInitializer.addFinalHandler(p); - for (ChannelPipelineListener pipelineListener : pipelineListeners) { - pipelineListener.onConnect(p); - } - } else if (ApplicationProtocolNames.HTTP_1_1.equals(protocol)) { - ChannelPipeline p = ctx.pipeline(); - httpClientInitializer.addHttp1Handlers(p); - } else { - ctx.close(); - throw customizeException(new HttpClientException("Unknown Protocol: " + protocol)); - } - httpClientInitializer.onStreamPipelineBuilt(); - } - }); - - pipeline.addLast(ChannelPipelineCustomizer.HANDLER_HTTP2_CONNECTION, connectionHandler); - } - - /** - * Configures HTTP/2 handling for plaintext (non-SSL) connections. - * - * @param httpClientInitializer The client initializer - * @param ch The channel - * @param connectionHandler The connection handler - */ - private void configureHttp2ClearText( - HttpClientInitializer httpClientInitializer, - @NonNull SocketChannel ch, - @NonNull HttpToHttp2ConnectionHandler connectionHandler) { - HttpClientCodec sourceCodec = new HttpClientCodec(); - Http2ClientUpgradeCodec upgradeCodec = new Http2ClientUpgradeCodec(ChannelPipelineCustomizer.HANDLER_HTTP2_CONNECTION, connectionHandler); - HttpClientUpgradeHandler upgradeHandler = new HttpClientUpgradeHandler(sourceCodec, upgradeCodec, 65536); - - final ChannelPipeline pipeline = ch.pipeline(); - pipeline.addLast(ChannelPipelineCustomizer.HANDLER_HTTP_CLIENT_CODEC, sourceCodec); - httpClientInitializer.settingsHandler = new Http2SettingsHandler(ch.newPromise()); - pipeline.addLast(upgradeHandler); - pipeline.addLast(new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { - ctx.fireUserEventTriggered(evt); - if (evt instanceof HttpClientUpgradeHandler.UpgradeEvent) { - httpClientInitializer.onStreamPipelineBuilt(); - ctx.pipeline().remove(this); - } - } - }); - pipeline.addLast(ChannelPipelineCustomizer.HANDLER_HTTP2_UPGRADE_REQUEST, new H2cUpgradeRequestHandler(httpClientInitializer)); - } - - /** - * Creates a new {@link HttpToHttp2ConnectionHandlerBuilder} for the given HTTP/2 connection object and config. - * - * @param connection The connection - * @param configuration The configuration - * @param stream Whether this is a stream request - * @return The {@link HttpToHttp2ConnectionHandlerBuilder} - */ - @NonNull - private static HttpToHttp2ConnectionHandlerBuilder newHttp2ConnectionHandlerBuilder( - @NonNull Http2Connection connection, @NonNull HttpClientConfiguration configuration, boolean stream) { - final HttpToHttp2ConnectionHandlerBuilder builder = new HttpToHttp2ConnectionHandlerBuilder(); - builder.validateHeaders(true); - final Http2FrameListener http2ToHttpAdapter; - - if (!stream) { - http2ToHttpAdapter = new InboundHttp2ToHttpAdapterBuilder(connection) - .maxContentLength(configuration.getMaxContentLength()) - .validateHttpHeaders(true) - .propagateSettings(true) - .build(); - - } else { - http2ToHttpAdapter = new StreamingInboundHttp2ToHttpAdapter( - connection, - configuration.getMaxContentLength() - ); - } - return builder - .connection(connection) - .frameListener(new DelegatingDecompressorFrameListener( - connection, - http2ToHttpAdapter)); - - } - private void configureProxy(ChannelPipeline pipeline, boolean secure, String host, int port) { Proxy proxy = configuration.resolveProxy(secure, host, port); if (Proxy.NO_PROXY.equals(proxy)) { @@ -898,10 +472,11 @@ private void configureProxy(ChannelPipeline pipeline, boolean secure, String hos } } - > Future addInstrumentedListener( + > void addInstrumentedListener( Future channelFuture, GenericFutureListener listener) { - return channelFuture.addListener(f -> { + channelFuture.addListener(f -> { try (Instrumentation ignored = instrumenter.newInstrumentation()) { + //noinspection unchecked listener.operationComplete((C) f); } }); @@ -912,423 +487,758 @@ private E customizeException(E exc) { return exc; } - private void addReadTimeoutHandler(ChannelPipeline pipeline) { - if (readTimeoutMillis != null) { - if (httpVersion == HttpVersion.HTTP_2_0) { - pipeline.addBefore( - ChannelPipelineCustomizer.HANDLER_HTTP2_CONNECTION, - ChannelPipelineCustomizer.HANDLER_READ_TIMEOUT, - new ReadTimeoutHandler(readTimeoutMillis, TimeUnit.MILLISECONDS) - ); - } else { - pipeline.addBefore( - ChannelPipelineCustomizer.HANDLER_HTTP_CLIENT_CODEC, - ChannelPipelineCustomizer.HANDLER_READ_TIMEOUT, - new ReadTimeoutHandler(readTimeoutMillis, TimeUnit.MILLISECONDS)); + private Http2FrameCodec makeFrameCodec() { + Http2FrameCodecBuilder builder = Http2FrameCodecBuilder.forClient(); + configuration.getLogLevel().ifPresent(logLevel -> { + try { + final io.netty.handler.logging.LogLevel nettyLevel = + io.netty.handler.logging.LogLevel.valueOf(logLevel.name()); + builder.frameLogger(new Http2FrameLogger(nettyLevel, DefaultHttpClient.class)); + } catch (IllegalArgumentException e) { + throw customizeException(new HttpClientException("Unsupported log level: " + logLevel)); } - } + }); + return builder.build(); } - private void removeReadTimeoutHandler(ChannelPipeline pipeline) { - if (readTimeoutMillis != null && pipeline.context(ChannelPipelineCustomizer.HANDLER_READ_TIMEOUT) != null) { - pipeline.remove(ChannelPipelineCustomizer.HANDLER_READ_TIMEOUT); - } + /** + * Initializer for HTTP1.1, called either in plaintext mode, or after ALPN in TLS. + * + * @param ch The plaintext channel + */ + private void initHttp1(Channel ch) { + addLogHandler(ch); + + ch.pipeline() + .addLast(ChannelPipelineCustomizer.HANDLER_HTTP_CLIENT_CODEC, new HttpClientCodec()) + .addLast(ChannelPipelineCustomizer.HANDLER_HTTP_DECODER, new HttpContentDecompressor()); + } + + private void addLogHandler(Channel ch) { + configuration.getLogLevel().ifPresent(logLevel -> { + try { + final io.netty.handler.logging.LogLevel nettyLevel = + io.netty.handler.logging.LogLevel.valueOf(logLevel.name()); + ch.pipeline().addLast(new LoggingHandler(DefaultHttpClient.class, nettyLevel)); + } catch (IllegalArgumentException e) { + throw customizeException(new HttpClientException("Unsupported log level: " + logLevel)); + } + }); } /** - * A handler that triggers the cleartext upgrade to HTTP/2 by sending an initial HTTP request. + * Initializer for HTTP2 multiplexing, called either in h2c mode, or after ALPN in TLS. The + * channel should already contain a {@link #makeFrameCodec() frame codec} that does the HTTP2 + * parsing, this method adds the handlers that do multiplexing, error handling, etc. + * + * @param pool The pool to add the connection to once the handshake is done + * @param ch The plaintext channel + * @param connectionCustomizer Customizer for the connection */ - private class H2cUpgradeRequestHandler extends ChannelInboundHandlerAdapter { + private void initHttp2(Pool pool, Channel ch, NettyClientCustomizer connectionCustomizer) { + Http2MultiplexHandler multiplexHandler = new Http2MultiplexHandler(new ChannelInitializer() { + @Override + protected void initChannel(@NonNull Http2StreamChannel ch) throws Exception { + log.warn("Server opened HTTP2 stream {}, closing immediately", ch.stream().id()); + ch.close(); + } + }, new ChannelInitializer() { + @Override + protected void initChannel(@NonNull Http2StreamChannel ch) throws Exception { + // discard any response data for the upgrade request + ch.close(); + } + }); + ch.pipeline().addLast(multiplexHandler); + ch.pipeline().addLast(ChannelPipelineCustomizer.HANDLER_HTTP2_SETTINGS, new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(@NonNull ChannelHandlerContext ctx, @NonNull Object msg) throws Exception { + if (msg instanceof Http2SettingsFrame) { + ctx.pipeline().remove(ChannelPipelineCustomizer.HANDLER_HTTP2_SETTINGS); + ctx.pipeline().remove(ChannelPipelineCustomizer.HANDLER_INITIAL_ERROR); + pool.new Http2ConnectionHolder(ch, connectionCustomizer).init(); + return; + } else { + log.warn("Premature frame: {}", msg.getClass()); + } + + super.channelRead(ctx, msg); + } + }); + // stream frames should be handled by the multiplexer + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + ctx.read(); + } - private final HttpClientInitializer initializer; + @Override + public void channelRead(@NonNull ChannelHandlerContext ctx, @NonNull Object msg) throws Exception { + log.warn("Unexpected message on HTTP2 connection channel: {}", msg); + ReferenceCountUtil.release(msg); + ctx.read(); + } + }); + } + + /** + * Initializer for TLS channels. After ALPN we will proceed either with + * {@link #initHttp1(Channel)} or {@link #initHttp2(Pool, Channel, NettyClientCustomizer)}. + */ + private final class AdaptiveAlpnChannelInitializer extends ChannelInitializer { + private final Pool pool; + + private final SslContext sslContext; + private final String host; + private final int port; + + AdaptiveAlpnChannelInitializer(Pool pool, + SslContext sslContext, + String host, + int port) { + this.pool = pool; + this.sslContext = sslContext; + this.host = host; + this.port = port; + } /** - * Default constructor. - * - * @param initializer The initializer + * @param ch The channel */ - public H2cUpgradeRequestHandler(HttpClientInitializer initializer) { - this.initializer = initializer; + @Override + protected void initChannel(@NonNull Channel ch) { + NettyClientCustomizer channelCustomizer = clientCustomizer.specializeForChannel(ch, NettyClientCustomizer.ChannelRole.CONNECTION); + + configureProxy(ch.pipeline(), true, host, port); + + SslHandler sslHandler = sslContext.newHandler(ch.alloc(), host, port); + sslHandler.setHandshakeTimeoutMillis(configuration.getSslConfiguration().getHandshakeTimeout().toMillis()); + ch.pipeline() + .addLast(ChannelPipelineCustomizer.HANDLER_SSL, sslHandler) + .addLast( + ChannelPipelineCustomizer.HANDLER_HTTP2_PROTOCOL_NEGOTIATOR, + // if the server doesn't do ALPN, fall back to HTTP 1 + new ApplicationProtocolNegotiationHandler(ApplicationProtocolNames.HTTP_1_1) { + @Override + protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { + if (ApplicationProtocolNames.HTTP_2.equals(protocol)) { + ctx.pipeline().addLast(ChannelPipelineCustomizer.HANDLER_HTTP2_CONNECTION, makeFrameCodec()); + initHttp2(pool, ctx.channel(), channelCustomizer); + } else if (ApplicationProtocolNames.HTTP_1_1.equals(protocol)) { + initHttp1(ctx.channel()); + pool.new Http1ConnectionHolder(ch, channelCustomizer).init(false); + ctx.pipeline().remove(ChannelPipelineCustomizer.HANDLER_INITIAL_ERROR); + } else { + ctx.close(); + throw customizeException(new HttpClientException("Unknown Protocol: " + protocol)); + } + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof SslHandshakeCompletionEvent) { + SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt; + if (!event.isSuccess()) { + InitialConnectionErrorHandler.setFailureCause(ctx.channel(), event.cause()); + } + } + super.userEventTriggered(ctx, evt); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + // let the HANDLER_INITIAL_ERROR handle the failure + if (cause instanceof DecoderException && cause.getCause() instanceof SSLException) { + // unwrap DecoderException + cause = cause.getCause(); + } + ctx.fireExceptionCaught(cause); + } + }) + .addLast(ChannelPipelineCustomizer.HANDLER_INITIAL_ERROR, pool.initialErrorHandler); + + channelCustomizer.onInitialPipelineBuilt(); } + } - @Override - public void channelActive(ChannelHandlerContext ctx) { - // Done with this handler, remove it from the pipeline. - final ChannelPipeline pipeline = ctx.pipeline(); - - pipeline.addLast(ChannelPipelineCustomizer.HANDLER_HTTP2_SETTINGS, initializer.settingsHandler); - DefaultFullHttpRequest upgradeRequest = - new DefaultFullHttpRequest(io.netty.handler.codec.http.HttpVersion.HTTP_1_1, HttpMethod.GET, "/", Unpooled.EMPTY_BUFFER); - - // Set HOST header as the remote peer may require it. - InetSocketAddress remote = (InetSocketAddress) ctx.channel().remoteAddress(); - String hostString = remote.getHostString(); - if (hostString == null) { - hostString = remote.getAddress().getHostAddress(); - } - upgradeRequest.headers().set(HttpHeaderNames.HOST, hostString + ':' + remote.getPort()); - ctx.writeAndFlush(upgradeRequest); + /** + * Initializer for H2C connections. Will proceed with + * {@link #initHttp2(Pool, Channel, NettyClientCustomizer)} when the upgrade is done. + */ + private final class Http2UpgradeInitializer extends ChannelInitializer { + private final Pool pool; - ctx.fireChannelActive(); - if (initializer.contextConsumer != null) { - initializer.contextConsumer.accept(ctx); - } - initializer.addFinalHandler(pipeline); + Http2UpgradeInitializer(Pool pool) { + this.pool = pool; } @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - if (msg instanceof HttpMessage) { - int streamId = ((HttpMessage) msg).headers().getInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), -1); - if (streamId == 1) { - // ignore this message - if (log.isDebugEnabled()) { - log.debug("Received response on HTTP2 stream 1, the stream used to respond to the initial upgrade request. Ignoring."); - } - ReferenceCountUtil.release(msg); - if (msg instanceof LastHttpContent) { - ctx.pipeline().remove(this); + protected void initChannel(@NonNull Channel ch) throws Exception { + NettyClientCustomizer connectionCustomizer = clientCustomizer.specializeForChannel(ch, NettyClientCustomizer.ChannelRole.CONNECTION); + + Http2FrameCodec frameCodec = makeFrameCodec(); + + HttpClientCodec sourceCodec = new HttpClientCodec(); + Http2ClientUpgradeCodec upgradeCodec = new Http2ClientUpgradeCodec(frameCodec, + new ChannelInitializer() { + @Override + protected void initChannel(@NonNull Channel ch) throws Exception { + ch.pipeline().addLast(ChannelPipelineCustomizer.HANDLER_HTTP2_CONNECTION, frameCodec); + initHttp2(pool, ch, connectionCustomizer); } - return; + }); + HttpClientUpgradeHandler upgradeHandler = new HttpClientUpgradeHandler(sourceCodec, upgradeCodec, 65536); + + ch.pipeline().addLast(ChannelPipelineCustomizer.HANDLER_HTTP_CLIENT_CODEC, sourceCodec); + ch.pipeline().addLast(upgradeHandler); + + ch.pipeline().addLast(ChannelPipelineCustomizer.HANDLER_HTTP2_UPGRADE_REQUEST, new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(@NonNull ChannelHandlerContext ctx) throws Exception { + DefaultFullHttpRequest upgradeRequest = + new DefaultFullHttpRequest(io.netty.handler.codec.http.HttpVersion.HTTP_1_1, HttpMethod.GET, "/", Unpooled.EMPTY_BUFFER); + + // Set HOST header as the remote peer may require it. + upgradeRequest.headers().set(HttpHeaderNames.HOST, pool.requestKey.getHost() + ':' + pool.requestKey.getPort()); + ctx.writeAndFlush(upgradeRequest); + ctx.pipeline().remove(ChannelPipelineCustomizer.HANDLER_HTTP2_UPGRADE_REQUEST); + // read the upgrade response + ctx.read(); + + super.channelActive(ctx); } - } + }); + ch.pipeline().addLast(ChannelPipelineCustomizer.HANDLER_INITIAL_ERROR, pool.initialErrorHandler); - super.channelRead(ctx, msg); + connectionCustomizer.onInitialPipelineBuilt(); } } /** - * Reads the first {@link Http2Settings} object and notifies a {@link io.netty.channel.ChannelPromise}. + * Handle for a pooled connection. One pool handle generally corresponds to one request, and + * once the request and response are done, the handle is {@link #release() released} and a new + * request can claim the same connection. */ - private class Http2SettingsHandler extends - SimpleChannelInboundHandlerInstrumented { - final ChannelPromise promise; + abstract static class PoolHandle { + private static final Supplier> LEAK_DETECTOR = SupplierUtil.memoized(() -> + ResourceLeakDetectorFactory.instance().newResourceLeakDetector(PoolHandle.class)); - /** - * Create new instance. - * - * @param promise Promise object used to notify when first settings are received - */ - Http2SettingsHandler(ChannelPromise promise) { - super(instrumenter); - this.promise = promise; - } + final boolean http2; + final Channel channel; - @Override - protected void channelReadInstrumented(ChannelHandlerContext ctx, Http2Settings msg) { - promise.setSuccess(); + boolean released = false; - // Only care about the first settings message - ctx.pipeline().remove(this); - } + private final ResourceLeakTracker tracker = LEAK_DETECTOR.get().track(this); - @Override - public void channelInactive(ChannelHandlerContext ctx) throws Exception { - super.channelInactive(ctx); - if (!promise.isDone()) { - promise.tryFailure(new HttpClientException("Channel became inactive before settings frame was received")); - } + private PoolHandle(boolean http2, Channel channel) { + this.http2 = http2; + this.channel = channel; } - @Override - public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { - super.handlerRemoved(ctx); - if (!promise.isDone()) { - promise.tryFailure(new HttpClientException("Handler was removed before settings frame was received")); + /** + * Prevent this connection from being reused, e.g. because garbage was written because of + * an error. + */ + abstract void taint(); + + /** + * Close this connection or release it back to the pool. + */ + void release() { + if (released) { + throw new IllegalStateException("Already released"); + } + released = true; + if (tracker != null) { + tracker.close(this); } } + + /** + * Whether this connection may be returned to a connection pool (i.e. should be kept + * keepalive). + * + * @return Whether this connection may be reused + */ + abstract boolean canReturn(); + + /** + * Notify any {@link NettyClientCustomizer} that the request pipeline has been built. + */ + abstract void notifyRequestPipelineBuilt(); } /** - * Initializes the HTTP client channel. + * This class represents one pool, and matches to exactly one + * {@link io.micronaut.http.client.netty.DefaultHttpClient.RequestKey} (i.e. host, port and + * protocol are the same for one pool). + *

+ * The superclass {@link PoolResizer} handles pool size management, this class just implements + * the HTTP parts. */ - private class HttpClientInitializer extends ChannelInitializer { - - final SslContext sslContext; - final String host; - final int port; - final boolean stream; - final boolean proxy; - final boolean acceptsEvents; - Http2SettingsHandler settingsHandler; - final Consumer contextConsumer; - private NettyClientCustomizer channelCustomizer; + private final class Pool extends PoolResizer { + private final DefaultHttpClient.RequestKey requestKey; /** - * @param sslContext The ssl context - * @param host The host - * @param port The port - * @param stream Whether is stream - * @param proxy Is this a streaming proxy - * @param acceptsEvents Whether an event stream is accepted - * @param contextConsumer The context consumer + * {@link ChannelHandler} that is added to a connection to report failures during + * handshakes. It's removed once the connection is established and processes requests. */ - protected HttpClientInitializer(SslContext sslContext, - String host, - int port, - boolean stream, - boolean proxy, - boolean acceptsEvents, - Consumer contextConsumer) { - this.sslContext = sslContext; - this.stream = stream; - this.host = host; - this.port = port; - this.proxy = proxy; - this.acceptsEvents = acceptsEvents; - this.contextConsumer = contextConsumer; + private final InitialConnectionErrorHandler initialErrorHandler = new InitialConnectionErrorHandler() { + @Override + protected void onNewConnectionFailure(@Nullable Throwable cause) throws Exception { + Pool.this.onNewConnectionFailure(cause); + } + }; + + Pool(DefaultHttpClient.RequestKey requestKey) { + super(log, configuration.getConnectionPoolConfiguration()); + this.requestKey = requestKey; + } + + Mono acquire() { + Sinks.One sink = new CancellableMonoSink<>(); + addPendingRequest(sink); + Optional acquireTimeout = configuration.getConnectionPoolConfiguration().getAcquireTimeout(); + //noinspection OptionalIsPresent + if (acquireTimeout.isPresent()) { + return sink.asMono().timeout(acquireTimeout.get(), Schedulers.fromExecutor(group)); + } else { + return sink.asMono(); + } } - /** - * @param ch The channel - */ @Override - protected void initChannel(SocketChannel ch) { - channelCustomizer = clientCustomizer.specializeForChannel(ch, NettyClientCustomizer.ChannelRole.CONNECTION); - ch.attr(CHANNEL_CUSTOMIZER_KEY).set(channelCustomizer); - - ChannelPipeline p = ch.pipeline(); - - configureProxy(p, sslContext != null, host, port); - - if (httpVersion == HttpVersion.HTTP_2_0) { - final Http2Connection connection = new DefaultHttp2Connection(false); - final HttpToHttp2ConnectionHandlerBuilder builder = - newHttp2ConnectionHandlerBuilder(connection, configuration, stream); - - configuration.getLogLevel().ifPresent(logLevel -> { - try { - final io.netty.handler.logging.LogLevel nettyLevel = io.netty.handler.logging.LogLevel.valueOf( - logLevel.name() - ); - builder.frameLogger(new Http2FrameLogger(nettyLevel, DefaultHttpClient.class)); - } catch (IllegalArgumentException e) { - throw customizeException(new HttpClientException("Unsupported log level: " + logLevel)); - } - }); - HttpToHttp2ConnectionHandler connectionHandler = builder - .build(); - if (sslContext != null) { - configureHttp2Ssl(this, ch, sslContext, host, port, connectionHandler); + void onNewConnectionFailure(@Nullable Throwable error) throws Exception { + super.onNewConnectionFailure(error); + // to avoid an infinite loop, fail one pending request. + Sinks.One pending = pollPendingRequest(); + if (pending != null) { + HttpClientException wrapped; + if (error == null) { + // no failure observed, but channel closed + wrapped = new HttpClientException("Unknown connect error"); } else { - configureHttp2ClearText(this, ch, connectionHandler); + wrapped = new HttpClientException("Connect Error: " + error.getMessage(), error); } - channelCustomizer.onInitialPipelineBuilt(); - } else { - if (stream) { - // for streaming responses we disable auto read - // so that the consumer is in charge of back pressure - ch.config().setAutoRead(false); - } - - configuration.getLogLevel().ifPresent(logLevel -> { - try { - final io.netty.handler.logging.LogLevel nettyLevel = io.netty.handler.logging.LogLevel.valueOf( - logLevel.name() - ); - p.addLast(new LoggingHandler(DefaultHttpClient.class, nettyLevel)); - } catch (IllegalArgumentException e) { - throw customizeException(new HttpClientException("Unsupported log level: " + logLevel)); - } - }); - - if (sslContext != null) { - SslHandler sslHandler = sslContext.newHandler(ch.alloc(), host, port); - sslHandler.setHandshakeTimeoutMillis(configuration.getSslConfiguration().getHandshakeTimeout().toMillis()); - p.addLast(ChannelPipelineCustomizer.HANDLER_SSL, sslHandler); + if (pending.tryEmitError(customizeException(wrapped)) == Sinks.EmitResult.OK) { + // no need to log + return; } + } + log.error("Failed to connect to remote", error); + } - // Pool connections require alternative timeout handling - if (poolMap == null) { - // read timeout settings are not applied to streamed requests. - // instead idle timeout settings are applied. - if (stream) { - Optional readIdleTime = configuration.getReadIdleTimeout(); - if (readIdleTime.isPresent()) { - Duration duration = readIdleTime.get(); - if (!duration.isNegative()) { - p.addLast(ChannelPipelineCustomizer.HANDLER_IDLE_STATE, new IdleStateHandler( - duration.toMillis(), - duration.toMillis(), - duration.toMillis(), - TimeUnit.MILLISECONDS - )); + @Override + void openNewConnection() { + // open a new connection + ChannelInitializer initializer; + if (requestKey.isSecure()) { + initializer = new AdaptiveAlpnChannelInitializer( + this, + buildSslContext(requestKey), + requestKey.getHost(), + requestKey.getPort() + ); + } else { + switch (httpVersion.getPlaintextMode()) { + case HTTP_1: + initializer = new ChannelInitializer() { + @Override + protected void initChannel(@NonNull Channel ch) throws Exception { + configureProxy(ch.pipeline(), false, requestKey.getHost(), requestKey.getPort()); + initHttp1(ch); + ch.pipeline().addLast(ChannelPipelineCustomizer.HANDLER_ACTIVITY_LISTENER, new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(@NonNull ChannelHandlerContext ctx) throws Exception { + super.channelActive(ctx); + ctx.pipeline().remove(this); + NettyClientCustomizer channelCustomizer = clientCustomizer.specializeForChannel(ch, NettyClientCustomizer.ChannelRole.CONNECTION); + new Http1ConnectionHolder(ch, channelCustomizer).init(true); + } + }); } - } - } + }; + break; + case H2C: + initializer = new Http2UpgradeInitializer(this); + break; + default: + throw new AssertionError("Unknown plaintext mode"); } - - addHttp1Handlers(p); - channelCustomizer.onInitialPipelineBuilt(); - onStreamPipelineBuilt(); } + addInstrumentedListener(doConnect(requestKey, initializer), future -> { + if (!future.isSuccess()) { + onNewConnectionFailure(future.cause()); + } + }); } - /** - * Called when the stream pipeline is fully set up (all handshakes completed) and we can - * start processing requests. - */ - void onStreamPipelineBuilt() { - channelCustomizer.onStreamPipelineBuilt(); + public void shutdown() { + forEachConnection(c -> ((ConnectionHolder) c).channel.close()); } - void addHttp1Handlers(ChannelPipeline p) { - p.addLast(ChannelPipelineCustomizer.HANDLER_HTTP_CLIENT_CODEC, new HttpClientCodec()); + /** + * Base class for one HTTP1/HTTP2 connection. + */ + abstract class ConnectionHolder extends ResizerConnection { + final Channel channel; + final NettyClientCustomizer connectionCustomizer; + /** + * Future for the scheduled task that runs when the configured time-to-live for the + * connection passes. + */ + @Nullable + ScheduledFuture ttlFuture; + volatile boolean windDownConnection = false; + + ConnectionHolder(Channel channel, NettyClientCustomizer connectionCustomizer) { + this.channel = channel; + this.connectionCustomizer = connectionCustomizer; + } - p.addLast(ChannelPipelineCustomizer.HANDLER_HTTP_DECODER, new HttpContentDecompressor()); + /** + * Add connection-level timeout-related handlers to the channel + * (read timeout, TTL, ...). + * + * @param before Reference handler name, the timeout handlers will be placed before + * this handler. + */ + final void addTimeoutHandlers(String before) { + // read timeout handles timeouts *during* a request + configuration.getReadTimeout() + .ifPresent(dur -> channel.pipeline().addBefore(before, ChannelPipelineCustomizer.HANDLER_READ_TIMEOUT, new ReadTimeoutHandler(dur.toNanos(), TimeUnit.NANOSECONDS) { + @Override + protected void readTimedOut(ChannelHandlerContext ctx) { + if (hasLiveRequests()) { + fireReadTimeout(ctx); + ctx.close(); + } + } + })); + // pool idle timeout happens *outside* a request + configuration.getConnectionPoolIdleTimeout() + .ifPresent(dur -> channel.pipeline().addBefore(before, ChannelPipelineCustomizer.HANDLER_IDLE_STATE, new ReadTimeoutHandler(dur.toNanos(), TimeUnit.NANOSECONDS) { + @Override + protected void readTimedOut(ChannelHandlerContext ctx) { + if (!hasLiveRequests()) { + ctx.close(); + } + } + })); + configuration.getConnectTtl().ifPresent(ttl -> + ttlFuture = channel.eventLoop().schedule(this::windDownConnection, ttl.toNanos(), TimeUnit.NANOSECONDS)); + channel.pipeline().addBefore(before, "connection-cleaner", new ChannelInboundHandlerAdapter() { + boolean inactiveCalled = false; - int maxContentLength = configuration.getMaxContentLength(); + @Override + public void channelInactive(@NonNull ChannelHandlerContext ctx) throws Exception { + super.channelInactive(ctx); + if (!inactiveCalled) { + inactiveCalled = true; + onInactive(); + } + } - if (!stream) { - p.addLast(ChannelPipelineCustomizer.HANDLER_HTTP_AGGREGATOR, new HttpObjectAggregator(maxContentLength) { @Override - protected void finishAggregation(FullHttpMessage aggregated) throws Exception { - if (!HttpUtil.isContentLengthSet(aggregated)) { - if (aggregated.content().readableBytes() > 0) { - super.finishAggregation(aggregated); - } + public void handlerRemoved(ChannelHandlerContext ctx) { + if (!inactiveCalled) { + inactiveCalled = true; + onInactive(); } } }); } - addEventStreamHandlerIfNecessary(p); - addFinalHandler(p); - for (ChannelPipelineListener pipelineListener : pipelineListeners) { - pipelineListener.onConnect(p); + + /** + * Stop accepting new requests on this connection, but finish up the running requests + * if possible. + */ + void windDownConnection() { + windDownConnection = true; + } + + /** + * Send the finished pool handle to the given requester, if possible. + * + * @param sink The request for a pool handle + * @param ph The pool handle + */ + final void emitPoolHandle(Sinks.One sink, PoolHandle ph) { + Sinks.EmitResult emitResult = sink.tryEmitValue(ph); + if (emitResult.isFailure()) { + ph.release(); + } else { + if (!configuration.getConnectionPoolConfiguration().isEnabled()) { + // if pooling is off, release the connection after this. + windDownConnection(); + } + } + } + + @Override + public boolean dispatch(Sinks.One sink) { + if (!tryEarmarkForRequest()) { + return false; + } + + if (channel.eventLoop().inEventLoop()) { + dispatch0(sink); + } else { + channel.eventLoop().execute(() -> dispatch0(sink)); + } + return true; + } + + /** + * Called on event loop only. Dispatch a stream/connection to the given pool + * handle request. + * + * @param sink The request for a pool handle + */ + abstract void dispatch0(Sinks.One sink); + + /** + * Try to add a new request to this connection. This is called outside the event loop, + * and if this succeeds, we will proceed with a {@link #dispatch0} call on the + * event loop. + * + * @return {@code true} if the request may be added to this connection + */ + abstract boolean tryEarmarkForRequest(); + + /** + * @return {@code true} iff there are any requests running on this connection. + */ + abstract boolean hasLiveRequests(); + + /** + * Send a read timeout exception to all requests on this connection. + * + * @param ctx The connection-level channel handler context to use. + */ + abstract void fireReadTimeout(ChannelHandlerContext ctx); + + /** + * Called when the connection becomes inactive, i.e. on disconnect. + */ + void onInactive() { + if (ttlFuture != null) { + ttlFuture.cancel(false); + } + windDownConnection = true; } } - void addEventStreamHandlerIfNecessary(ChannelPipeline p) { - // if the content type is a SSE event stream we add a decoder - // to delimit the content by lines (unless we are proxying the stream) - if (acceptsEventStream() && !proxy) { - p.addLast(ChannelPipelineCustomizer.HANDLER_MICRONAUT_SSE_EVENT_STREAM, new LineBasedFrameDecoder(configuration.getMaxContentLength(), true, true) { + final class Http1ConnectionHolder extends ConnectionHolder { + private final AtomicBoolean hasLiveRequest = new AtomicBoolean(false); + + Http1ConnectionHolder(Channel channel, NettyClientCustomizer connectionCustomizer) { + super(channel, connectionCustomizer); + } + + void init(boolean fireInitialPipelineBuilt) { + addTimeoutHandlers( + requestKey.isSecure() ? + ChannelPipelineCustomizer.HANDLER_SSL : + ChannelPipelineCustomizer.HANDLER_HTTP_CLIENT_CODEC + ); + + if (fireInitialPipelineBuilt) { + connectionCustomizer.onInitialPipelineBuilt(); + } + connectionCustomizer.onStreamPipelineBuilt(); + + onNewConnectionEstablished1(this); + } + + @Override + boolean tryEarmarkForRequest() { + return !windDownConnection && hasLiveRequest.compareAndSet(false, true); + } + + @Override + boolean hasLiveRequests() { + return hasLiveRequest.get(); + } + + @Override + void fireReadTimeout(ChannelHandlerContext ctx) { + ctx.fireExceptionCaught(ReadTimeoutException.INSTANCE); + } + + @Override + void dispatch0(Sinks.One sink) { + if (!channel.isActive()) { + returnPendingRequest(sink); + return; + } + PoolHandle ph = new PoolHandle(false, channel) { + final ChannelHandlerContext lastContext = channel.pipeline().lastContext(); @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - if (msg instanceof HttpContent) { - if (msg instanceof LastHttpContent) { - super.channelRead(ctx, msg); - } else { - Attribute streamKey = ctx.channel().attr(STREAM_KEY); - if (msg instanceof Http2Content) { - streamKey.set(((Http2Content) msg).stream()); - } - try { - super.channelRead(ctx, ((HttpContent) msg).content()); - } finally { - streamKey.set(null); - } + void taint() { + windDownConnection = true; + } + + @Override + void release() { + super.release(); + if (!windDownConnection) { + ChannelHandlerContext newLast = channel.pipeline().lastContext(); + if (lastContext != newLast) { + log.warn("BUG - Handler not removed: {}", newLast); + taint(); } + } + if (!windDownConnection) { + hasLiveRequest.set(false); + markConnectionAvailable(); } else { - super.channelRead(ctx, msg); + channel.close(); } } - }); - - p.addLast(ChannelPipelineCustomizer.HANDLER_MICRONAUT_SSE_CONTENT, new SimpleChannelInboundHandlerInstrumented(instrumenter, false) { @Override - public boolean acceptInboundMessage(Object msg) { - return msg instanceof ByteBuf; + boolean canReturn() { + return !windDownConnection; } @Override - protected void channelReadInstrumented(ChannelHandlerContext ctx, ByteBuf msg) { - try { - Attribute streamKey = ctx.channel().attr(STREAM_KEY); - Http2Stream http2Stream = streamKey.get(); - if (http2Stream != null) { - ctx.fireChannelRead(new DefaultHttp2Content(msg.copy(), http2Stream)); - } else { - ctx.fireChannelRead(new DefaultHttpContent(msg.copy())); - } - } finally { - msg.release(); - } + void notifyRequestPipelineBuilt() { + connectionCustomizer.onRequestPipelineBuilt(); } - }); + }; + emitPoolHandle(sink, ph); + } + private void returnPendingRequest(Sinks.One sink) { + // failed, but the pending request may still work on another connection. + addPendingRequest(sink); + hasLiveRequest.set(false); } - } - /** - * Allows overriding the final handler added to the pipeline. - * - * @param pipeline The pipeline - */ - protected void addFinalHandler(ChannelPipeline pipeline) { - pipeline.addLast( - ChannelPipelineCustomizer.HANDLER_HTTP_STREAM, - new HttpStreamsClientHandler() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { - if (evt instanceof IdleStateEvent) { - // close the connection if it is idle for too long - ctx.close(); - } - super.userEventTriggered(ctx, evt); + @Override + void windDownConnection() { + super.windDownConnection(); + if (!hasLiveRequest.get()) { + channel.close(); } - }); - } + } - private boolean acceptsEventStream() { - return this.acceptsEvents; + @Override + void onInactive() { + super.onInactive(); + onConnectionInactive1(this); + } } - } - final class PoolHandle { - final Channel channel; - private final ChannelPool channelPool; - private boolean canReturn; + final class Http2ConnectionHolder extends ConnectionHolder { + private final AtomicInteger liveRequests = new AtomicInteger(0); + private final Set liveStreamChannels = new HashSet<>(); // todo: https://github.com/netty/netty/pull/12830 - private PoolHandle(ChannelPool channelPool, Channel channel) { - this.channel = channel; - this.channelPool = channelPool; - this.canReturn = channelPool != null; - } + Http2ConnectionHolder(Channel channel, NettyClientCustomizer customizer) { + super(channel, customizer); + } - /** - * Prevent this connection from being reused. - */ - void taint() { - canReturn = false; - } + void init() { + addTimeoutHandlers( + requestKey.isSecure() ? + ChannelPipelineCustomizer.HANDLER_SSL : + ChannelPipelineCustomizer.HANDLER_HTTP2_CONNECTION + ); - /** - * Close this connection or release it back to the pool. - */ - void release() { - if (channelPool != null) { - removeReadTimeoutHandler(channel.pipeline()); - if (!canReturn) { - channel.closeFuture().addListener((future -> - channelPool.release(channel) - )); - } else { - channelPool.release(channel); + connectionCustomizer.onStreamPipelineBuilt(); + + onNewConnectionEstablished2(this); + } + + @Override + boolean tryEarmarkForRequest() { + return !windDownConnection && incrementWithLimit(liveRequests, configuration.getConnectionPoolConfiguration().getMaxConcurrentRequestsPerHttp2Connection()); + } + + @Override + boolean hasLiveRequests() { + return liveRequests.get() > 0; + } + + @Override + void fireReadTimeout(ChannelHandlerContext ctx) { + for (Channel sc : liveStreamChannels) { + sc.pipeline().fireExceptionCaught(ReadTimeoutException.INSTANCE); } - } else { - // just close it to prevent any future reads without a handler registered - channel.close(); } - } - /** - * Whether this connection may be returned to a connection pool (i.e. should be kept - * keepalive). - * - * @return Whether this connection may be reused - */ - public boolean canReturn() { - return canReturn; - } + @Override + void dispatch0(Sinks.One sink) { + if (!channel.isActive() || windDownConnection) { + returnPendingRequest(sink); + return; + } + addInstrumentedListener(new Http2StreamChannelBootstrap(channel).open(), (Future future) -> { + if (future.isSuccess()) { + Http2StreamChannel streamChannel = future.get(); + streamChannel.pipeline() + .addLast(new Http2StreamFrameToHttpObjectCodec(false)) + .addLast(ChannelPipelineCustomizer.HANDLER_HTTP_DECOMPRESSOR, new HttpContentDecompressor()); + NettyClientCustomizer streamCustomizer = connectionCustomizer.specializeForChannel(streamChannel, NettyClientCustomizer.ChannelRole.HTTP2_STREAM); + PoolHandle ph = new PoolHandle(true, streamChannel) { + @Override + void taint() { + // do nothing, we don't reuse stream channels + } - /** - * Notify any {@link NettyClientCustomizer} that the request pipeline has been built. - */ - void notifyRequestPipelineBuilt() { - channel.attr(CHANNEL_CUSTOMIZER_KEY).get().onRequestPipelineBuilt(); + @Override + void release() { + super.release(); + liveStreamChannels.remove(streamChannel); + streamChannel.close(); + int newCount = liveRequests.decrementAndGet(); + if (windDownConnection && newCount <= 0) { + Http2ConnectionHolder.this.channel.close(); + } else { + markConnectionAvailable(); + } + } + + @Override + boolean canReturn() { + return true; + } + + @Override + void notifyRequestPipelineBuilt() { + streamCustomizer.onRequestPipelineBuilt(); + } + }; + liveStreamChannels.add(streamChannel); + emitPoolHandle(sink, ph); + } else { + log.debug("Failed to open http2 stream", future.cause()); + returnPendingRequest(sink); + } + }); + } + + private void returnPendingRequest(Sinks.One sink) { + // failed, but the pending request may still work on another connection. + addPendingRequest(sink); + liveRequests.decrementAndGet(); + } + + @Override + void windDownConnection() { + super.windDownConnection(); + if (liveRequests.get() == 0) { + channel.close(); + } + } + + @Override + void onInactive() { + super.onInactive(); + onConnectionInactive2(this); + } } } } diff --git a/http-client/src/main/java/io/micronaut/http/client/netty/DefaultHttpClient.java b/http-client/src/main/java/io/micronaut/http/client/netty/DefaultHttpClient.java index 0990fbd7597..c1cdb016b40 100644 --- a/http-client/src/main/java/io/micronaut/http/client/netty/DefaultHttpClient.java +++ b/http-client/src/main/java/io/micronaut/http/client/netty/DefaultHttpClient.java @@ -47,6 +47,7 @@ import io.micronaut.http.client.DefaultHttpClientConfiguration; import io.micronaut.http.client.HttpClient; import io.micronaut.http.client.HttpClientConfiguration; +import io.micronaut.http.client.HttpVersionSelection; import io.micronaut.http.client.LoadBalancer; import io.micronaut.http.client.ProxyHttpClient; import io.micronaut.http.client.ProxyRequestOptions; @@ -80,8 +81,8 @@ import io.micronaut.http.netty.NettyHttpRequestBuilder; import io.micronaut.http.netty.NettyHttpResponseBuilder; import io.micronaut.http.netty.channel.ChannelPipelineCustomizer; -import io.micronaut.http.netty.channel.ChannelPipelineListener; import io.micronaut.http.netty.stream.DefaultStreamedHttpResponse; +import io.micronaut.http.netty.stream.HttpStreamsClientHandler; import io.micronaut.http.netty.stream.JsonSubscriber; import io.micronaut.http.netty.stream.StreamedHttpRequest; import io.micronaut.http.netty.stream.StreamedHttpResponse; @@ -112,6 +113,7 @@ import io.netty.channel.ChannelFactory; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; import io.netty.channel.MultithreadEventLoopGroup; @@ -122,16 +124,19 @@ import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.DefaultLastHttpContent; import io.netty.handler.codec.http.EmptyHttpHeaders; +import io.netty.handler.codec.http.FullHttpMessage; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpContent; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderValues; import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpScheme; import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.LastHttpContent; import io.netty.handler.codec.http.multipart.DefaultHttpDataFactory; import io.netty.handler.codec.http.multipart.FileUpload; import io.netty.handler.codec.http.multipart.HttpDataFactory; @@ -144,6 +149,7 @@ import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.DefaultThreadFactory; +import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Promise; import org.reactivestreams.Processor; import org.reactivestreams.Publisher; @@ -172,6 +178,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Objects; import java.util.Optional; import java.util.concurrent.ThreadFactory; @@ -232,7 +239,7 @@ public class DefaultHttpClient implements protected MediaTypeCodecRegistry mediaTypeCodecRegistry; protected ByteBufferFactory byteBufferFactory = new NettyByteBufferFactory(); - final ConnectionManager connectionManager; + ConnectionManager connectionManager; private final List> clientFilterEntries; private final LoadBalancer loadBalancer; @@ -245,6 +252,7 @@ public class DefaultHttpClient implements private final RequestBinderRegistry requestBinderRegistry; private final List invocationInstrumenterFactories; private final String informationalServiceId; + private final ConversionService conversionService; /** * Construct a client for the given arguments. @@ -269,7 +277,7 @@ public DefaultHttpClient(@Nullable LoadBalancer loadBalancer, List invocationInstrumenterFactories, HttpClientFilter... filters) { this(loadBalancer, - configuration.getHttpVersion(), + null, configuration, contextPath, new DefaultHttpClientFilterResolver(annotationMetadataResolver, Arrays.asList(filters)), @@ -281,9 +289,10 @@ public DefaultHttpClient(@Nullable LoadBalancer loadBalancer, new DefaultRequestBinderRegistry(ConversionService.SHARED), null, NioSocketChannel::new, - Collections.emptySet(), CompositeNettyClientCustomizer.EMPTY, - invocationInstrumenterFactories, null); + invocationInstrumenterFactories, + null, + ConversionService.SHARED); } /** @@ -301,13 +310,12 @@ public DefaultHttpClient(@Nullable LoadBalancer loadBalancer, * @param requestBinderRegistry The request binder registry * @param eventLoopGroup The event loop group to use * @param socketChannelFactory The socket channel factory - * @param pipelineListeners The listeners to call for pipeline customization * @param clientCustomizer The pipeline customizer * @param invocationInstrumenterFactories The invocation instrumeter factories to instrument netty handlers execution with * @param informationalServiceId Optional service ID that will be passed to exceptions created by this client */ public DefaultHttpClient(@Nullable LoadBalancer loadBalancer, - @Nullable io.micronaut.http.HttpVersion explicitHttpVersion, + @Nullable HttpVersionSelection explicitHttpVersion, @NonNull HttpClientConfiguration configuration, @Nullable String contextPath, @NonNull HttpClientFilterResolver filterResolver, @@ -319,10 +327,10 @@ public DefaultHttpClient(@Nullable LoadBalancer loadBalancer, @NonNull RequestBinderRegistry requestBinderRegistry, @Nullable EventLoopGroup eventLoopGroup, @NonNull ChannelFactory socketChannelFactory, - Collection pipelineListeners, NettyClientCustomizer clientCustomizer, List invocationInstrumenterFactories, - @Nullable String informationalServiceId + @Nullable String informationalServiceId, + @NonNull ConversionService conversionService ) { ArgumentUtils.requireNonNull("nettyClientSslBuilder", nettyClientSslBuilder); ArgumentUtils.requireNonNull("codecRegistry", codecRegistry); @@ -359,6 +367,7 @@ public DefaultHttpClient(@Nullable LoadBalancer loadBalancer, this.webSocketRegistry = webSocketBeanRegistry != null ? webSocketBeanRegistry : WebSocketBeanRegistry.EMPTY; this.requestBinderRegistry = requestBinderRegistry; this.informationalServiceId = informationalServiceId; + this.conversionService = conversionService; this.connectionManager = new ConnectionManager( log, @@ -370,7 +379,6 @@ public DefaultHttpClient(@Nullable LoadBalancer loadBalancer, socketChannelFactory, nettyClientSslBuilder, clientCustomizer, - pipelineListeners, informationalServiceId); } @@ -523,6 +531,11 @@ public O retrieve(io.micronaut.http.HttpRequest request, Argument MutableHttpRequest toMutableRequest(io.micronaut.http.HttpRequest request) { + return MutableHttpRequestWrapper.wrapIfNecessary(conversionService, request); + } + @SuppressWarnings("SubscriberImplementation") @Override public Publisher>> eventStream(@NonNull io.micronaut.http.HttpRequest request) { @@ -691,7 +704,7 @@ public Publisher> dataStream(@NonNull io.micronaut.http.HttpRe public Publisher> dataStream(@NonNull io.micronaut.http.HttpRequest request, @NonNull Argument errorType) { final io.micronaut.http.HttpRequest parentRequest = ServerRequestContext.currentRequest().orElse(null); return new MicronautFlux<>(Flux.from(resolveRequestURI(request)) - .flatMap(requestURI -> dataStreamImpl(request, errorType, parentRequest, requestURI))) + .flatMap(requestURI -> dataStreamImpl(toMutableRequest(request), errorType, parentRequest, requestURI))) .doAfterNext(buffer -> { Object o = buffer.asNativeBuffer(); if (o instanceof ByteBuf) { @@ -712,7 +725,7 @@ public Publisher>> exchangeStre public Publisher>> exchangeStream(@NonNull io.micronaut.http.HttpRequest request, @NonNull Argument errorType) { io.micronaut.http.HttpRequest parentRequest = ServerRequestContext.currentRequest().orElse(null); return new MicronautFlux<>(Flux.from(resolveRequestURI(request)) - .flatMap(uri -> exchangeStreamImpl(parentRequest, request, errorType, uri))) + .flatMap(uri -> exchangeStreamImpl(parentRequest, toMutableRequest(request), errorType, uri))) .doAfterNext(byteBufferHttpResponse -> { ByteBuffer buffer = byteBufferHttpResponse.body(); if (buffer instanceof ReferenceCounted) { @@ -730,7 +743,7 @@ public Publisher jsonStream(@NonNull io.micronaut.http.HttpRequest public Publisher jsonStream(@NonNull io.micronaut.http.HttpRequest request, @NonNull Argument type, @NonNull Argument errorType) { final io.micronaut.http.HttpRequest parentRequest = ServerRequestContext.currentRequest().orElse(null); return Flux.from(resolveRequestURI(request)) - .flatMap(requestURI -> jsonStreamImpl(parentRequest, request, type, errorType, requestURI)); + .flatMap(requestURI -> jsonStreamImpl(parentRequest, toMutableRequest(request), type, errorType, requestURI)); } @SuppressWarnings("unchecked") @@ -749,7 +762,7 @@ public Publisher> exchange(@NonNull final io.micronaut.http.HttpRequest parentRequest = ServerRequestContext.currentRequest().orElse(null); Publisher uriPublisher = resolveRequestURI(request); return Flux.from(uriPublisher) - .switchMap(uri -> exchangeImpl(uri, parentRequest, request, bodyType, errorType)); + .switchMap(uri -> exchangeImpl(uri, parentRequest, toMutableRequest(request), bodyType, errorType)); } @Override @@ -844,8 +857,8 @@ private Publisher connectWebSocket(URI uri, MutableHttpRequest request .then(handler.getHandshakeCompletedMono()); } - private Flux>> exchangeStreamImpl(io.micronaut.http.HttpRequest parentRequest, io.micronaut.http.HttpRequest request, Argument errorType, URI requestURI) { - Flux> streamResponsePublisher = Flux.from(buildStreamExchange(parentRequest, request, requestURI, errorType)); + private Flux>> exchangeStreamImpl(io.micronaut.http.HttpRequest parentRequest, MutableHttpRequest request, Argument errorType, URI requestURI) { + Flux> streamResponsePublisher = Flux.from(buildStreamExchange(parentRequest, request, requestURI, errorType)); return streamResponsePublisher.switchMap(response -> { StreamedHttpResponse streamedHttpResponse = NettyHttpResponseBuilder.toStreamResponse(response); Flux httpContentReactiveSequence = Flux.from(streamedHttpResponse); @@ -863,19 +876,11 @@ private Flux>> exchangeStreamImpl(io.micronaut.ht thisResponse.setBody(byteBuffer); return (HttpResponse>) new HttpResponseWrapper<>(thisResponse); }); - }).doOnTerminate(() -> { - final Object o = request.getAttribute(NettyClientHttpRequest.CHANNEL).orElse(null); - if (o instanceof Channel) { - final Channel c = (Channel) o; - if (c.isOpen()) { - c.close(); - } - } }); } - private Flux jsonStreamImpl(io.micronaut.http.HttpRequest parentRequest, io.micronaut.http.HttpRequest request, Argument type, Argument errorType, URI requestURI) { - Flux> streamResponsePublisher = + private Flux jsonStreamImpl(io.micronaut.http.HttpRequest parentRequest, MutableHttpRequest request, Argument type, Argument errorType, URI requestURI) { + Flux> streamResponsePublisher = Flux.from(buildStreamExchange(parentRequest, request, requestURI, errorType)); return streamResponsePublisher.switchMap(response -> { if (!(response instanceof NettyStreamedHttpResponse)) { @@ -907,19 +912,11 @@ private Flux jsonStreamImpl(io.micronaut.http.HttpRequest parentReq }, streamArray); return Flux.from(jsonProcessor) .map(jsonNode -> mediaTypeCodec.decode(type, jsonNode)); - }).doOnTerminate(() -> { - final Object o = request.getAttribute(NettyClientHttpRequest.CHANNEL).orElse(null); - if (o instanceof Channel) { - final Channel c = (Channel) o; - if (c.isOpen()) { - c.close(); - } - } }); } - private Flux> dataStreamImpl(io.micronaut.http.HttpRequest request, Argument errorType, io.micronaut.http.HttpRequest parentRequest, URI requestURI) { - Flux> streamResponsePublisher = Flux.from(buildStreamExchange(parentRequest, request, requestURI, errorType)); + private Flux> dataStreamImpl(MutableHttpRequest request, Argument errorType, io.micronaut.http.HttpRequest parentRequest, URI requestURI) { + Flux> streamResponsePublisher = Flux.from(buildStreamExchange(parentRequest, request, requestURI, errorType)); Function> contentMapper = message -> { ByteBuf byteBuf = message.content(); return byteBufferFactory.wrap(byteBuf); @@ -933,15 +930,6 @@ private Flux> dataStreamImpl(io.micronaut.http.HttpRequest return httpContentReactiveSequence .filter(message -> !(message.content() instanceof EmptyByteBuf)) .map(contentMapper); - }) - .doOnTerminate(() -> { - final Object o = request.getAttribute(NettyClientHttpRequest.CHANNEL).orElse(null); - if (o instanceof Channel) { - final Channel c = (Channel) o; - if (c.isOpen()) { - c.close(); - } - } }); } @@ -949,14 +937,14 @@ private Flux> dataStreamImpl(io.micronaut.http.HttpRequest * Implementation of {@link #jsonStream}, {@link #dataStream}, {@link #exchangeStream}. */ @SuppressWarnings("MagicNumber") - private Publisher> buildStreamExchange( + private Publisher> buildStreamExchange( @Nullable io.micronaut.http.HttpRequest parentRequest, - @NonNull io.micronaut.http.HttpRequest request, + @NonNull MutableHttpRequest request, @NonNull URI requestURI, @Nullable Argument errorType) { - AtomicReference> requestWrapper = new AtomicReference<>(request); - Flux> streamResponsePublisher = connectAndStream(parentRequest, request, requestURI, requestWrapper, false, true); + AtomicReference> requestWrapper = new AtomicReference<>(request); + Flux> streamResponsePublisher = connectAndStream(parentRequest, request, requestURI, requestWrapper, false, true); streamResponsePublisher = readBodyOnError(errorType, streamResponsePublisher); @@ -965,7 +953,7 @@ private Publisher> buildStreamExchange( applyFilterToResponsePublisher(parentRequest, request, requestURI, requestWrapper, streamResponsePublisher) ); - return streamResponsePublisher.subscribeOn(connectionManager.getEventLoopScheduler()); + return streamResponsePublisher; } @Override @@ -978,15 +966,13 @@ public Publisher> proxy(@NonNull io.micronaut.http.HttpRe Objects.requireNonNull(options, "options"); return Flux.from(resolveRequestURI(request)) .flatMap(requestURI -> { - io.micronaut.http.MutableHttpRequest httpRequest = request instanceof MutableHttpRequest - ? (io.micronaut.http.MutableHttpRequest) request - : request.mutate(); + io.micronaut.http.MutableHttpRequest httpRequest = toMutableRequest(request); if (!options.isRetainHostHeader()) { httpRequest.headers(headers -> headers.remove(HttpHeaderNames.HOST)); } - AtomicReference> requestWrapper = new AtomicReference<>(httpRequest); - Flux> proxyResponsePublisher = connectAndStream(request, request, requestURI, requestWrapper, true, false); + AtomicReference> requestWrapper = new AtomicReference<>(httpRequest); + Flux> proxyResponsePublisher = connectAndStream(request, request, requestURI, requestWrapper, true, false); // apply filters //noinspection unchecked proxyResponsePublisher = Flux.from( @@ -1002,11 +988,11 @@ public Publisher> proxy(@NonNull io.micronaut.http.HttpRe }); } - private Flux> connectAndStream( + private Flux> connectAndStream( io.micronaut.http.HttpRequest parentRequest, io.micronaut.http.HttpRequest request, URI requestURI, - AtomicReference> requestWrapper, + AtomicReference> requestWrapper, boolean isProxy, boolean failOnError ) { @@ -1016,8 +1002,42 @@ private Flux> connectAndStream( } catch (Exception e) { return Flux.error(e); } - return connectionManager.connectForStream(requestKey, isProxy, isAcceptEvents(request)).flatMapMany(poolHandle -> { + return connectionManager.connect(requestKey).flatMapMany(poolHandle -> { request.setAttribute(NettyClientHttpRequest.CHANNEL, poolHandle.channel); + + boolean sse = !isProxy && isAcceptEvents(request); + poolHandle.channel.pipeline().addLast(new ChannelInboundHandlerAdapter() { + boolean ignoreOneLast = false; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof io.netty.handler.codec.http.HttpResponse && + ((io.netty.handler.codec.http.HttpResponse) msg).status().equals(HttpResponseStatus.CONTINUE)) { + ignoreOneLast = true; + } + + super.channelRead(ctx, msg); + + if (msg instanceof LastHttpContent) { + if (ignoreOneLast) { + ignoreOneLast = false; + } else { + ctx.pipeline().remove(ChannelPipelineCustomizer.HANDLER_HTTP_STREAM); + ctx.pipeline().remove(this); + } + } + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + poolHandle.release(); + } + }); + if (sse) { + poolHandle.channel.pipeline().addLast(HttpLineBasedFrameDecoder.NAME, new HttpLineBasedFrameDecoder(configuration.getMaxContentLength(), true, true)); + } + poolHandle.channel.pipeline().addLast(ChannelPipelineCustomizer.HANDLER_HTTP_STREAM, new HttpStreamsClientHandler()); + return this.streamRequestThroughChannel( parentRequest, requestWrapper.get(), @@ -1034,10 +1054,10 @@ private Flux> connectAndStream( private Publisher> exchangeImpl( URI requestURI, io.micronaut.http.HttpRequest parentRequest, - io.micronaut.http.HttpRequest request, + MutableHttpRequest request, @NonNull Argument bodyType, @NonNull Argument errorType) { - AtomicReference> requestWrapper = new AtomicReference<>(request); + AtomicReference> requestWrapper = new AtomicReference<>(request); RequestKey requestKey; try { @@ -1046,9 +1066,22 @@ private Publisher> exchang return Flux.error(e); } - Mono handlePublisher = connectionManager.connectForExchange(requestKey, MediaType.MULTIPART_FORM_DATA_TYPE.equals(request.getContentType().orElse(null)), isAcceptEvents(request)); + Mono handlePublisher = connectionManager.connect(requestKey); Flux> responsePublisher = handlePublisher.flatMapMany(poolHandle -> { + poolHandle.channel.pipeline() + .addLast(ChannelPipelineCustomizer.HANDLER_HTTP_AGGREGATOR, new HttpObjectAggregator(configuration.getMaxContentLength()) { + @Override + protected void finishAggregation(FullHttpMessage aggregated) throws Exception { + // only set content-length if there's any content + if (!HttpUtil.isContentLengthSet(aggregated) && + aggregated.content().readableBytes() > 0) { + super.finishAggregation(aggregated); + } + } + }) + .addLast(ChannelPipelineCustomizer.HANDLER_HTTP_STREAM, new HttpStreamsClientHandler()); + return Flux.create(emitter -> { try { sendRequestThroughChannel( @@ -1056,7 +1089,6 @@ private Publisher> exchang bodyType, errorType, emitter, - poolHandle.channel, requestKey.isSecure(), poolHandle ); @@ -1082,7 +1114,7 @@ private Publisher> exchang final Duration rt = readTimeout.get(); if (!rt.isNegative()) { Duration duration = rt.plus(Duration.ofSeconds(1)); - finalReactiveSequence = finalReactiveSequence.timeout(duration) + finalReactiveSequence = finalReactiveSequence.timeout(duration) // todo: move to CM .onErrorResume(throwable -> { if (throwable instanceof TimeoutException) { return Flux.error(ReadTimeoutException.TIMEOUT_EXCEPTION); @@ -1167,11 +1199,11 @@ protected Object getLoadBalancerDiscriminator() { return null; } - private > Publisher applyFilterToResponsePublisher( + private > Publisher applyFilterToResponsePublisher( io.micronaut.http.HttpRequest parentRequest, io.micronaut.http.HttpRequest request, URI requestURI, - AtomicReference> requestWrapper, + AtomicReference> requestWrapper, Publisher responsePublisher) { if (request instanceof MutableHttpRequest) { @@ -1253,7 +1285,7 @@ protected NettyRequestWriter buildNettyRequest( if (Publishers.isConvertibleToPublisher(bodyValue)) { boolean isSingle = Publishers.isSingle(bodyValue.getClass()); - Publisher publisher = ConversionService.SHARED.convert(bodyValue, Publisher.class).orElseThrow(() -> + Publisher publisher = conversionService.convert(bodyValue, Publisher.class).orElseThrow(() -> new IllegalArgumentException("Unconvertible reactive type: " + bodyValue) ); @@ -1335,7 +1367,7 @@ protected NettyRequestWriter buildNettyRequest( .orElse(null); } if (bodyContent == null) { - bodyContent = ConversionService.SHARED.convert(bodyValue, ByteBuf.class).orElseThrow(() -> + bodyContent = conversionService.convert(bodyValue, ByteBuf.class).orElseThrow(() -> customizeException(new HttpClientException("Body [" + bodyValue + "] cannot be encoded to content type [" + requestContentType + "]. No possible codecs or converters found.")) ); } @@ -1359,7 +1391,7 @@ protected NettyRequestWriter buildNettyRequest( return new NettyRequestWriter(nettyRequest, postRequestEncoder); } - private Flux> readBodyOnError(@Nullable Argument errorType, @NonNull Flux> publisher) { + private Flux> readBodyOnError(@Nullable Argument errorType, @NonNull Flux> publisher) { if (errorType != null && errorType != HttpClient.DEFAULT_ERROR_TYPE) { return publisher.onErrorResume(clientException -> { if (clientException instanceof HttpClientResponseException) { @@ -1441,7 +1473,6 @@ private void sendRequestThroughChannel( Argument bodyType, Argument errorType, FluxSink> emitter, - Channel channel, boolean secure, ConnectionManager.PoolHandle poolHandle) throws HttpPostRequestEncoder.ErrorDataEncoderException { URI requestURI = finalRequest.getUri(); @@ -1467,11 +1498,11 @@ private void sendRequestThroughChannel( HttpRequest nettyRequest = requestWriter.getNettyRequest(); prepareHttpHeaders( - requestURI, - finalRequest, - nettyRequest, - permitsBody, - !poolHandle.canReturn() + poolHandle, + requestURI, + finalRequest, + nettyRequest, + permitsBody ); if (log.isDebugEnabled()) { @@ -1482,8 +1513,8 @@ private void sendRequestThroughChannel( traceRequest(finalRequest, nettyRequest); } - Promise> responsePromise = channel.eventLoop().newPromise(); - channel.pipeline().addLast(ChannelPipelineCustomizer.HANDLER_MICRONAUT_FULL_HTTP_RESPONSE, + Promise> responsePromise = poolHandle.channel.eventLoop().newPromise(); + poolHandle.channel.pipeline().addLast(ChannelPipelineCustomizer.HANDLER_MICRONAUT_FULL_HTTP_RESPONSE, new FullHttpResponseHandler<>(responsePromise, poolHandle, secure, finalRequest, bodyType, errorType)); poolHandle.notifyRequestPipelineBuilt(); Publisher> publisher = new NettyFuturePublisher<>(responsePromise, true); @@ -1493,16 +1524,16 @@ private void sendRequestThroughChannel( } publisher.subscribe(new ForwardingSubscriber<>(emitter)); - requestWriter.write(channel, secure, emitter); + requestWriter.write(poolHandle, secure, emitter); } - private Flux> streamRequestThroughChannel( + private Flux> streamRequestThroughChannel( io.micronaut.http.HttpRequest parentRequest, - io.micronaut.http.HttpRequest request, + MutableHttpRequest request, ConnectionManager.PoolHandle poolHandle, boolean failOnError, boolean secure) { - return Flux.>create(sink -> { + return Flux.>create(sink -> { try { streamRequestThroughChannel0(parentRequest, request, sink, poolHandle, secure); } catch (HttpPostRequestEncoder.ErrorDataEncoderException e) { @@ -1525,32 +1556,45 @@ private > Flux handleStreamHttpError( private void streamRequestThroughChannel0( io.micronaut.http.HttpRequest parentRequest, - final io.micronaut.http.HttpRequest finalRequest, - FluxSink emitter, + MutableHttpRequest request, + FluxSink> emitter, ConnectionManager.PoolHandle poolHandle, boolean secure) throws HttpPostRequestEncoder.ErrorDataEncoderException { - NettyRequestWriter requestWriter = prepareRequest( - finalRequest, - finalRequest.getUri(), - emitter + URI requestURI = request.getUri(); + boolean permitsBody = io.micronaut.http.HttpMethod.permitsRequestBody(request.getMethod()); + NettyRequestWriter requestWriter = buildNettyRequest( + request, + requestURI, + request + .getContentType() + .orElse(MediaType.APPLICATION_JSON_TYPE), + permitsBody, + null, + throwable -> { + if (!emitter.isCancelled()) { + emitter.error(throwable); + } + } ); + prepareHttpHeaders(poolHandle, requestURI, request, requestWriter.getNettyRequest(), permitsBody); + HttpRequest nettyRequest = requestWriter.getNettyRequest(); - Promise> responsePromise = poolHandle.channel.eventLoop().newPromise(); + Promise> responsePromise = poolHandle.channel.eventLoop().newPromise(); ChannelPipeline pipeline = poolHandle.channel.pipeline(); - pipeline.addLast(ChannelPipelineCustomizer.HANDLER_MICRONAUT_HTTP_RESPONSE_FULL, new StreamFullHttpResponseHandler(responsePromise, parentRequest, finalRequest)); - pipeline.addLast(ChannelPipelineCustomizer.HANDLER_MICRONAUT_HTTP_RESPONSE_STREAM, new StreamStreamHttpResponseHandler(responsePromise, parentRequest, finalRequest)); + pipeline.addLast(ChannelPipelineCustomizer.HANDLER_MICRONAUT_HTTP_RESPONSE_FULL, new StreamFullHttpResponseHandler(responsePromise, parentRequest, request)); + pipeline.addLast(ChannelPipelineCustomizer.HANDLER_MICRONAUT_HTTP_RESPONSE_STREAM, new StreamStreamHttpResponseHandler(responsePromise, parentRequest, request)); poolHandle.notifyRequestPipelineBuilt(); if (log.isDebugEnabled()) { - debugRequest(finalRequest.getUri(), nettyRequest); + debugRequest(request.getUri(), nettyRequest); } if (log.isTraceEnabled()) { - traceRequest(finalRequest, nettyRequest); + traceRequest(request, nettyRequest); } - requestWriter.write(poolHandle.channel, secure, emitter); - responsePromise.addListener(future -> { + requestWriter.write(poolHandle, secure, emitter); + responsePromise.addListener((Future> future) -> { if (future.isSuccess()) { emitter.next(future.getNow()); emitter.complete(); @@ -1580,23 +1624,22 @@ private String getHostHeader(URI requestURI) { } private void prepareHttpHeaders( - URI requestURI, - io.micronaut.http.HttpRequest request, - io.netty.handler.codec.http.HttpRequest nettyRequest, - boolean permitsBody, - boolean closeConnection) { + ConnectionManager.PoolHandle poolHandle, + URI requestURI, + io.micronaut.http.HttpRequest request, + HttpRequest nettyRequest, + boolean permitsBody) { HttpHeaders headers = nettyRequest.headers(); if (!headers.contains(HttpHeaderNames.HOST)) { headers.set(HttpHeaderNames.HOST, getHostHeader(requestURI)); } - // HTTP/2 assumes keep-alive connections - if (connectionManager.httpVersion != io.micronaut.http.HttpVersion.HTTP_2_0) { - if (closeConnection) { - headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE); - } else { + if (!poolHandle.http2) { + if (poolHandle.canReturn()) { headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE); + } else { + headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE); } } @@ -1621,7 +1664,7 @@ private void prepareHttpHeaders( } } - private ClientFilterChain buildChain(AtomicReference> requestWrapper, List filters) { + private ClientFilterChain buildChain(AtomicReference> requestWrapper, List filters) { AtomicInteger integer = new AtomicInteger(); int len = filters.size(); return new ClientFilterChain() { @@ -1668,7 +1711,7 @@ private HttpPostRequestEncoder buildFormDataRequest(MutableHttpRequest clientHtt } private void addBodyAttribute(HttpPostRequestEncoder postRequestEncoder, String key, Object value) throws HttpPostRequestEncoder.ErrorDataEncoderException { - Optional converted = ConversionService.SHARED.convert(value, String.class); + Optional converted = conversionService.convert(value, String.class); if (converted.isPresent()) { postRequestEncoder.addBodyAttribute(key, converted.get()); } @@ -1784,37 +1827,6 @@ private static MediaTypeCodecRegistry createDefaultMediaTypeRegistry() { ); } - private NettyRequestWriter prepareRequest( - io.micronaut.http.HttpRequest request, - URI requestURI, - FluxSink> emitter) throws HttpPostRequestEncoder.ErrorDataEncoderException { - MediaType requestContentType = request - .getContentType() - .orElse(MediaType.APPLICATION_JSON_TYPE); - - boolean permitsBody = io.micronaut.http.HttpMethod.permitsRequestBody(request.getMethod()); - - if (!(request instanceof MutableHttpRequest)) { - throw new IllegalArgumentException("A MutableHttpRequest is required"); - } - MutableHttpRequest clientHttpRequest = (MutableHttpRequest) request; - NettyRequestWriter requestWriter = buildNettyRequest( - clientHttpRequest, - requestURI, - requestContentType, - permitsBody, - null, - throwable -> { - if (!emitter.isCancelled()) { - emitter.error(throwable); - } - } - ); - io.netty.handler.codec.http.HttpRequest nettyRequest = requestWriter.getNettyRequest(); - prepareHttpHeaders(requestURI, request, nettyRequest, permitsBody, true); - return requestWriter; - } - private @NonNull InvocationInstrumenter combineFactories() { if (CollectionUtils.isEmpty(invocationInstrumenterFactories)) { return NOOP; @@ -1949,23 +1961,21 @@ private class NettyRequestWriter { * @param channelPool The channel pool * @param emitter The emitter */ - protected void write(Channel channel, boolean isSecure, FluxSink emitter) { - final ChannelPipeline pipeline = channel.pipeline(); - if (connectionManager.httpVersion == io.micronaut.http.HttpVersion.HTTP_2_0) { + protected void write(ConnectionManager.PoolHandle poolHandle, boolean isSecure, FluxSink emitter) { + if (poolHandle.http2) { + // todo: move to ConnectionManager, DefaultHttpClient shouldn't care about the scheme if (isSecure) { nettyRequest.headers().add(AbstractNettyHttpRequest.HTTP2_SCHEME, HttpScheme.HTTPS); } else { nettyRequest.headers().add(AbstractNettyHttpRequest.HTTP2_SCHEME, HttpScheme.HTTP); } } - processRequestWrite(channel, emitter, pipeline); - } - private void processRequestWrite(Channel channel, FluxSink emitter, ChannelPipeline pipeline) { + Channel channel = poolHandle.channel; ChannelFuture writeFuture; if (encoder != null && encoder.isChunked()) { channel.attr(AttributeKey.valueOf(ChannelPipelineCustomizer.HANDLER_HTTP_CHUNK)).set(true); - pipeline.addAfter(ChannelPipelineCustomizer.HANDLER_HTTP_STREAM, ChannelPipelineCustomizer.HANDLER_HTTP_CHUNK, new ChunkedWriteHandler()); + channel.pipeline().addAfter(ChannelPipelineCustomizer.HANDLER_HTTP_STREAM, ChannelPipelineCustomizer.HANDLER_HTTP_CHUNK, new ChunkedWriteHandler()); channel.write(nettyRequest); writeFuture = channel.writeAndFlush(encoder); } else { @@ -1975,6 +1985,7 @@ private void processRequestWrite(Channel channel, FluxSink emitter, ChannelPi connectionManager.addInstrumentedListener(writeFuture, f -> { try { if (!f.isSuccess()) { + poolHandle.taint(); if (!emitter.isCancelled()) { emitter.error(f.cause()); } @@ -2010,11 +2021,11 @@ private static class CurrentEvent { } private abstract class BaseHttpResponseHandler extends SimpleChannelInboundHandlerInstrumented { - private final Promise responsePromise; + private final Promise responsePromise; private final io.micronaut.http.HttpRequest parentRequest; private final io.micronaut.http.HttpRequest finalRequest; - public BaseHttpResponseHandler(Promise responsePromise, io.micronaut.http.HttpRequest parentRequest, io.micronaut.http.HttpRequest finalRequest) { + public BaseHttpResponseHandler(Promise responsePromise, io.micronaut.http.HttpRequest parentRequest, io.micronaut.http.HttpRequest finalRequest) { super(connectionManager.instrumenter); this.responsePromise = responsePromise; this.parentRequest = parentRequest; @@ -2082,6 +2093,7 @@ protected void channelReadInstrumented(ChannelHandlerContext ctx, R msg) throws traceHeaders(headers); } buildResponse(responsePromise, msg); + removeHandler(ctx); } private void setRedirectHeaders(@Nullable io.micronaut.http.HttpRequest request, MutableHttpRequest redirectRequest) { @@ -2101,6 +2113,8 @@ private void setRedirectHeaders(@Nullable io.micronaut.http.HttpRequest reque } } + protected abstract void removeHandler(ChannelHandlerContext ctx); + protected abstract Function> makeRedirectHandler(io.micronaut.http.HttpRequest parentRequest, MutableHttpRequest redirectRequest); protected abstract void buildResponse(Promise promise, R msg); @@ -2161,6 +2175,11 @@ protected void channelReadInstrumented(ChannelHandlerContext channelHandlerConte } } + @Override + protected void removeHandler(ChannelHandlerContext ctx) { + // done in channelReadInstrumented + } + @Override protected void buildResponse(Promise> promise, FullHttpResponse msg) { try { @@ -2274,6 +2293,12 @@ public Argument getErrorType(MediaType mediaType) { @Override public void handlerRemoved(ChannelHandlerContext ctx) { + ctx.pipeline().remove(ChannelPipelineCustomizer.HANDLER_HTTP_AGGREGATOR); + try { + ctx.pipeline().remove(ChannelPipelineCustomizer.HANDLER_HTTP_CHUNK); + } catch (NoSuchElementException ignored) { + } + ctx.pipeline().remove(ChannelPipelineCustomizer.HANDLER_HTTP_STREAM); poolHandle.release(); } @@ -2285,8 +2310,12 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { } } - private class StreamFullHttpResponseHandler extends BaseHttpResponseHandler> { - public StreamFullHttpResponseHandler(Promise> responsePromise, io.micronaut.http.HttpRequest parentRequest, io.micronaut.http.HttpRequest finalRequest) { + private class StreamFullHttpResponseHandler extends BaseHttpResponseHandler> { + public StreamFullHttpResponseHandler( + Promise> responsePromise, + io.micronaut.http.HttpRequest parentRequest, + io.micronaut.http.HttpRequest finalRequest) { + super(responsePromise, parentRequest, finalRequest); } @@ -2296,7 +2325,13 @@ public boolean acceptInboundMessage(Object msg) { } @Override - protected void buildResponse(Promise> promise, FullHttpResponse msg) { + protected void removeHandler(ChannelHandlerContext ctx) { + ctx.pipeline().remove(ChannelPipelineCustomizer.HANDLER_MICRONAUT_HTTP_RESPONSE_FULL); + ctx.pipeline().remove(ChannelPipelineCustomizer.HANDLER_MICRONAUT_HTTP_RESPONSE_STREAM); + } + + @Override + protected void buildResponse(Promise> promise, FullHttpResponse msg) { Publisher bodyPublisher; if (msg.content() instanceof EmptyByteBuf) { bodyPublisher = Publishers.empty(); @@ -2313,13 +2348,17 @@ protected void buildResponse(Promise> promise, FullHttpR } @Override - protected Function>> makeRedirectHandler(io.micronaut.http.HttpRequest parentRequest, MutableHttpRequest redirectRequest) { + protected Function>> makeRedirectHandler(io.micronaut.http.HttpRequest parentRequest, MutableHttpRequest redirectRequest) { return uri -> buildStreamExchange(parentRequest, redirectRequest, uri, null); } } - private class StreamStreamHttpResponseHandler extends BaseHttpResponseHandler> { - public StreamStreamHttpResponseHandler(Promise> responsePromise, io.micronaut.http.HttpRequest parentRequest, io.micronaut.http.HttpRequest finalRequest) { + private class StreamStreamHttpResponseHandler extends BaseHttpResponseHandler> { + public StreamStreamHttpResponseHandler( + Promise> responsePromise, + io.micronaut.http.HttpRequest parentRequest, + io.micronaut.http.HttpRequest finalRequest) { + super(responsePromise, parentRequest, finalRequest); } @@ -2329,12 +2368,18 @@ public boolean acceptInboundMessage(Object msg) { } @Override - protected void buildResponse(Promise> promise, StreamedHttpResponse msg) { + protected void removeHandler(ChannelHandlerContext ctx) { + ctx.pipeline().remove(ChannelPipelineCustomizer.HANDLER_MICRONAUT_HTTP_RESPONSE_FULL); + ctx.pipeline().remove(ChannelPipelineCustomizer.HANDLER_MICRONAUT_HTTP_RESPONSE_STREAM); + } + + @Override + protected void buildResponse(Promise> promise, StreamedHttpResponse msg) { promise.trySuccess(new NettyStreamedHttpResponse<>(msg)); } @Override - protected Function>> makeRedirectHandler(io.micronaut.http.HttpRequest parentRequest, MutableHttpRequest redirectRequest) { + protected Function>> makeRedirectHandler(io.micronaut.http.HttpRequest parentRequest, MutableHttpRequest redirectRequest) { return uri -> buildStreamExchange(parentRequest, redirectRequest, uri, null); } } diff --git a/http-client/src/main/java/io/micronaut/http/client/netty/DefaultNettyHttpClientRegistry.java b/http-client/src/main/java/io/micronaut/http/client/netty/DefaultNettyHttpClientRegistry.java index 6e0698a2a6e..07acfebeee2 100644 --- a/http-client/src/main/java/io/micronaut/http/client/netty/DefaultNettyHttpClientRegistry.java +++ b/http-client/src/main/java/io/micronaut/http/client/netty/DefaultNettyHttpClientRegistry.java @@ -27,20 +27,20 @@ import io.micronaut.core.annotation.Nullable; import io.micronaut.core.convert.ConversionService; import io.micronaut.core.util.StringUtils; -import io.micronaut.http.HttpVersion; import io.micronaut.http.MediaType; import io.micronaut.http.annotation.FilterMatcher; import io.micronaut.http.bind.DefaultRequestBinderRegistry; import io.micronaut.http.bind.RequestBinderRegistry; -import io.micronaut.http.client.HttpClientRegistry; -import io.micronaut.http.client.StreamingHttpClientRegistry; -import io.micronaut.http.client.ProxyHttpClient; -import io.micronaut.http.client.HttpClientConfiguration; import io.micronaut.http.client.HttpClient; -import io.micronaut.http.client.StreamingHttpClient; -import io.micronaut.http.client.ProxyHttpClientRegistry; +import io.micronaut.http.client.HttpClientConfiguration; +import io.micronaut.http.client.HttpClientRegistry; +import io.micronaut.http.client.HttpVersionSelection; import io.micronaut.http.client.LoadBalancer; import io.micronaut.http.client.LoadBalancerResolver; +import io.micronaut.http.client.ProxyHttpClient; +import io.micronaut.http.client.ProxyHttpClientRegistry; +import io.micronaut.http.client.StreamingHttpClient; +import io.micronaut.http.client.StreamingHttpClientRegistry; import io.micronaut.http.client.annotation.Client; import io.micronaut.http.client.exceptions.HttpClientException; import io.micronaut.http.client.filter.ClientFilterResolutionContext; @@ -58,8 +58,8 @@ import io.micronaut.http.netty.channel.EventLoopGroupRegistry; import io.micronaut.inject.InjectionPoint; import io.micronaut.inject.qualifiers.Qualifiers; -import io.micronaut.json.JsonMapper; import io.micronaut.json.JsonFeatures; +import io.micronaut.json.JsonMapper; import io.micronaut.json.codec.MapperMediaTypeCodec; import io.micronaut.scheduling.instrument.InvocationInstrumenterFactory; import io.micronaut.websocket.WebSocketClient; @@ -157,7 +157,7 @@ public DefaultNettyHttpClientRegistry( @NonNull @Override - public HttpClient getClient(HttpVersion httpVersion, @NonNull String clientId, @Nullable String path) { + public HttpClient getClient(@NonNull HttpVersionSelection httpVersion, @NonNull String clientId, @Nullable String path) { final ClientKey key = new ClientKey( httpVersion, clientId, @@ -391,7 +391,7 @@ private DefaultHttpClient getClient(ClientKey key, BeanContext beanContext, Anno private DefaultHttpClient buildClient( LoadBalancer loadBalancer, - HttpVersion httpVersion, + HttpVersionSelection httpVersion, HttpClientConfiguration configuration, String clientId, String contextPath, @@ -399,6 +399,7 @@ private DefaultHttpClient buildClient( AnnotationMetadata annotationMetadata) { EventLoopGroup eventLoopGroup = resolveEventLoopGroup(configuration, beanContext); + ConversionService conversionService = beanContext.getBean(ConversionService.class); return new DefaultHttpClient( loadBalancer, httpVersion, @@ -414,14 +415,14 @@ private DefaultHttpClient buildClient( codecRegistry, WebSocketBeanRegistry.forClient(beanContext), beanContext.findBean(RequestBinderRegistry.class).orElseGet(() -> - new DefaultRequestBinderRegistry(ConversionService.SHARED) + new DefaultRequestBinderRegistry(conversionService) ), eventLoopGroup, resolveSocketChannelFactory(configuration, beanContext), - pipelineListeners, clientCustomizer, invocationInstrumenterFactories, - clientId + clientId, + conversionService ); } @@ -476,8 +477,7 @@ private ChannelFactory resolveSocketChannelFactory(HttpClientConfiguration confi } private ClientKey getClientKey(AnnotationMetadata metadata) { - final HttpVersion httpVersion = - metadata.enumValue(Client.class, "httpVersion", HttpVersion.class).orElse(null); + HttpVersionSelection httpVersionSelection = HttpVersionSelection.forClientAnnotation(metadata); String clientId = metadata.stringValue(Client.class).orElse(null); String path = metadata.stringValue(Client.class, "path").orElse(null); List filterAnnotation = metadata @@ -486,7 +486,7 @@ private ClientKey getClientKey(AnnotationMetadata metadata) { metadata.classValue(Client.class, "configuration").orElse(null); JsonFeatures jsonFeatures = jsonMapper.detectFeatures(metadata).orElse(null); - return new ClientKey(httpVersion, clientId, filterAnnotation, path, configurationClass, jsonFeatures); + return new ClientKey(httpVersionSelection, clientId, filterAnnotation, path, configurationClass, jsonFeatures); } private static MediaTypeCodec createNewJsonCodec(BeanContext beanContext, JsonFeatures jsonFeatures) { @@ -502,7 +502,7 @@ private static MapperMediaTypeCodec getJsonCodec(BeanContext beanContext) { */ @Internal private static final class ClientKey { - final HttpVersion httpVersion; + final HttpVersionSelection httpVersion; final String clientId; final List filterAnnotations; final String path; @@ -510,7 +510,7 @@ private static final class ClientKey { final JsonFeatures jsonFeatures; ClientKey( - HttpVersion httpVersion, + HttpVersionSelection httpVersion, String clientId, List filterAnnotations, String path, diff --git a/http-client/src/main/java/io/micronaut/http/client/netty/HttpLineBasedFrameDecoder.java b/http-client/src/main/java/io/micronaut/http/client/netty/HttpLineBasedFrameDecoder.java new file mode 100644 index 00000000000..eacc0125f52 --- /dev/null +++ b/http-client/src/main/java/io/micronaut/http/client/netty/HttpLineBasedFrameDecoder.java @@ -0,0 +1,103 @@ +/* + * Copyright 2017-2022 original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.micronaut.http.client.netty; + +import io.micronaut.core.annotation.Internal; +import io.micronaut.core.annotation.NonNull; +import io.micronaut.http.netty.channel.ChannelPipelineCustomizer; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.LineBasedFrameDecoder; +import io.netty.handler.codec.http.DefaultHttpContent; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.LastHttpContent; + +/** + * Variant of {@link LineBasedFrameDecoder} that accepts + * {@link io.netty.handler.codec.http.HttpContent} data. Note: this handler removes itself when the + * response has been consumed. + * + * @since 4.0.0 + */ +@Internal +final class HttpLineBasedFrameDecoder extends LineBasedFrameDecoder { + static final String NAME = ChannelPipelineCustomizer.HANDLER_MICRONAUT_SSE_EVENT_STREAM; + + private boolean ignoreOneLast = false; + + HttpLineBasedFrameDecoder(int maxLength, boolean stripDelimiter, boolean failFast) { + super(maxLength, stripDelimiter, failFast); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof HttpResponse && + ((HttpResponse) msg).status().equals(HttpResponseStatus.CONTINUE)) { + ignoreOneLast = true; + } + + if (msg instanceof HttpContent) { + super.channelRead(ctx, ((HttpContent) msg).content()); + } else { + ctx.fireChannelRead(msg); + } + + if (msg instanceof LastHttpContent) { + if (ignoreOneLast) { + ignoreOneLast = false; + } else { + // first, remove the handler so that LineBasedFrameDecoder flushes any further + // data. Then forward the LastHttpContent. + ctx.pipeline().remove(NAME); + ctx.fireChannelRead(LastHttpContent.EMPTY_LAST_CONTENT); + } + } + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + ctx.pipeline().addAfter(NAME, Wrap.NAME, Wrap.INSTANCE); + } + + @Override + protected void handlerRemoved0(ChannelHandlerContext ctx) { + ctx.pipeline().remove(Wrap.NAME); + } + + @Sharable + private static class Wrap extends ChannelInboundHandlerAdapter { + static final ChannelHandler INSTANCE = new Wrap(); + static final String NAME = ChannelPipelineCustomizer.HANDLER_MICRONAUT_SSE_CONTENT; + + @Override + public void channelRead(@NonNull ChannelHandlerContext ctx, @NonNull Object msg) throws Exception { + if (msg instanceof ByteBuf) { + ByteBuf buffer = (ByteBuf) msg; + // todo: this is necessary because downstream handlers sometimes do the + // `if (refcnt > 0) release` pattern. We should eventually fix that. + ByteBuf copy = buffer.copy(); + buffer.release(); + ctx.fireChannelRead(new DefaultHttpContent(copy)); + } else { + ctx.fireChannelRead(msg); + } + } + } +} diff --git a/http-client/src/main/java/io/micronaut/http/client/netty/IdleTimeoutHandler.java b/http-client/src/main/java/io/micronaut/http/client/netty/IdleTimeoutHandler.java deleted file mode 100644 index c2242e22535..00000000000 --- a/http-client/src/main/java/io/micronaut/http/client/netty/IdleTimeoutHandler.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright 2017-2020 original authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.micronaut.http.client.netty; - -import io.micronaut.core.annotation.Internal; -import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandler; -import io.netty.handler.timeout.IdleState; -import io.netty.handler.timeout.IdleStateEvent; - -/** - * This class is responsible for detecting idle timeout events, upon which the channel in the pool is closed. - * - * @author Dan Maas - * @since 2.2.4 - */ -@ChannelHandler.Sharable -@Internal -final class IdleTimeoutHandler extends ChannelDuplexHandler { - - static final ChannelInboundHandler INSTANCE = new IdleTimeoutHandler(); - - private IdleTimeoutHandler() { - } - - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { - if (evt instanceof IdleStateEvent) { - IdleStateEvent e = (IdleStateEvent) evt; - if (e.state() == IdleState.READER_IDLE || e.state() == IdleState.WRITER_IDLE) { - ctx.close(); - } - } - } -} diff --git a/http-client/src/main/java/io/micronaut/http/client/netty/IdlingConnectionHandler.java b/http-client/src/main/java/io/micronaut/http/client/netty/InitialConnectionErrorHandler.java similarity index 52% rename from http-client/src/main/java/io/micronaut/http/client/netty/IdlingConnectionHandler.java rename to http-client/src/main/java/io/micronaut/http/client/netty/InitialConnectionErrorHandler.java index d4e931566a2..5c6aea4d4f3 100644 --- a/http-client/src/main/java/io/micronaut/http/client/netty/IdlingConnectionHandler.java +++ b/http-client/src/main/java/io/micronaut/http/client/netty/InitialConnectionErrorHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2017-2020 original authors + * Copyright 2017-2022 original authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,36 +16,37 @@ package io.micronaut.http.client.netty; import io.micronaut.core.annotation.Internal; +import io.micronaut.core.annotation.Nullable; +import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandler; import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.util.ReferenceCountUtil; +import io.netty.util.AttributeKey; /** - * This handler prevents reading a channel when it is not being used by the connection pool. - * - * @author Dan Maas - * @since 2.2.4 + * Handler for connection failures that happen during the handshake phases of a connection. */ -@ChannelHandler.Sharable @Internal -final class IdlingConnectionHandler extends ChannelInboundHandlerAdapter { - - static final ChannelInboundHandler INSTANCE = new IdlingConnectionHandler(); - - private IdlingConnectionHandler() { - } +@ChannelHandler.Sharable +abstract class InitialConnectionErrorHandler extends ChannelInboundHandlerAdapter { + private static final AttributeKey FAILURE_KEY = + AttributeKey.valueOf(InitialConnectionErrorHandler.class, "FAILURE_KEY"); @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - ReferenceCountUtil.release(msg); + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + setFailureCause(ctx.channel(), cause); ctx.close(); } + static void setFailureCause(Channel channel, Throwable cause) { + channel.attr(FAILURE_KEY).set(cause); + } + @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - ctx.close(); + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + super.channelInactive(ctx); + onNewConnectionFailure(ctx.channel().attr(FAILURE_KEY).get()); } + protected abstract void onNewConnectionFailure(@Nullable Throwable cause) throws Exception; } diff --git a/http-client/src/main/java/io/micronaut/http/client/netty/MutableHttpRequestWrapper.java b/http-client/src/main/java/io/micronaut/http/client/netty/MutableHttpRequestWrapper.java new file mode 100644 index 00000000000..8e1343be39f --- /dev/null +++ b/http-client/src/main/java/io/micronaut/http/client/netty/MutableHttpRequestWrapper.java @@ -0,0 +1,131 @@ +/* + * Copyright 2017-2022 original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.micronaut.http.client.netty; + +import io.micronaut.core.annotation.Internal; +import io.micronaut.core.annotation.NonNull; +import io.micronaut.core.annotation.Nullable; +import io.micronaut.core.convert.ConversionContext; +import io.micronaut.core.convert.ConversionService; +import io.micronaut.core.type.Argument; +import io.micronaut.http.HttpRequest; +import io.micronaut.http.HttpRequestWrapper; +import io.micronaut.http.MutableHttpHeaders; +import io.micronaut.http.MutableHttpParameters; +import io.micronaut.http.MutableHttpRequest; +import io.micronaut.http.cookie.Cookie; + +import java.net.URI; +import java.util.Optional; + +/** + * Wrapper around an immutable {@link HttpRequest} that allows mutation. + * + * @param Body type + * @since 4.0.0 + */ +@Internal +final class MutableHttpRequestWrapper extends HttpRequestWrapper implements MutableHttpRequest { + private final ConversionService conversionService; + + @Nullable + private B body; + @Nullable + private URI uri; + + MutableHttpRequestWrapper(ConversionService conversionService, HttpRequest delegate) { + super(delegate); + this.conversionService = conversionService; + } + + static MutableHttpRequest wrapIfNecessary(ConversionService conversionService, HttpRequest request) { + if (request instanceof MutableHttpRequest) { + return (MutableHttpRequest) request; + } else { + return new MutableHttpRequestWrapper<>(conversionService, request); + } + } + + @NonNull + @Override + public Optional getBody() { + if (body == null) { + return getDelegate().getBody(); + } else { + return Optional.of(body); + } + } + + @NonNull + @Override + public Optional getBody(@NonNull Class type) { + if (body == null) { + return getDelegate().getBody(type); + } else { + return conversionService.convert(body, ConversionContext.of(type)); + } + } + + @NonNull + @Override + public Optional getBody(@NonNull Argument type) { + if (body == null) { + return getDelegate().getBody(type); + } else { + return conversionService.convert(body, ConversionContext.of(type)); + } + } + + @Override + public MutableHttpRequest cookie(Cookie cookie) { + throw new UnsupportedOperationException(); + } + + @Override + public MutableHttpRequest uri(URI uri) { + this.uri = uri; + return this; + } + + @Override + @NonNull + public URI getUri() { + if (uri == null) { + return getDelegate().getUri(); + } else { + return uri; + } + } + + @NonNull + @Override + public MutableHttpParameters getParameters() { + return (MutableHttpParameters) super.getParameters(); + } + + @NonNull + @Override + public MutableHttpHeaders getHeaders() { + return (MutableHttpHeaders) super.getHeaders(); + } + + @SuppressWarnings("unchecked") + @Override + public MutableHttpRequest body(T body) { + this.body = (B) body; + return (MutableHttpRequest) this; + } +} diff --git a/http-client/src/main/java/io/micronaut/http/client/netty/NettyClientCustomizer.java b/http-client/src/main/java/io/micronaut/http/client/netty/NettyClientCustomizer.java index cfe383afbea..a95e029f37f 100644 --- a/http-client/src/main/java/io/micronaut/http/client/netty/NettyClientCustomizer.java +++ b/http-client/src/main/java/io/micronaut/http/client/netty/NettyClientCustomizer.java @@ -79,6 +79,12 @@ enum ChannelRole { * {@link io.netty.channel.socket.SocketChannel}, representing an HTTP connection. */ CONNECTION, + /** + * The channel is a HTTP2 stream channel. + * + * @since 4.0.0 + */ + HTTP2_STREAM, } /** diff --git a/http-client/src/main/java/io/micronaut/http/client/netty/PoolResizer.java b/http-client/src/main/java/io/micronaut/http/client/netty/PoolResizer.java new file mode 100644 index 00000000000..2a0a28e65c4 --- /dev/null +++ b/http-client/src/main/java/io/micronaut/http/client/netty/PoolResizer.java @@ -0,0 +1,282 @@ +/* + * Copyright 2017-2022 original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.micronaut.http.client.netty; + +import io.micronaut.core.annotation.Internal; +import io.micronaut.core.annotation.Nullable; +import io.micronaut.http.client.HttpClientConfiguration; +import io.micronaut.http.client.exceptions.HttpClientException; +import org.slf4j.Logger; +import reactor.core.publisher.Sinks; + +import java.util.Deque; +import java.util.List; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +/** + * This class handles the sizing of a connection pool to conform to the configuration in + * {@link io.micronaut.http.client.HttpClientConfiguration.ConnectionPoolConfiguration}. + *

+ * This class consists of various mutator methods (e.g. {@link #addPendingRequest}) that + * may be called concurrently and in a reentrant fashion (e.g. inside {@link #openNewConnection()}). + * These mutator methods update their respective fields and then mark this class as + * {@link #dirty()}. The state management logic ensures that {@link #doSomeWork()} is called in a + * serialized fashion (no concurrency or reentrancy) at least once after each {@link #dirty()} + * call. + */ +@Internal +abstract class PoolResizer { + private final Logger log; + private final HttpClientConfiguration.ConnectionPoolConfiguration connectionPoolConfiguration; + + private final AtomicReference state = new AtomicReference<>(WorkState.IDLE); + + private final AtomicInteger pendingConnectionCount = new AtomicInteger(0); + + private final Deque> pendingRequests = new ConcurrentLinkedDeque<>(); + private final List http1Connections = new CopyOnWriteArrayList<>(); + private final List http2Connections = new CopyOnWriteArrayList<>(); + + PoolResizer(Logger log, HttpClientConfiguration.ConnectionPoolConfiguration connectionPoolConfiguration) { + this.log = log; + this.connectionPoolConfiguration = connectionPoolConfiguration; + } + + private void dirty() { + WorkState before = state.getAndUpdate(ws -> { + if (ws == WorkState.IDLE) { + return WorkState.ACTIVE_WITHOUT_PENDING_WORK; + } else { + return WorkState.ACTIVE_WITH_PENDING_WORK; + } + }); + if (before != WorkState.IDLE) { + // already in one of the active states, another thread will take care of our changes + return; + } + // we were in idle state, this thread will handle the changes. + while (true) { + try { + doSomeWork(); + } catch (Throwable t) { + // this is probably an irrecoverable failure, we need to bail immediately, but + // avoid locking up the state. Another thread might be able to continue work. + state.set(WorkState.IDLE); + throw t; + } + + WorkState endState = state.updateAndGet(ws -> { + if (ws == WorkState.ACTIVE_WITH_PENDING_WORK) { + return WorkState.ACTIVE_WITHOUT_PENDING_WORK; + } else { + return WorkState.IDLE; + } + }); + if (endState == WorkState.IDLE) { + // nothing else to do \o/ + break; + } + } + } + + private void doSomeWork() { + while (true) { + Sinks.One toDispatch = pendingRequests.pollFirst(); + if (toDispatch == null) { + break; + } + boolean dispatched = false; + for (ResizerConnection c : http2Connections) { + if (dispatchSafe(c, toDispatch)) { + dispatched = true; + break; + } + } + if (!dispatched) { + for (ResizerConnection c : http1Connections) { + if (dispatchSafe(c, toDispatch)) { + dispatched = true; + break; + } + } + } + if (!dispatched) { + pendingRequests.addFirst(toDispatch); + break; + } + } + + // snapshot our fields + int pendingRequestCount = this.pendingRequests.size(); + int pendingConnectionCount = this.pendingConnectionCount.get(); + int http1ConnectionCount = this.http1Connections.size(); + int http2ConnectionCount = this.http2Connections.size(); + + if (pendingRequestCount == 0) { + // if there are no pending requests, there is nothing to do. + return; + } + int connectionsToOpen = pendingRequestCount - pendingConnectionCount; + // make sure we won't exceed our config setting for pending connections + connectionsToOpen = Math.min(connectionsToOpen, connectionPoolConfiguration.getMaxPendingConnections() - pendingConnectionCount); + // limit the connection count to the protocol-specific settings, but only if that protocol was seen for this pool. + if (http1ConnectionCount > 0) { + connectionsToOpen = Math.min(connectionsToOpen, connectionPoolConfiguration.getMaxConcurrentHttp1Connections() - http1ConnectionCount); + } + if (http2ConnectionCount > 0) { + connectionsToOpen = Math.min(connectionsToOpen, connectionPoolConfiguration.getMaxConcurrentHttp2Connections() - http2ConnectionCount); + } + + if (connectionsToOpen > 0) { + this.pendingConnectionCount.addAndGet(connectionsToOpen); + for (int i = 0; i < connectionsToOpen; i++) { + try { + openNewConnection(); + } catch (Exception e) { + try { + onNewConnectionFailure(e); + } catch (Exception f) { + log.error("Internal error", f); + } + } + } + dirty(); + } + } + + private boolean dispatchSafe(ResizerConnection connection, Sinks.One toDispatch) { + try { + return connection.dispatch(toDispatch); + } catch (Exception e) { + try { + if (toDispatch.tryEmitError(e) != Sinks.EmitResult.OK) { + // this is probably fine, log it anyway + log.debug("Failure during connection dispatch operation, but dispatch request was already complete.", e); + } + } catch (Exception f) { + log.error("Internal error", f); + } + return true; + } + } + + abstract void openNewConnection() throws Exception; + + static boolean incrementWithLimit(AtomicInteger variable, int limit) { + while (true) { + int old = variable.get(); + if (old >= limit) { + return false; + } + if (variable.compareAndSet(old, old + 1)) { + return true; + } + } + } + + // can be overridden, so `throws Exception` ensures we handle any errors + void onNewConnectionFailure(@Nullable Throwable error) throws Exception { + // todo: implement a circuit breaker here? right now, we just fail one connection in the + // subclass implementation, but maybe we should do more. + pendingConnectionCount.decrementAndGet(); + dirty(); + } + + final void onNewConnectionEstablished1(ResizerConnection connection) { + http1Connections.add(connection); + pendingConnectionCount.decrementAndGet(); + dirty(); + } + + final void onNewConnectionEstablished2(ResizerConnection connection) { + http2Connections.add(connection); + pendingConnectionCount.decrementAndGet(); + dirty(); + } + + final void onConnectionInactive1(ResizerConnection connection) { + http1Connections.remove(connection); + dirty(); + } + + final void onConnectionInactive2(ResizerConnection connection) { + http2Connections.remove(connection); + dirty(); + } + + final void addPendingRequest(Sinks.One sink) { + if (pendingRequests.size() >= connectionPoolConfiguration.getMaxPendingAcquires()) { + sink.tryEmitError(new HttpClientException("Cannot acquire connection, exceeded max pending acquires configuration")); + return; + } + pendingRequests.addLast(sink); + dirty(); + } + + @Nullable + final Sinks.One pollPendingRequest() { + Sinks.One req = pendingRequests.pollFirst(); + if (req != null) { + dirty(); + } + return req; + } + + final void markConnectionAvailable() { + dirty(); + } + + final void forEachConnection(Consumer c) { + for (ResizerConnection http1Connection : http1Connections) { + c.accept(http1Connection); + } + for (ResizerConnection http2Connection : http2Connections) { + c.accept(http2Connection); + } + } + + private enum WorkState { + /** + * There are no pending changes, and nobody is currently executing {@link #doSomeWork()}. + */ + IDLE, + /** + * Someone is currently executing {@link #doSomeWork()}, but there were further changes + * after {@link #doSomeWork()} was called, so it needs to be called again. + */ + ACTIVE_WITH_PENDING_WORK, + /** + * Someone is currently executing {@link #doSomeWork()}, and there were no other changes + * since then. + */ + ACTIVE_WITHOUT_PENDING_WORK, + } + + abstract static class ResizerConnection { + /** + * Attempt to dispatch a stream on this connection. + * + * @param sink The pending request that wants to acquire this connection + * @return {@code true} if the acquisition may succeed (if it fails later, the pending + * request must be readded), or {@code false} if it fails immediately + */ + abstract boolean dispatch(Sinks.One sink) throws Exception; + } +} diff --git a/http-client/src/main/java/io/micronaut/http/client/netty/ssl/NettyClientSslBuilder.java b/http-client/src/main/java/io/micronaut/http/client/netty/ssl/NettyClientSslBuilder.java index 9c043c6c91f..e8e3864415f 100644 --- a/http-client/src/main/java/io/micronaut/http/client/netty/ssl/NettyClientSslBuilder.java +++ b/http-client/src/main/java/io/micronaut/http/client/netty/ssl/NettyClientSslBuilder.java @@ -17,8 +17,10 @@ import io.micronaut.context.annotation.BootstrapContextCompatible; import io.micronaut.core.annotation.Internal; +import io.micronaut.core.annotation.Nullable; import io.micronaut.core.io.ResourceResolver; import io.micronaut.http.HttpVersion; +import io.micronaut.http.client.HttpVersionSelection; import io.micronaut.http.ssl.AbstractClientSslConfiguration; import io.micronaut.http.ssl.ClientAuthentication; import io.micronaut.http.ssl.SslBuilder; @@ -26,7 +28,6 @@ import io.micronaut.http.ssl.SslConfigurationException; import io.netty.handler.codec.http2.Http2SecurityUtil; import io.netty.handler.ssl.ApplicationProtocolConfig; -import io.netty.handler.ssl.ApplicationProtocolNames; import io.netty.handler.ssl.ClientAuth; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; @@ -53,7 +54,7 @@ @Singleton @Internal @BootstrapContextCompatible -public class NettyClientSslBuilder extends SslBuilder { +public final class NettyClientSslBuilder extends SslBuilder { private static final Logger LOG = LoggerFactory.getLogger(NettyClientSslBuilder.class); /** @@ -71,20 +72,24 @@ public Optional build(SslConfiguration ssl) { @Override public Optional build(SslConfiguration ssl, HttpVersion httpVersion) { + return Optional.ofNullable(build(ssl, HttpVersionSelection.forLegacyVersion(httpVersion))); + } + + @Nullable + public SslContext build(SslConfiguration ssl, HttpVersionSelection versionSelection) { if (!ssl.isEnabled()) { - return Optional.empty(); + return null; } - final boolean isHttp2 = httpVersion == HttpVersion.HTTP_2_0; SslContextBuilder sslBuilder = SslContextBuilder - .forClient() - .keyManager(getKeyManagerFactory(ssl)) - .trustManager(getTrustManagerFactory(ssl)); + .forClient() + .keyManager(getKeyManagerFactory(ssl)) + .trustManager(getTrustManagerFactory(ssl)); if (ssl.getProtocols().isPresent()) { sslBuilder.protocols(ssl.getProtocols().get()); } if (ssl.getCiphers().isPresent()) { sslBuilder = sslBuilder.ciphers(Arrays.asList(ssl.getCiphers().get())); - } else if (isHttp2) { + } else if (versionSelection.isHttp2CipherSuites()) { sslBuilder.ciphers(Http2SecurityUtil.CIPHERS, SupportedCipherSuiteFilter.INSTANCE); } if (ssl.getClientAuthentication().isPresent()) { @@ -95,20 +100,19 @@ public Optional build(SslConfiguration ssl, HttpVersion httpVersion) sslBuilder = sslBuilder.clientAuth(ClientAuth.OPTIONAL); } } - if (isHttp2) { + if (versionSelection.isAlpn()) { SslProvider provider = SslProvider.isAlpnSupported(SslProvider.OPENSSL) ? SslProvider.OPENSSL : SslProvider.JDK; sslBuilder.sslProvider(provider); sslBuilder.applicationProtocolConfig(new ApplicationProtocolConfig( - ApplicationProtocolConfig.Protocol.ALPN, - ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, - ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, - ApplicationProtocolNames.HTTP_1_1, - ApplicationProtocolNames.HTTP_2 + ApplicationProtocolConfig.Protocol.ALPN, + ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, + ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, + versionSelection.getAlpnSupportedProtocols() )); } try { - return Optional.of(sslBuilder.build()); + return sslBuilder.build(); } catch (SSLException ex) { throw new SslConfigurationException("An error occurred while setting up SSL", ex); } diff --git a/http-client/src/main/java/io/micronaut/http/client/netty/websocket/NettyWebSocketClientHandler.java b/http-client/src/main/java/io/micronaut/http/client/netty/websocket/NettyWebSocketClientHandler.java index 5e825b38400..0bf36fc9700 100644 --- a/http-client/src/main/java/io/micronaut/http/client/netty/websocket/NettyWebSocketClientHandler.java +++ b/http-client/src/main/java/io/micronaut/http/client/netty/websocket/NettyWebSocketClientHandler.java @@ -146,7 +146,14 @@ public void handlerAdded(final ChannelHandlerContext ctx) { @Override public void channelActive(final ChannelHandlerContext ctx) { - handshaker.handshake(ctx.channel()); + handshaker.handshake(ctx.channel()).addListener(future -> { + if (future.isSuccess()) { + ctx.channel().config().setAutoRead(true); + ctx.read(); + } else { + handshakeFuture.tryFailure(future.cause()); + } + }); } @Override diff --git a/http-client/src/test/groovy/io/micronaut/http/client/ConnectTTLHandlerSpec.groovy b/http-client/src/test/groovy/io/micronaut/http/client/ConnectTTLHandlerSpec.groovy deleted file mode 100644 index e684cc596cc..00000000000 --- a/http-client/src/test/groovy/io/micronaut/http/client/ConnectTTLHandlerSpec.groovy +++ /dev/null @@ -1,41 +0,0 @@ -package io.micronaut.http.client - -import io.micronaut.http.client.netty.ConnectTTLHandler -import io.netty.channel.ChannelHandlerContext -import io.netty.channel.embedded.EmbeddedChannel -import spock.lang.Specification - - -class ConnectTTLHandlerSpec extends Specification{ - - - def "RELEASE_CHANNEL should be true for those channels who's connect-ttl is reached"(){ - - setup: - MockChannel channel = new MockChannel(); - ChannelHandlerContext context = Mock() - - when: - new ConnectTTLHandler(1).handlerAdded(context) - channel.runAllPendingTasks() - - then: - _ * context.channel() >> channel - - channel.attr(ConnectTTLHandler.RELEASE_CHANNEL) - - } - - class MockChannel extends EmbeddedChannel { - MockChannel() throws Exception { - super.doRegister() - } - - void runAllPendingTasks() throws InterruptedException { - super.runPendingTasks() - while (runScheduledPendingTasks() != -1) { - Thread.sleep(1) - } - } - } -} diff --git a/http-client/src/test/groovy/io/micronaut/http/client/ConnectionTTLSpec.groovy b/http-client/src/test/groovy/io/micronaut/http/client/ConnectionTTLSpec.groovy index e2660359945..dc03f800445 100644 --- a/http-client/src/test/groovy/io/micronaut/http/client/ConnectionTTLSpec.groovy +++ b/http-client/src/test/groovy/io/micronaut/http/client/ConnectionTTLSpec.groovy @@ -8,15 +8,12 @@ import io.micronaut.http.annotation.Controller import io.micronaut.http.annotation.Get import io.micronaut.runtime.server.EmbeddedServer import io.netty.channel.Channel -import io.netty.channel.pool.AbstractChannelPoolMap import spock.lang.AutoCleanup import spock.lang.Retry import spock.lang.Shared import spock.lang.Specification import spock.util.concurrent.PollingConditions -import java.lang.reflect.Field - @Retry class ConnectionTTLSpec extends Specification { @@ -37,7 +34,7 @@ class ConnectionTTLSpec extends Specification { when:"make first request" httpClient.toBlocking().retrieve(HttpRequest.GET('/connectTTL/'),String) - Channel ch = getQueuedChannels(httpClient).first + Channel ch = getQueuedChannels(httpClient).get(0) then:"ensure that connection is open as connect-ttl is not reached" getQueuedChannels(httpClient).size() == 1 @@ -67,11 +64,11 @@ class ConnectionTTLSpec extends Specification { when:"make first request" httpClient.toBlocking().retrieve(HttpRequest.GET('/connectTTL/'),String) - Deque deque = getQueuedChannels(httpClient) + List deque = getQueuedChannels(httpClient) then:"ensure that connection is open as connect-ttl is not reached" new PollingConditions().eventually { - deque.first.isOpen() + deque.get(0).isOpen() } when:"make another request after some time" @@ -80,7 +77,7 @@ class ConnectionTTLSpec extends Specification { then:"ensure channel is still open" new PollingConditions().eventually { - deque.first.isOpen() + deque.get(0).isOpen() } cleanup: @@ -99,11 +96,11 @@ class ConnectionTTLSpec extends Specification { when:"make first request" httpClient.toBlocking().retrieve(HttpRequest.GET('/connectTTL/'),String) - Deque deque = getQueuedChannels(httpClient) + Collection deque = getQueuedChannels(httpClient) then:"ensure that connection is open as connect-ttl is not reached" new PollingConditions().eventually { - deque.first.isOpen() + deque.get(0).isOpen() } when:"make another request" @@ -111,7 +108,7 @@ class ConnectionTTLSpec extends Specification { then:"ensure channel is still open" new PollingConditions().eventually { - deque.first.isOpen() + deque.get(0).isOpen() } cleanup: @@ -130,7 +127,7 @@ class ConnectionTTLSpec extends Specification { when:"make first request" httpClient.toBlocking().retrieve(HttpRequest.GET('/connectTTL/'),String) - Channel ch = getQueuedChannels(httpClient).first + Channel ch = getQueuedChannels(httpClient).get(0) then:"ensure that connection is open as connect-ttl is not reached" getQueuedChannels(httpClient).size() == 1 @@ -149,12 +146,8 @@ class ConnectionTTLSpec extends Specification { clientContext.close() } - Deque getQueuedChannels(HttpClient client) { - AbstractChannelPoolMap poolMap = client.connectionManager.poolMap - Field mapField = AbstractChannelPoolMap.getDeclaredField("map") - mapField.setAccessible(true) - Map innerMap = mapField.get(poolMap) - return innerMap.values().first().deque + List getQueuedChannels(HttpClient client) { + return client.connectionManager.channels } @Requires(property = 'spec.name', value = 'ConnectionTTLSpec') diff --git a/http-client/src/test/groovy/io/micronaut/http/client/IdleTimeoutSpec.groovy b/http-client/src/test/groovy/io/micronaut/http/client/IdleTimeoutSpec.groovy index a65805a92bf..c5edd768f3f 100644 --- a/http-client/src/test/groovy/io/micronaut/http/client/IdleTimeoutSpec.groovy +++ b/http-client/src/test/groovy/io/micronaut/http/client/IdleTimeoutSpec.groovy @@ -8,15 +8,12 @@ import io.micronaut.http.annotation.Controller import io.micronaut.http.annotation.Get import io.micronaut.runtime.server.EmbeddedServer import io.netty.channel.Channel -import io.netty.channel.pool.AbstractChannelPoolMap import spock.lang.AutoCleanup import spock.lang.Retry import spock.lang.Shared import spock.lang.Specification import spock.util.concurrent.PollingConditions -import java.lang.reflect.Field - @Retry class IdleTimeoutSpec extends Specification { @@ -35,11 +32,10 @@ class IdleTimeoutSpec extends Specification { when: "make first request" httpClient.toBlocking().retrieve(HttpRequest.GET('/idleTimeout/'), String) - Deque deque = getQueuedChannels(httpClient) - Channel ch1 = deque.first + Channel ch1 = getQueuedChannels(httpClient).get(0) then: "ensure that connection is open as connection-pool-idle-timeout is not reached" - deque.size() == 1 + getQueuedChannels(httpClient).size() == 1 ch1.isOpen() new PollingConditions(timeout: 2).eventually { !ch1.isOpen() @@ -50,14 +46,14 @@ class IdleTimeoutSpec extends Specification { then: new PollingConditions().eventually { - assert deque.size() > 0 + assert getQueuedChannels(httpClient).size() > 0 } when: - Channel ch2 = deque.first + Channel ch2 = getQueuedChannels(httpClient).get(0) then: "ensure channel 2 is open and channel 2 != channel 1" - deque.size() == 1 + getQueuedChannels(httpClient).size() == 1 ch1 != ch2 ch2.isOpen() new PollingConditions(timeout: 2).eventually { @@ -79,13 +75,13 @@ class IdleTimeoutSpec extends Specification { when: "make first request" httpClient.toBlocking().retrieve(HttpRequest.GET('/idleTimeout/'), String) - Deque deque = getQueuedChannels(httpClient) - Channel ch1 = deque.first + List deque = getQueuedChannels(httpClient) + Channel ch1 = deque.get(0) then: "ensure that connection is open as connection-pool-idle-timeout is not reached" deque.size() == 1 new PollingConditions().eventually { - deque.first.isOpen() + deque.get(0).isOpen() } when: "make another request" @@ -97,13 +93,13 @@ class IdleTimeoutSpec extends Specification { } when: - Channel ch2 = deque.first + Channel ch2 = deque.get(0) then: "ensure channel is still open" ch1 == ch2 deque.size() == 1 new PollingConditions().eventually { - deque.first.isOpen() + deque.get(0).isOpen() } cleanup: @@ -111,12 +107,8 @@ class IdleTimeoutSpec extends Specification { clientContext.close() } - Deque getQueuedChannels(HttpClient client) { - AbstractChannelPoolMap poolMap = client.connectionManager.poolMap - Field mapField = AbstractChannelPoolMap.getDeclaredField("map") - mapField.setAccessible(true) - Map innerMap = mapField.get(poolMap) - return innerMap.values().first().deque + List getQueuedChannels(HttpClient client) { + return client.connectionManager.channels } @Requires(property = 'spec.name', value = 'IdleTimeoutSpec') diff --git a/http-client/src/test/groovy/io/micronaut/http/client/ReadTimeoutSpec.groovy b/http-client/src/test/groovy/io/micronaut/http/client/ReadTimeoutSpec.groovy index a04d8e137c0..9debe5c3b98 100644 --- a/http-client/src/test/groovy/io/micronaut/http/client/ReadTimeoutSpec.groovy +++ b/http-client/src/test/groovy/io/micronaut/http/client/ReadTimeoutSpec.groovy @@ -298,11 +298,10 @@ class ReadTimeoutSpec extends Specification { .filter { it.clientId == "http://localhost:${embeddedServer.getPort()}" } .findFirst() .get() - def pool = getPool(clients.get(clientKey)) then:"Connections are not leaked" conditions.eventually { - pool.acquiredChannelCount() == 0 + clients.get(clientKey).connectionManager.liveRequestCount() == 0 } cleanup: diff --git a/http-client/src/test/groovy/io/micronaut/http/client/SslRefreshSpec.groovy b/http-client/src/test/groovy/io/micronaut/http/client/SslRefreshSpec.groovy index 40a68f9eb8b..90770f722d8 100644 --- a/http-client/src/test/groovy/io/micronaut/http/client/SslRefreshSpec.groovy +++ b/http-client/src/test/groovy/io/micronaut/http/client/SslRefreshSpec.groovy @@ -43,7 +43,10 @@ class SslRefreshSpec extends Specification { 'micronaut.http.client.ssl.client-authentication': 'NEED', 'micronaut.http.client.ssl.key-store.path': 'classpath:certs/client1.p12', 'micronaut.http.client.ssl.key-store.password': 'secret', - 'micronaut.http.client.ssl.insecure-trust-all-certificates': true + 'micronaut.http.client.ssl.insecure-trust-all-certificates': true, + 'micronaut.http.client.pool.enabled': false, + // need to force http1 because our ciphers are not supported by http2 + 'micronaut.http.client.http-version': '1.1', ] @Shared @AutoCleanup EmbeddedServer embeddedServer = ApplicationContext .builder() diff --git a/http-client/src/test/groovy/io/micronaut/http/client/config/DefaultHttpClientConfigurationSpec.groovy b/http-client/src/test/groovy/io/micronaut/http/client/config/DefaultHttpClientConfigurationSpec.groovy index ec839c1b0ae..6a2ed45e48b 100644 --- a/http-client/src/test/groovy/io/micronaut/http/client/config/DefaultHttpClientConfigurationSpec.groovy +++ b/http-client/src/test/groovy/io/micronaut/http/client/config/DefaultHttpClientConfigurationSpec.groovy @@ -75,7 +75,6 @@ class DefaultHttpClientConfigurationSpec extends Specification { where: key | property | value | expected 'enabled' | 'enabled' | 'false' | false - 'max-connections' | 'maxConnections' | '10' | 10 } void "test overriding logger for the client"() { diff --git a/http-client/src/test/groovy/io/micronaut/http/client/netty/ConnectionManagerSpec.groovy b/http-client/src/test/groovy/io/micronaut/http/client/netty/ConnectionManagerSpec.groovy new file mode 100644 index 00000000000..3aba2aab49f --- /dev/null +++ b/http-client/src/test/groovy/io/micronaut/http/client/netty/ConnectionManagerSpec.groovy @@ -0,0 +1,1226 @@ +package io.micronaut.http.client.netty + +import io.micronaut.context.ApplicationContext +import io.micronaut.context.event.BeanCreatedEvent +import io.micronaut.context.event.BeanCreatedEventListener +import io.micronaut.http.HttpRequest +import io.micronaut.http.HttpResponse +import io.micronaut.http.HttpStatus +import io.micronaut.http.HttpVersion +import io.micronaut.http.MediaType +import io.micronaut.http.client.HttpClient +import io.micronaut.http.client.StreamingHttpClient +import io.micronaut.http.client.exceptions.ReadTimeoutException +import io.micronaut.http.client.multipart.MultipartBody +import io.micronaut.http.netty.channel.ChannelPipelineCustomizer +import io.micronaut.http.server.netty.ssl.CertificateProvidedSslBuilder +import io.micronaut.http.ssl.SslConfiguration +import io.micronaut.websocket.WebSocketSession +import io.micronaut.websocket.annotation.ClientWebSocket +import io.micronaut.websocket.annotation.OnMessage +import io.netty.buffer.ByteBufAllocator +import io.netty.buffer.Unpooled +import io.netty.channel.Channel +import io.netty.channel.ChannelFuture +import io.netty.channel.ChannelHandler +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.ChannelId +import io.netty.channel.ChannelInboundHandlerAdapter +import io.netty.channel.ChannelInitializer +import io.netty.channel.ChannelPromise +import io.netty.channel.ServerChannel +import io.netty.channel.embedded.EmbeddedChannel +import io.netty.handler.codec.http.DefaultFullHttpResponse +import io.netty.handler.codec.http.DefaultHttpContent +import io.netty.handler.codec.http.DefaultHttpResponse +import io.netty.handler.codec.http.DefaultLastHttpContent +import io.netty.handler.codec.http.FullHttpRequest +import io.netty.handler.codec.http.HttpContentCompressor +import io.netty.handler.codec.http.HttpMethod +import io.netty.handler.codec.http.HttpObjectAggregator +import io.netty.handler.codec.http.HttpResponseStatus +import io.netty.handler.codec.http.HttpServerCodec +import io.netty.handler.codec.http.HttpServerUpgradeHandler +import io.netty.handler.codec.http.LastHttpContent +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame +import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory +import io.netty.handler.codec.http2.DefaultHttp2DataFrame +import io.netty.handler.codec.http2.DefaultHttp2Headers +import io.netty.handler.codec.http2.DefaultHttp2HeadersFrame +import io.netty.handler.codec.http2.Http2FrameCodec +import io.netty.handler.codec.http2.Http2FrameCodecBuilder +import io.netty.handler.codec.http2.Http2FrameStream +import io.netty.handler.codec.http2.Http2FrameStreamEvent +import io.netty.handler.codec.http2.Http2Headers +import io.netty.handler.codec.http2.Http2HeadersFrame +import io.netty.handler.codec.http2.Http2ResetFrame +import io.netty.handler.codec.http2.Http2ServerUpgradeCodec +import io.netty.handler.codec.http2.Http2SettingsAckFrame +import io.netty.handler.codec.http2.Http2SettingsFrame +import io.netty.handler.codec.http2.Http2Stream +import io.netty.handler.logging.LogLevel +import io.netty.handler.logging.LoggingHandler +import io.netty.handler.ssl.ApplicationProtocolNegotiationHandler +import io.netty.handler.ssl.SslContextBuilder +import io.netty.handler.ssl.util.SelfSignedCertificate +import io.netty.util.AsciiString +import io.netty.util.concurrent.GenericFutureListener +import jakarta.inject.Singleton +import org.spockframework.runtime.model.parallel.ExecutionMode +import reactor.core.publisher.Flux +import reactor.core.publisher.Mono +import spock.lang.Execution +import spock.lang.Specification +import spock.lang.Unroll + +import java.nio.charset.StandardCharsets +import java.util.concurrent.CompletableFuture +import java.util.concurrent.ExecutionException +import java.util.concurrent.Future +import java.util.concurrent.TimeUnit +import java.util.zip.GZIPOutputStream + +@Execution(ExecutionMode.CONCURRENT) +class ConnectionManagerSpec extends Specification { + private static void patch(DefaultHttpClient httpClient, EmbeddedTestConnectionBase... connections) { + httpClient.connectionManager = new ConnectionManager(httpClient.connectionManager) { + int i = 0 + + @Override + protected ChannelFuture doConnect(DefaultHttpClient.RequestKey requestKey, ChannelInitializer channelInitializer) { + try { + def connection = connections[i++] + connection.clientChannel = new EmbeddedChannel(new DummyChannelId('client' + i), connection.clientInitializer, channelInitializer) + def promise = connection.clientChannel.newPromise() + promise.setSuccess() + return promise + } catch (Throwable t) { + // print it immediately to make sure it's not swallowed + t.printStackTrace() + throw t + } + } + } + } + + def 'simple http2 get'() { + def ctx = ApplicationContext.run([ + 'micronaut.http.client.ssl.insecure-trust-all-certificates': true, + ]) + def client = ctx.getBean(DefaultHttpClient) + + def conn = new EmbeddedTestConnectionHttp2() + conn.setupHttp2Tls() + patch(client, conn) + + def future = conn.testExchangeRequest(client) + conn.exchangeSettings() + conn.testExchangeResponse(future) + + cleanup: + client.close() + ctx.close() + } + + def 'http2 streaming get'() { + def ctx = ApplicationContext.run([ + 'micronaut.http.client.ssl.insecure-trust-all-certificates': true, + ]) + def client = ctx.getBean(DefaultHttpClient) + + def conn = new EmbeddedTestConnectionHttp2() + conn.setupHttp2Tls() + patch(client, conn) + + def r1 = conn.testStreamingRequest(client) + conn.exchangeSettings() + conn.testStreamingResponse(r1) + + cleanup: + client.close() + ctx.close() + } + + def 'simple http1 get'() { + def ctx = ApplicationContext.run() + def client = ctx.getBean(DefaultHttpClient) + + def conn = new EmbeddedTestConnectionHttp1() + conn.setupHttp1() + patch(client, conn) + + conn.testExchangeResponse(conn.testExchangeRequest(client)) + + cleanup: + client.close() + ctx.close() + } + + def 'http1 get with compression'() { + def ctx = ApplicationContext.run() + def client = ctx.getBean(DefaultHttpClient) + + def conn = new EmbeddedTestConnectionHttp1() + conn.setupHttp1() + conn.serverChannel.pipeline().addLast(new HttpContentCompressor()) + patch(client, conn) + + def future = Mono.from(client.exchange( + HttpRequest.GET('http://example.com/foo').header('accept-encoding', 'gzip'), String)).toFuture() + future.exceptionally(t -> t.printStackTrace()) + conn.advance() + + assert conn.serverChannel.readInbound() instanceof io.netty.handler.codec.http.HttpRequest + + def response = new DefaultFullHttpResponse(io.netty.handler.codec.http.HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.wrappedBuffer("foo".bytes)) + response.headers().add('content-length', 3) + conn.serverChannel.writeOutbound(response) + + conn.advance() + assert future.get().status() == HttpStatus.OK + assert future.get().body() == 'foo' + + cleanup: + client.close() + ctx.close() + } + + def 'http2 get with compression'() { + def ctx = ApplicationContext.run([ + 'micronaut.http.client.ssl.insecure-trust-all-certificates': true, + ]) + def client = ctx.getBean(DefaultHttpClient) + + def conn = new EmbeddedTestConnectionHttp2() + conn.setupHttp2Tls() + patch(client, conn) + + def future = Mono.from(client.exchange('https://example.com/foo', String)).toFuture() + future.exceptionally(t -> t.printStackTrace()) + conn.exchangeSettings() + + Http2HeadersFrame request = conn.serverChannel.readInbound() + def responseHeaders = new DefaultHttp2Headers() + responseHeaders.add(Http2Headers.PseudoHeaderName.STATUS.value(), "200") + responseHeaders.add('content-encoding', "gzip") + conn.serverChannel.writeOutbound(new DefaultHttp2HeadersFrame(responseHeaders, false).stream(request.stream())) + def compressedOut = new ByteArrayOutputStream() + try (OutputStream os = new GZIPOutputStream(compressedOut)) { + os.write('foo'.bytes) + } + conn.serverChannel.writeOutbound(new DefaultHttp2DataFrame(Unpooled.wrappedBuffer(compressedOut.toByteArray()), true).stream(request.stream())) + + conn.advance() + def response = future.get() + assert response.status() == HttpStatus.OK + assert response.body() == 'foo' + + cleanup: + client.close() + ctx.close() + } + + def 'simple http1 tls get'() { + def ctx = ApplicationContext.run([ + 'micronaut.http.client.ssl.insecure-trust-all-certificates': true, + ]) + def client = ctx.getBean(DefaultHttpClient) + + def conn = new EmbeddedTestConnectionHttp1() + conn.setupHttp1Tls() + patch(client, conn) + + conn.testExchangeResponse(conn.testExchangeRequest(client)) + + cleanup: + client.close() + ctx.close() + } + + def 'simple h2c get'() { + def ctx = ApplicationContext.run([ + 'micronaut.http.client.plaintext-mode': 'h2c', + ]) + def client = ctx.getBean(DefaultHttpClient) + + def conn = new EmbeddedTestConnectionHttp2() + conn.setupH2c() + patch(client, conn) + + def future = conn.testExchangeRequest(client) + conn.exchangeH2c() + conn.testExchangeResponse(future) + + cleanup: + client.close() + ctx.close() + } + + def 'http1 streaming get'() { + def ctx = ApplicationContext.run() + def client = ctx.getBean(DefaultHttpClient) + + def conn = new EmbeddedTestConnectionHttp1() + conn.setupHttp1() + patch(client, conn) + + conn.testStreamingResponse(conn.testStreamingRequest(client)) + + cleanup: + client.close() + ctx.close() + } + + def 'http2 concurrent stream'() { + given: + def ctx = ApplicationContext.run([ + 'micronaut.http.client.ssl.insecure-trust-all-certificates': true, + ]) + def client = ctx.getBean(DefaultHttpClient) + + def conn1 = new EmbeddedTestConnectionHttp2() + conn1.setupHttp2Tls() + def conn2 = new EmbeddedTestConnectionHttp2() + conn2.setupHttp2Tls() + patch(client, conn1, conn2) + + when: + // start two requests. this will open two connections + def f1 = Mono.from(client.exchange('https://example.com/r1')).toFuture() + f1.exceptionally(t -> t.printStackTrace()) + def f2 = Mono.from(client.exchange('https://example.com/r2')).toFuture() + f2.exceptionally(t -> t.printStackTrace()) + + then: + // no data yet, haven't finished the handshake + conn1.serverChannel.readInbound() == null + + when: + // finish handshake for first connection + conn1.exchangeSettings() + then: + // both requests immediately go to the first connection + def req1 = conn1.serverChannel. readInbound() + req1.headers().get(Http2Headers.PseudoHeaderName.PATH.value()) == '/r1' + def req2 = conn1.serverChannel. readInbound() + req2.stream().id() != req1.stream().id() + req2.headers().get(Http2Headers.PseudoHeaderName.PATH.value()) == '/r2' + + when: + // start a third request, this should reuse the existing connection + def f3 = Mono.from(client.exchange('https://example.com/r3')).toFuture() + f3.exceptionally(t -> t.printStackTrace()) + conn1.advance() + then: + def req3 = conn1.serverChannel. readInbound() + req3.stream().id() != req1.stream().id() + req3.stream().id() != req2.stream().id() + req3.headers().get(Http2Headers.PseudoHeaderName.PATH.value()) == '/r3' + + // finish up the third request + when: + conn1.respondOk(req3.stream()) + conn1.advance() + then: + f3.get().status() == HttpStatus.OK + + // finish up the second and first request + when: + conn1.respondOk(req2.stream()) + conn1.respondOk(req1.stream()) + conn1.advance() + then: + f1.get().status() == HttpStatus.OK + f2.get().status() == HttpStatus.OK + + cleanup: + client.close() + ctx.close() + } + + def 'http1 reuse'() { + def ctx = ApplicationContext.run() + def client = ctx.getBean(DefaultHttpClient) + + def conn = new EmbeddedTestConnectionHttp1() + conn.setupHttp1() + patch(client, conn) + + conn.testExchangeResponse(conn.testExchangeRequest(client)) + + Queue responseData1 = conn.testStreamingRequest(client) + conn.testStreamingResponse(responseData1) + conn.testExchangeResponse(conn.testExchangeRequest(client)) + Queue responseData = conn.testStreamingRequest(client) + conn.testStreamingResponse(responseData) + + cleanup: + client.close() + ctx.close() + } + + def 'http1 plain text customization'() { + given: + def ctx = ApplicationContext.run() + def client = ctx.getBean(DefaultHttpClient) + def tracker = ctx.getBean(CustomizerTracker) + + def conn = new EmbeddedTestConnectionHttp1() + conn.setupHttp1() + patch(client, conn) + + when: + conn.testExchangeResponse(conn.testExchangeRequest(client)) + + Queue responseData = conn.testStreamingRequest(client) + conn.testStreamingResponse(responseData) + + then: + def outerChannel = tracker.initialPipelineBuilt.poll() + outerChannel.channel == conn.clientChannel + outerChannel.handlerNames.contains(ChannelPipelineCustomizer.HANDLER_HTTP_CLIENT_CODEC) + outerChannel.handlerNames.contains(ChannelPipelineCustomizer.HANDLER_HTTP_DECODER) + !outerChannel.handlerNames.contains(ChannelPipelineCustomizer.HANDLER_HTTP_AGGREGATOR) + !outerChannel.handlerNames.contains(ChannelPipelineCustomizer.HANDLER_HTTP_STREAM) + tracker.initialPipelineBuilt.isEmpty() + + def innerChannel = tracker.streamPipelineBuilt.poll() + innerChannel.channel == conn.clientChannel + innerChannel.handlerNames == outerChannel.handlerNames + tracker.streamPipelineBuilt.isEmpty() + + def req1Channel = tracker.requestPipelineBuilt.poll() + req1Channel.channel == conn.clientChannel + req1Channel.handlerNames.contains(ChannelPipelineCustomizer.HANDLER_HTTP_AGGREGATOR) + req1Channel.handlerNames.contains(ChannelPipelineCustomizer.HANDLER_MICRONAUT_FULL_HTTP_RESPONSE) + + def req2Channel = tracker.requestPipelineBuilt.poll() + req2Channel.channel == conn.clientChannel + req2Channel.handlerNames.contains(ChannelPipelineCustomizer.HANDLER_MICRONAUT_HTTP_RESPONSE_FULL) + req2Channel.handlerNames.contains(ChannelPipelineCustomizer.HANDLER_MICRONAUT_HTTP_RESPONSE_STREAM) + + tracker.requestPipelineBuilt.isEmpty() + + cleanup: + client.close() + ctx.close() + } + + def 'http2 customization'(boolean secure) { + given: + def ctx = ApplicationContext.run([ + 'micronaut.http.client.ssl.insecure-trust-all-certificates': true, + 'micronaut.http.client.plaintext-mode': 'h2c', + ]) + def client = ctx.getBean(DefaultHttpClient) + def tracker = ctx.getBean(CustomizerTracker) + + def conn = new EmbeddedTestConnectionHttp2() + if (secure) { + conn.setupHttp2Tls() + } else { + conn.setupH2c() + } + patch(client, conn) + + when: + def r1 = conn.testExchangeRequest(client) + if (secure) { + conn.exchangeSettings() + } else { + conn.exchangeH2c() + } + conn.testExchangeResponse(r1) + + def r2 = conn.testStreamingRequest(client) + conn.testStreamingResponse(r2) + + then: + def outerChannel = tracker.initialPipelineBuilt.poll() + outerChannel.channel == conn.clientChannel + outerChannel.handlerNames.contains(ChannelPipelineCustomizer.HANDLER_SSL) == secure + !outerChannel.handlerNames.contains(ChannelPipelineCustomizer.HANDLER_HTTP2_CONNECTION) + tracker.initialPipelineBuilt.isEmpty() + + def innerChannel = tracker.streamPipelineBuilt.poll() + innerChannel.channel == conn.clientChannel + innerChannel.handlerNames.contains(ChannelPipelineCustomizer.HANDLER_HTTP2_CONNECTION) + tracker.streamPipelineBuilt.isEmpty() + + def req1Channel = tracker.requestPipelineBuilt.poll() + req1Channel.role == NettyClientCustomizer.ChannelRole.HTTP2_STREAM + req1Channel.channel !== conn.clientChannel + req1Channel.handlerNames.contains(ChannelPipelineCustomizer.HANDLER_HTTP_AGGREGATOR) + req1Channel.handlerNames.contains(ChannelPipelineCustomizer.HANDLER_MICRONAUT_FULL_HTTP_RESPONSE) + + def req2Channel = tracker.requestPipelineBuilt.poll() + req2Channel.role == NettyClientCustomizer.ChannelRole.HTTP2_STREAM + req2Channel.channel !== conn.clientChannel + req2Channel.handlerNames.contains(ChannelPipelineCustomizer.HANDLER_MICRONAUT_HTTP_RESPONSE_FULL) + req2Channel.handlerNames.contains(ChannelPipelineCustomizer.HANDLER_MICRONAUT_HTTP_RESPONSE_STREAM) + + tracker.requestPipelineBuilt.isEmpty() + + cleanup: + client.close() + ctx.close() + + where: + secure << [true, false] + } + + def 'http1 exchange read timeout'() { + def ctx = ApplicationContext.run([ + 'micronaut.http.client.read-timeout': '5s', + ]) + def client = ctx.getBean(DefaultHttpClient) + + def conn = new EmbeddedTestConnectionHttp1() + conn.setupHttp1() + patch(client, conn) + + // do one request + conn.testExchangeResponse(conn.testExchangeRequest(client)) + conn.clientChannel.unfreezeTime() + // connection is in reserve, should not time out + TimeUnit.SECONDS.sleep(10) + conn.advance() + + // second request + def future = Mono.from(client.exchange('http://example.com/foo', String)).toFuture() + conn.advance() + + // todo: move to advanceTime once IdleStateHandler supports it + TimeUnit.SECONDS.sleep(5) + conn.advance() + + assert future.isDone() + when: + future.get() + then: + def e = thrown ExecutionException + e.cause instanceof ReadTimeoutException + + cleanup: + client.close() + ctx.close() + } + + def 'http2 exchange read timeout'() { + def ctx = ApplicationContext.run([ + 'micronaut.http.client.ssl.insecure-trust-all-certificates': true, + 'micronaut.http.client.read-timeout': '5s', + ]) + def client = ctx.getBean(DefaultHttpClient) + + def conn = new EmbeddedTestConnectionHttp2() + conn.setupHttp2Tls() + patch(client, conn) + + // one request opens the connection + def r1 = conn.testExchangeRequest(client) + conn.exchangeSettings() + conn.testExchangeResponse(r1) + conn.clientChannel.unfreezeTime() + + // connection is in reserve, should not time out + TimeUnit.SECONDS.sleep(10) + conn.advance() + + // second request + def future = Mono.from(client.exchange('https://example.com/foo', String)).toFuture() + conn.advance() + + // todo: move to advanceTime once IdleStateHandler supports it + TimeUnit.SECONDS.sleep(5) + conn.advance() + + assert future.isDone() + when: + future.get() + then: + def e = thrown ExecutionException + e.cause instanceof ReadTimeoutException + + cleanup: + client.close() + ctx.close() + } + + def 'http1 ttl'() { + def ctx = ApplicationContext.run([ + 'micronaut.http.client.connect-ttl': '100s', + ]) + def client = ctx.getBean(DefaultHttpClient) + + def conn1 = new EmbeddedTestConnectionHttp1() + conn1.setupHttp1() + def conn2 = new EmbeddedTestConnectionHttp1() + conn2.setupHttp1() + patch(client, conn1, conn2) + + def r1 = conn1.testExchangeRequest(client) + conn1.clientChannel.advanceTimeBy(101, TimeUnit.SECONDS) + conn1.testExchangeResponse(r1) + + // conn1 should expire now, conn2 will be the next connection + conn2.testExchangeResponse(conn2.testExchangeRequest(client)) + + cleanup: + client.close() + ctx.close() + } + + def 'http2 ttl'() { + def ctx = ApplicationContext.run([ + 'micronaut.http.client.ssl.insecure-trust-all-certificates': true, + 'micronaut.http.client.connect-ttl': '100s', + ]) + def client = ctx.getBean(DefaultHttpClient) + + def conn1 = new EmbeddedTestConnectionHttp2() + conn1.setupHttp2Tls() + def conn2 = new EmbeddedTestConnectionHttp2() + conn2.setupHttp2Tls() + patch(client, conn1, conn2) + + def r1 = conn1.testExchangeRequest(client) + conn1.exchangeSettings() + conn1.clientChannel.advanceTimeBy(101, TimeUnit.SECONDS) + conn1.testExchangeResponse(r1) + + // conn1 should expire now, conn2 will be the next connection + def r2 = conn2.testExchangeRequest(client) + conn2.exchangeSettings() + conn2.testExchangeResponse(r2) + + cleanup: + client.close() + ctx.close() + } + + def 'http1 pool timeout'() { + def ctx = ApplicationContext.run([ + 'micronaut.http.client.connection-pool-idle-timeout': '5s', + ]) + def client = ctx.getBean(DefaultHttpClient) + + def conn1 = new EmbeddedTestConnectionHttp1() + conn1.setupHttp1() + def conn2 = new EmbeddedTestConnectionHttp1() + conn2.setupHttp1() + patch(client, conn1, conn2) + + conn1.testExchangeResponse(conn1.testExchangeRequest(client)) + conn1.clientChannel.unfreezeTime() + // todo: move to advanceTime once IdleStateHandler supports it + TimeUnit.SECONDS.sleep(5) + conn1.advance() + // conn1 should expire now, conn2 will be the next connection + conn2.testExchangeResponse(conn2.testExchangeRequest(client)) + + cleanup: + client.close() + ctx.close() + } + + @Unroll + def 'websocket ssl=#secure'(boolean secure) { + def ctx = ApplicationContext.run([ + 'micronaut.http.client.ssl.insecure-trust-all-certificates': true, + 'micronaut.http.client.connect-ttl': '100s', + ]) + def client = ctx.getBean(DefaultHttpClient) + + def conn = new EmbeddedTestConnectionHttp1() + if (secure) { + conn.setupHttp1Tls() + } else { + conn.setupHttp1() + } + conn.serverChannel.pipeline().addLast(new HttpObjectAggregator(1024)) + patch(client, conn) + + def uri = conn.scheme + "://example.com/foo" + Mono.from(client.connect(Ws, uri)).subscribe() + conn.advance() + io.netty.handler.codec.http.HttpRequest req = conn.serverChannel.readInbound() + def handshaker = new WebSocketServerHandshakerFactory(uri, null, false).newHandshaker(req) + handshaker.handshake(conn.serverChannel, req) + conn.advance() + + conn.serverChannel.writeOutbound(new TextWebSocketFrame('foo')) + conn.advance() + TextWebSocketFrame response = conn.serverChannel.readInbound() + assert response.text() == 'received: foo' + + cleanup: + client.close() + ctx.close() + + where: + secure << [true, false] + } + + @ClientWebSocket + static class Ws implements AutoCloseable { + @Override + void close() throws Exception { + } + + @OnMessage + def onMessage(String msg, WebSocketSession session) { + return session.send('received: ' + msg) + } + } + + def 'cancel pool acquisition'() { + def ctx = ApplicationContext.run() + def client = ctx.getBean(DefaultHttpClient) + + def conn = new EmbeddedTestConnectionHttp1() + conn.setupHttp1() + + ChannelPromise delayPromise + def normalInit = conn.clientInitializer + // hack: delay the channelActive call until we complete delayPromise + conn.clientInitializer = new ChannelInitializer() { + @Override + protected void initChannel(EmbeddedChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + void channelActive(ChannelHandlerContext chtx) throws Exception { + delayPromise = chtx.newPromise() + delayPromise.addListener(new GenericFutureListener>() { + @Override + void operationComplete(io.netty.util.concurrent.Future future) throws Exception { + chtx.fireChannelActive() + } + }) + } + }) + ch.pipeline().addLast(normalInit) + } + } + + patch(client, conn) + + def subscription = Mono.from(client.exchange(conn.scheme + '://example.com/foo')).subscribe() + conn.advance() + subscription.dispose() + // this completes the handshake + delayPromise.setSuccess() + conn.advance() + + conn.testExchangeResponse(conn.testExchangeRequest(client)) + + cleanup: + client.close() + ctx.close() + } + + def 'max pending acquires'() { + def ctx = ApplicationContext.run([ + 'micronaut.http.client.pool.max-pending-acquires': 5, + 'micronaut.http.client.pool.max-pending-connections': 1, + ]) + def client = ctx.getBean(DefaultHttpClient) + + def conn = new EmbeddedTestConnectionHttp1() + conn.setupHttp1() + + ChannelPromise delayPromise + def normalInit = conn.clientInitializer + // hack: delay the channelActive call until we complete delayPromise + conn.clientInitializer = new ChannelInitializer() { + @Override + protected void initChannel(EmbeddedChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + void channelActive(ChannelHandlerContext chtx) throws Exception { + delayPromise = chtx.newPromise() + delayPromise.addListener(new GenericFutureListener>() { + @Override + void operationComplete(io.netty.util.concurrent.Future future) throws Exception { + chtx.fireChannelActive() + } + }) + } + }) + ch.pipeline().addLast(normalInit) + } + } + + patch(client, conn) + + List> futures = new ArrayList<>() + for (int i = 0; i < 6; i++) { + futures.add(Mono.from(client.exchange(conn.scheme + '://example.com/foo')).toFuture()) + } + conn.advance() + + for (int i = 0; i < 5; i++) { + assert !futures.get(i).isDone() + } + assert futures.get(5).isDone() + assert futures.get(5).completedExceptionally + + cleanup: + client.close() + ctx.close() + } + + def 'max http1 connections'() { + def ctx = ApplicationContext.run([ + 'micronaut.http.client.pool.max-pending-connections': 1, + 'micronaut.http.client.pool.max-concurrent-http1-connections': 2, + ]) + def client = ctx.getBean(DefaultHttpClient) + + def conn1 = new EmbeddedTestConnectionHttp1() + conn1.setupHttp1() + def conn2 = new EmbeddedTestConnectionHttp1() + conn2.setupHttp1() + + patch(client, conn1, conn2) + + // we open four requests, the first two of which will open connections. + List>> futures = [ + conn1.testExchangeRequest(client), + conn2.testExchangeRequest(client), + conn1.testExchangeRequest(client), + conn1.testExchangeRequest(client), + ] + + conn1.testExchangeResponse(futures.get(0)) + conn1.testExchangeResponse(futures.get(2)) + conn1.testExchangeResponse(futures.get(3)) + conn2.testExchangeResponse(futures.get(1)) + + cleanup: + client.close() + ctx.close() + } + + def 'multipart request'() { + def ctx = ApplicationContext.run() + def client = ctx.getBean(DefaultHttpClient) + + def conn = new EmbeddedTestConnectionHttp1() + conn.setupHttp1() + patch(client, conn) + conn.serverChannel.pipeline().addLast(new HttpObjectAggregator(1024)) + + def future = Mono.from(client.exchange(HttpRequest.POST(conn.scheme + '://example.com/foo', MultipartBody.builder() + .addPart('foo', 'fn', MediaType.TEXT_PLAIN_TYPE, 'bar'.bytes) + .build()) + .contentType(MediaType.MULTIPART_FORM_DATA), String)).toFuture() + future.exceptionally(t -> t.printStackTrace()) + conn.advance() + + FullHttpRequest request = conn.serverChannel.readInbound() + assert request.uri() == '/foo' + assert request.method() == HttpMethod.POST + assert request.headers().get('host') == 'example.com' + assert request.headers().get("connection") == "keep-alive" + assert request.content().isReadable(100) // cba to check the exact content + + def response = new DefaultFullHttpResponse(io.netty.handler.codec.http.HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.wrappedBuffer('foo'.bytes)) + response.headers().add("Content-Length", 3) + conn.serverChannel.writeOutbound(response) + conn.advance() + assert future.get().body() == 'foo' + + cleanup: + client.close() + ctx.close() + } + + def 'publisher request'() { + def ctx = ApplicationContext.run() + def client = ctx.getBean(DefaultHttpClient) + + def conn = new EmbeddedTestConnectionHttp1() + conn.setupHttp1() + patch(client, conn) + conn.serverChannel.pipeline().addLast(new HttpObjectAggregator(1024)) + + def future = Mono.from(client.exchange(HttpRequest.POST(conn.scheme + '://example.com/foo', Flux.fromIterable([1,2,3,4,5])) + .contentType(MediaType.APPLICATION_JSON_TYPE), String)).toFuture() + future.exceptionally(t -> t.printStackTrace()) + conn.advance() + + FullHttpRequest request = conn.serverChannel.readInbound() + assert request.uri() == '/foo' + assert request.method() == HttpMethod.POST + assert request.headers().get('host') == 'example.com' + assert request.headers().get("connection") == "keep-alive" + assert request.content().toString(StandardCharsets.UTF_8) == '[1,2,3,4,5]' + + def response = new DefaultFullHttpResponse(io.netty.handler.codec.http.HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.wrappedBuffer('foo'.bytes)) + response.headers().add("Content-Length", 3) + conn.serverChannel.writeOutbound(response) + conn.advance() + assert future.get().body() == 'foo' + + cleanup: + client.close() + ctx.close() + } + + def 'connection pool disabled http1'() { + def ctx = ApplicationContext.run([ + 'micronaut.http.client.pool.enabled': false, + ]) + def client = ctx.getBean(DefaultHttpClient) + + def conn1 = new EmbeddedTestConnectionHttp1() + conn1.setupHttp1() + def conn2 = new EmbeddedTestConnectionHttp1() + conn2.setupHttp1() + patch(client, conn1, conn2) + + def r1 = conn1.testExchangeRequest(client) + conn1.testExchangeResponse(r1, "close") + + def r2 = conn2.testExchangeRequest(client) + conn2.testExchangeResponse(r2, "close") + + cleanup: + client.close() + ctx.close() + } + + def 'connection pool disabled http2'() { + def ctx = ApplicationContext.run([ + 'micronaut.http.client.pool.enabled': false, + 'micronaut.http.client.ssl.insecure-trust-all-certificates': true, + ]) + def client = ctx.getBean(DefaultHttpClient) + + def conn1 = new EmbeddedTestConnectionHttp2() + conn1.setupHttp2Tls() + def conn2 = new EmbeddedTestConnectionHttp2() + conn2.setupHttp2Tls() + patch(client, conn1, conn2) + + def r1 = conn1.testExchangeRequest(client) + conn1.exchangeSettings() + conn1.testExchangeResponse(r1) + + def r2 = conn2.testExchangeRequest(client) + conn2.exchangeSettings() + conn2.testExchangeResponse(r2) + + cleanup: + client.close() + ctx.close() + } + + static class EmbeddedTestConnectionBase { + final EmbeddedChannel serverChannel + EmbeddedChannel clientChannel + ChannelInitializer clientInitializer = new ChannelInitializer() { + @Override + protected void initChannel(EmbeddedChannel ch) throws Exception { + ch.freezeTime() + ch.config().setAutoRead(false) + EmbeddedTestUtil.connect(serverChannel, ch) + } + } + + EmbeddedTestConnectionBase() { + serverChannel = new EmbeddedServerChannel(new DummyChannelId('server')) + serverChannel.freezeTime() + serverChannel.config().setAutoRead(true) + } + + final void advance() { + EmbeddedTestUtil.advance(serverChannel, clientChannel) + } + } + + static class EmbeddedServerChannel extends EmbeddedChannel implements ServerChannel { + EmbeddedServerChannel(ChannelId channelId) { + super(channelId) + } + } + + static class EmbeddedTestConnectionHttp1 extends EmbeddedTestConnectionBase { + private String scheme + + void setupHttp1() { + scheme = 'http' + serverChannel.pipeline() + .addLast(new HttpServerCodec()) + } + + void setupHttp1Tls() { + def certificate = new SelfSignedCertificate() + def builder = SslContextBuilder.forServer(certificate.key(), certificate.cert()) + CertificateProvidedSslBuilder.setupSslBuilder(builder, new SslConfiguration(), HttpVersion.HTTP_1_1); + def tlsHandler = builder.build().newHandler(ByteBufAllocator.DEFAULT) + + scheme = 'https' + serverChannel.pipeline() + .addLast(tlsHandler) + .addLast(new HttpServerCodec()) + } + + void respondOk() { + def response = new DefaultFullHttpResponse(io.netty.handler.codec.http.HttpVersion.HTTP_1_1, HttpResponseStatus.OK) + response.headers().add('content-length', 0) + serverChannel.writeOutbound(response) + } + + CompletableFuture> testExchangeRequest(HttpClient client) { + def future = Mono.from(client.exchange(scheme + '://example.com/foo')).toFuture() + future.exceptionally(t -> t.printStackTrace()) + advance() + return future + } + + void testExchangeResponse(CompletableFuture> future, String connectionHeader = "keep-alive") { + io.netty.handler.codec.http.HttpRequest request = serverChannel.readInbound() + assert request.uri() == '/foo' + assert request.method() == HttpMethod.GET + assert request.headers().get('host') == 'example.com' + assert request.headers().get("connection") == connectionHeader + + def tail = serverChannel.readInbound() + assert tail == null || tail instanceof LastHttpContent + + respondOk() + advance() + + assert future.get().status() == HttpStatus.OK + } + + private Queue testStreamingRequest(StreamingHttpClient client) { + def responseData = new ArrayDeque() + Flux.from(client.dataStream(HttpRequest.GET(scheme + '://example.com/foo'))) + .doOnError(t -> t.printStackTrace()) + .doOnComplete(() -> responseData.add("END")) + .subscribe(b -> responseData.add(b.toString(StandardCharsets.UTF_8))) + responseData + } + + private void testStreamingResponse(Queue responseData) { + advance() + + io.netty.handler.codec.http.HttpRequest request = serverChannel.readInbound() + assert request.uri() == '/foo' + assert request.method() == HttpMethod.GET + assert request.headers().get('host') == 'example.com' + assert request.headers().get("connection") == "keep-alive" + + def tail = serverChannel.readInbound() + assert tail == null || tail instanceof LastHttpContent + + def response = new DefaultHttpResponse(io.netty.handler.codec.http.HttpVersion.HTTP_1_1, HttpResponseStatus.OK) + response.headers().add('content-length', 6) + serverChannel.writeOutbound(response) + serverChannel.writeOutbound(new DefaultHttpContent(Unpooled.wrappedBuffer('foo'.bytes))) + advance() + + assert responseData.poll() == 'foo' + assert responseData.isEmpty() + + serverChannel.writeOutbound(new DefaultLastHttpContent(Unpooled.wrappedBuffer('bar'.bytes))) + advance() + + assert responseData.poll() == 'bar' + assert responseData.poll() == 'END' + } + } + + static class EmbeddedTestConnectionHttp2 extends EmbeddedTestConnectionBase { + private String scheme + Http2FrameStream h2cResponseStream + + void setupHttp2Tls() { + scheme = 'https' + + def certificate = new SelfSignedCertificate() + def builder = SslContextBuilder.forServer(certificate.key(), certificate.cert()) + CertificateProvidedSslBuilder.setupSslBuilder(builder, new SslConfiguration(), HttpVersion.HTTP_2_0); + def tlsHandler = builder.build().newHandler(ByteBufAllocator.DEFAULT) + + serverChannel.pipeline() + .addLast(tlsHandler) + .addLast(new ApplicationProtocolNegotiationHandler("h2") { + @Override + protected void configurePipeline(ChannelHandlerContext chtx, String protocol) throws Exception { + chtx.pipeline() + .addLast(Http2FrameCodecBuilder.forServer().build()) + } + }) + } + + void setupH2c() { + scheme = 'http' + + ChannelHandler responseStreamHandler = new ChannelInboundHandlerAdapter() { + @Override + void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof Http2FrameStreamEvent && evt.stream().id() == 1) { + h2cResponseStream = evt.stream() + } + + super.userEventTriggered(ctx, evt) + } + } + Http2FrameCodec frameCodec = Http2FrameCodecBuilder.forServer() + .build() + HttpServerUpgradeHandler.UpgradeCodecFactory upgradeCodecFactory = protocol -> { + if (AsciiString.contentEquals("h2c", protocol)) { + return new Http2ServerUpgradeCodec(frameCodec, responseStreamHandler) + } else { + return null + } + } + + HttpServerCodec sourceCodec = new HttpServerCodec() + serverChannel.pipeline() + .addLast(new LoggingHandler(LogLevel.INFO)) + .addLast(sourceCodec) + .addLast(new HttpServerUpgradeHandler(sourceCodec, upgradeCodecFactory, 1024)) + } + + void exchangeSettings() { + advance() + + assert serverChannel.readInbound() instanceof Http2SettingsFrame + assert serverChannel.readInbound() instanceof Http2SettingsAckFrame + } + + void exchangeH2c() { + advance() + + Http2HeadersFrame upgradeRequest = serverChannel.readInbound() + assert upgradeRequest.headers().get(Http2Headers.PseudoHeaderName.METHOD.value()) == 'GET' + assert upgradeRequest.headers().get(Http2Headers.PseudoHeaderName.PATH.value()) == '/' + assert upgradeRequest.headers().get(Http2Headers.PseudoHeaderName.AUTHORITY.value()) == 'example.com:80' + assert upgradeRequest.headers().get('content-length') == '0' + // client closes the stream immediately + assert upgradeRequest.stream().state() == Http2Stream.State.CLOSED + + assert serverChannel.readInbound() instanceof Http2SettingsFrame + assert serverChannel.readInbound() instanceof Http2ResetFrame + assert serverChannel.readInbound() instanceof Http2SettingsAckFrame + } + + void respondOk(Http2FrameStream stream) { + def responseHeaders = new DefaultHttp2Headers() + responseHeaders.add(Http2Headers.PseudoHeaderName.STATUS.value(), "200") + serverChannel.writeOutbound(new DefaultHttp2HeadersFrame(responseHeaders, true).stream(stream)) + } + + Future> testExchangeRequest(HttpClient client) { + def future = Mono.from(client.exchange(scheme + '://example.com/foo')).toFuture() + future.exceptionally(t -> t.printStackTrace()) + return future + } + + void testExchangeResponse(Future> future) { + Http2HeadersFrame request = serverChannel.readInbound() + assert request.headers().get(Http2Headers.PseudoHeaderName.PATH.value()) == '/foo' + assert request.headers().get(Http2Headers.PseudoHeaderName.SCHEME.value()) == scheme + assert request.headers().get(Http2Headers.PseudoHeaderName.AUTHORITY.value()) == 'example.com' + assert request.headers().get(Http2Headers.PseudoHeaderName.METHOD.value()) == 'GET' + + respondOk(request.stream()) + advance() + + def response = future.get() + assert response.status() == HttpStatus.OK + } + + Queue testStreamingRequest(StreamingHttpClient client) { + def responseData = new ArrayDeque() + Flux.from(client.dataStream(HttpRequest.GET(scheme + '://example.com/foo'))) + .doOnError(t -> t.printStackTrace()) + .doOnComplete(() -> responseData.add("END")) + .subscribe(b -> responseData.add(b.toString(StandardCharsets.UTF_8))) + return responseData + } + + void testStreamingResponse(Queue responseData) { + advance() + Http2HeadersFrame request = serverChannel.readInbound() + assert request.headers().get(Http2Headers.PseudoHeaderName.PATH.value()) == '/foo' + assert request.headers().get(Http2Headers.PseudoHeaderName.SCHEME.value()) == scheme + assert request.headers().get(Http2Headers.PseudoHeaderName.AUTHORITY.value()) == 'example.com' + assert request.headers().get(Http2Headers.PseudoHeaderName.METHOD.value()) == 'GET' + + def responseHeaders = new DefaultHttp2Headers() + responseHeaders.add(Http2Headers.PseudoHeaderName.STATUS.value(), "200") + serverChannel.writeOutbound(new DefaultHttp2HeadersFrame(responseHeaders, false).stream(request.stream())) + serverChannel.writeOutbound(new DefaultHttp2DataFrame(Unpooled.wrappedBuffer('foo'.bytes)).stream(request.stream())) + advance() + + assert responseData.poll() == 'foo' + assert responseData.isEmpty() + + serverChannel.writeOutbound(new DefaultHttp2DataFrame(Unpooled.wrappedBuffer('bar'.bytes), true).stream(request.stream())) + advance() + + assert responseData.poll() == 'bar' + assert responseData.poll() == 'END' + } + } + + @Singleton + static class CustomizerTracker implements NettyClientCustomizer, BeanCreatedEventListener { + final Queue initialPipelineBuilt = new ArrayDeque<>() + final Queue streamPipelineBuilt = new ArrayDeque<>() + final Queue requestPipelineBuilt = new ArrayDeque<>() + + @Override + NettyClientCustomizer specializeForChannel(Channel channel, ChannelRole role) { + return new NettyClientCustomizer() { + @Override + NettyClientCustomizer specializeForChannel(Channel channel_, ChannelRole role_) { + return CustomizerTracker.this.specializeForChannel(channel_, role_) + } + + Snapshot snap() { + return new Snapshot(channel, role, channel.pipeline().names()) + } + + @Override + void onInitialPipelineBuilt() { + initialPipelineBuilt.add(snap()) + } + + @Override + void onStreamPipelineBuilt() { + streamPipelineBuilt.add(snap()) + } + + @Override + void onRequestPipelineBuilt() { + requestPipelineBuilt.add(snap()) + } + } + } + + @Override + Registry onCreated(BeanCreatedEvent event) { + event.getBean().register(this) + return event.getBean() + } + + static class Snapshot { + final Channel channel + final ChannelRole role + final List handlerNames + + Snapshot(Channel channel, ChannelRole role, List handlerNames) { + this.channel = channel + this.role = role + this.handlerNames = handlerNames + } + } + } +} diff --git a/http-client/src/test/groovy/io/micronaut/http/client/netty/DefaultNettyHttpClientRegistrySpec.groovy b/http-client/src/test/groovy/io/micronaut/http/client/netty/DefaultNettyHttpClientRegistrySpec.groovy index 5abda077f12..2aa503905ad 100644 --- a/http-client/src/test/groovy/io/micronaut/http/client/netty/DefaultNettyHttpClientRegistrySpec.groovy +++ b/http-client/src/test/groovy/io/micronaut/http/client/netty/DefaultNettyHttpClientRegistrySpec.groovy @@ -1,17 +1,13 @@ package io.micronaut.http.client.netty import io.micronaut.context.ApplicationContext -import io.micronaut.context.BeanContext import io.micronaut.context.annotation.Requires import io.micronaut.context.event.BeanCreatedEvent import io.micronaut.context.event.BeanCreatedEventListener import io.micronaut.http.annotation.Get import io.micronaut.http.client.HttpClient import io.micronaut.http.client.annotation.Client -import io.micronaut.http.netty.channel.ChannelPipelineCustomizer -import io.micronaut.runtime.server.EmbeddedServer -import io.micronaut.test.extensions.spock.annotation.MicronautTest -import io.netty.util.Attribute +import io.netty.channel.Channel import io.netty.util.AttributeKey import jakarta.inject.Inject import jakarta.inject.Singleton @@ -88,22 +84,27 @@ class DefaultNettyHttpClientRegistrySpec extends Specification { @Requires(property = 'spec.name', value = 'DefaultNettyHttpClientRegistrySpec') @Singleton - static class MyCustomizer implements BeanCreatedEventListener { + static class MyCustomizer implements BeanCreatedEventListener { static final AttributeKey CUSTOMIZED = AttributeKey.valueOf('micronaut.test.customized') def connected = 0 def duplicate = false @Override - ChannelPipelineCustomizer onCreated(BeanCreatedEvent event) { - event.bean.doOnConnect { - if (it.channel().hasAttr(CUSTOMIZED)) { - duplicate = true + NettyClientCustomizer.Registry onCreated(BeanCreatedEvent event) { + event.bean.register(new NettyClientCustomizer() { + @Override + NettyClientCustomizer specializeForChannel(Channel channel, NettyClientCustomizer.ChannelRole role) { + if (role == NettyClientCustomizer.ChannelRole.CONNECTION) { + if (channel.hasAttr(CUSTOMIZED)) { + duplicate = true + } + channel.attr(CUSTOMIZED).set(true) + connected++ + } + return this } - it.channel().attr(CUSTOMIZED).set(true) - connected++ - return it - } + }) return event.bean } } diff --git a/http-client/src/test/groovy/io/micronaut/http/client/netty/DummyChannelId.groovy b/http-client/src/test/groovy/io/micronaut/http/client/netty/DummyChannelId.groovy new file mode 100644 index 00000000000..c5d6af6d3ff --- /dev/null +++ b/http-client/src/test/groovy/io/micronaut/http/client/netty/DummyChannelId.groovy @@ -0,0 +1,26 @@ +package io.micronaut.http.client.netty + +import io.netty.channel.ChannelId + +class DummyChannelId implements ChannelId { + final String name + + DummyChannelId(String name) { + this.name = name + } + + @Override + String asShortText() { + return name + } + + @Override + String asLongText() { + return name + } + + @Override + int compareTo(ChannelId o) { + return asLongText() <=> o.asLongText() + } +} diff --git a/http-client/src/test/groovy/io/micronaut/http/client/netty/EmbeddedTestUtil.groovy b/http-client/src/test/groovy/io/micronaut/http/client/netty/EmbeddedTestUtil.groovy new file mode 100644 index 00000000000..38441839f08 --- /dev/null +++ b/http-client/src/test/groovy/io/micronaut/http/client/netty/EmbeddedTestUtil.groovy @@ -0,0 +1,108 @@ +package io.micronaut.http.client.netty + +import io.netty.buffer.ByteBuf +import io.netty.buffer.CompositeByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.ChannelOutboundHandlerAdapter +import io.netty.channel.ChannelPromise +import io.netty.channel.embedded.EmbeddedChannel + +// todo: can we unify this with the util class in http-server-netty tests? +class EmbeddedTestUtil { + static void advance(EmbeddedChannel... channels) { + boolean advanced + do { + advanced = false + for (EmbeddedChannel channel : channels) { + if (channel.hasPendingTasks()) { + advanced = true + channel.runPendingTasks() + } + channel.checkException() + } + } while (advanced); + } + + static void connect(EmbeddedChannel server, EmbeddedChannel client) { + new ConnectionDirection(server, client).register() + new ConnectionDirection(client, server).register() + } + + private static class ConnectionDirection { + final EmbeddedChannel source + final EmbeddedChannel dest + CompositeByteBuf sourceQueue + final List sourceQueueFutures = new ArrayList<>(); + final Queue destQueue = new ArrayDeque<>() + boolean readPending + + ConnectionDirection(EmbeddedChannel source, EmbeddedChannel dest) { + this.source = source + this.dest = dest + } + + private void forwardNow(ByteBuf msg) { + if (!dest.isOpen()) { + return + } + dest.writeOneInbound(msg) + dest.pipeline().fireChannelReadComplete() + } + + void register() { + source.pipeline().addFirst(new ChannelOutboundHandlerAdapter() { + @Override + void write(ChannelHandlerContext ctx_, Object msg, ChannelPromise promise) throws Exception { + if (!(msg instanceof ByteBuf)) { + throw new IllegalArgumentException("Can only forward bytes, got " + msg) + } + if (!msg.isReadable()) { + // no data + msg.release() + promise.setSuccess() + return + } + + if (sourceQueue == null) { + sourceQueue = ((ByteBuf) msg).alloc().compositeBuffer() + } + sourceQueue.addComponent(true, (ByteBuf) msg) + if (!promise.isVoid()) { + sourceQueueFutures.add(promise) + } + } + + @Override + void flush(ChannelHandlerContext ctx_) throws Exception { + if (sourceQueue != null) { + ByteBuf packet = sourceQueue + sourceQueue = null + + for (ChannelPromise promise : sourceQueueFutures) { + promise.trySuccess() + } + sourceQueueFutures.clear() + + if (readPending || dest.config().isAutoRead()) { + dest.eventLoop().execute(() -> forwardNow(packet)) + readPending = false + } else { + destQueue.add(packet) + } + } + } + }) + dest.pipeline().addFirst(new ChannelOutboundHandlerAdapter() { + @Override + void read(ChannelHandlerContext ctx) throws Exception { + if (destQueue.isEmpty()) { + readPending = true + } else { + ByteBuf msg = destQueue.poll() + ctx.channel().eventLoop().execute(() -> forwardNow(msg)) + } + } + }) + } + } +} diff --git a/http-netty/src/main/java/io/micronaut/http/netty/channel/ChannelPipelineCustomizer.java b/http-netty/src/main/java/io/micronaut/http/netty/channel/ChannelPipelineCustomizer.java index 9883e5779dc..746774dc1c3 100644 --- a/http-netty/src/main/java/io/micronaut/http/netty/channel/ChannelPipelineCustomizer.java +++ b/http-netty/src/main/java/io/micronaut/http/netty/channel/ChannelPipelineCustomizer.java @@ -56,6 +56,12 @@ public interface ChannelPipelineCustomizer { String HANDLER_WEBSOCKET_UPGRADE = "websocket-upgrade-handler"; String HANDLER_MICRONAUT_INBOUND = "micronaut-inbound-handler"; String HANDLER_ACCESS_LOGGER = "http-access-logger"; + String HANDLER_INITIAL_ERROR = "initial-error"; + /** + * Handler that listens for channelActive to trigger, which will finish up the connection + * setup. + */ + String HANDLER_ACTIVITY_LISTENER = "activity-listener"; /** * @return Is this customizer the client. diff --git a/http-netty/src/main/java/io/micronaut/http/netty/stream/HttpStreamsClientHandler.java b/http-netty/src/main/java/io/micronaut/http/netty/stream/HttpStreamsClientHandler.java index f7702f69059..3ff8be28cda 100644 --- a/http-netty/src/main/java/io/micronaut/http/netty/stream/HttpStreamsClientHandler.java +++ b/http-netty/src/main/java/io/micronaut/http/netty/stream/HttpStreamsClientHandler.java @@ -88,7 +88,6 @@ protected boolean hasBody(HttpResponse response) { return true; } - if (HttpUtil.isContentLengthSet(response)) { return HttpUtil.getContentLength(response) > 0; } @@ -183,7 +182,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception @Override public void write(final ChannelHandlerContext ctx, Object msg, final ChannelPromise promise) throws Exception { - if (ctx.channel().attr(AttributeKey.valueOf(ChannelPipelineCustomizer.HANDLER_HTTP_CHUNK)).get() == Boolean.TRUE) { + if (Boolean.TRUE.equals(ctx.channel().attr(AttributeKey.valueOf(ChannelPipelineCustomizer.HANDLER_HTTP_CHUNK)).get())) { ctx.write(msg, promise); } else { super.write(ctx, msg, promise); diff --git a/http-netty/src/main/java/io/micronaut/http/netty/stream/HttpStreamsServerHandler.java b/http-netty/src/main/java/io/micronaut/http/netty/stream/HttpStreamsServerHandler.java index f819aad2bc3..a2d4cf67584 100644 --- a/http-netty/src/main/java/io/micronaut/http/netty/stream/HttpStreamsServerHandler.java +++ b/http-netty/src/main/java/io/micronaut/http/netty/stream/HttpStreamsServerHandler.java @@ -23,7 +23,16 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http.*; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker; import io.netty.handler.codec.http.websocketx.WebSocketVersion; @@ -211,6 +220,15 @@ protected void consumedInMessage(ChannelHandlerContext ctx) { webSocketResponse = null; webSocketResponseChannelPromise = null; } + if (inFlight == 0) { + // normally, after writing the response, the routing handler triggers a read() for the + // next request. However, if at this point the request is not fully read yet (e.g. + // still missing a LastHttpContent), then that read() call will simply read the + // remaining content, and the HandlerPublisher also won't trigger more read()s since + // it's complete. To prevent the connection from being stuck in that case, we trigger a + // read here. + ctx.read(); + } } private void handleWebSocketResponse(ChannelHandlerContext ctx, HttpResponse message, ChannelPromise promise) { diff --git a/http-server-netty/src/main/java/io/micronaut/http/server/netty/RoutingInBoundHandler.java b/http-server-netty/src/main/java/io/micronaut/http/server/netty/RoutingInBoundHandler.java index 01ad951e2c9..41c9c9dca87 100644 --- a/http-server-netty/src/main/java/io/micronaut/http/server/netty/RoutingInBoundHandler.java +++ b/http-server-netty/src/main/java/io/micronaut/http/server/netty/RoutingInBoundHandler.java @@ -1251,7 +1251,7 @@ private void writeFinalNettyResponse(MutableHttpResponse message, HttpRequest if (!isHttp2) { if (!nettyHeaders.contains(HttpHeaderNames.CONNECTION)) { boolean expectKeepAlive = nettyResponse.protocolVersion().isKeepAliveDefault() || request.getHeaders().isKeepAlive(); - if (!decodeError && (expectKeepAlive || httpStatus < 500 || serverConfiguration.isKeepAliveOnServerError())) { + if (!decodeError && expectKeepAlive && (httpStatus < 500 || serverConfiguration.isKeepAliveOnServerError())) { nettyHeaders.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE); } else { nettyHeaders.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE); diff --git a/http-server-netty/src/main/java/io/micronaut/http/server/netty/configuration/NettyHttpServerConfiguration.java b/http-server-netty/src/main/java/io/micronaut/http/server/netty/configuration/NettyHttpServerConfiguration.java index 4c05c0dcae3..d615bb1dc8a 100644 --- a/http-server-netty/src/main/java/io/micronaut/http/server/netty/configuration/NettyHttpServerConfiguration.java +++ b/http-server-netty/src/main/java/io/micronaut/http/server/netty/configuration/NettyHttpServerConfiguration.java @@ -111,7 +111,7 @@ public class NettyHttpServerConfiguration extends HttpServerConfiguration { * The default configuration for boolean flag indicating whether to add connection header `keep-alive` to responses with HttpStatus > 499. */ @SuppressWarnings("WeakerAccess") - public static final boolean DEFAULT_KEEP_ALIVE_ON_SERVER_ERROR = false; + public static final boolean DEFAULT_KEEP_ALIVE_ON_SERVER_ERROR = true; private static final Logger LOG = LoggerFactory.getLogger(NettyHttpServerConfiguration.class); diff --git a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/binding/HttpResponseSpec.groovy b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/binding/HttpResponseSpec.groovy index 2c2ea5d784e..2b61ac1134e 100644 --- a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/binding/HttpResponseSpec.groovy +++ b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/binding/HttpResponseSpec.groovy @@ -24,18 +24,13 @@ import io.micronaut.http.HttpStatus import io.micronaut.http.annotation.Controller import io.micronaut.http.annotation.Get import io.micronaut.http.client.HttpClient -import io.micronaut.http.client.DefaultHttpClientConfiguration import io.micronaut.http.client.exceptions.HttpClientResponseException import io.micronaut.http.server.netty.AbstractMicronautSpec -import io.micronaut.runtime.Micronaut import io.micronaut.runtime.server.EmbeddedServer import reactor.core.publisher.Flux import spock.lang.Shared import spock.lang.Unroll -import java.time.Duration -import java.time.temporal.ChronoUnit - /** * @author Graeme Rocher * @since 1.0 @@ -71,19 +66,19 @@ class HttpResponseSpec extends AbstractMicronautSpec { where: action | status | body | headers - "ok" | HttpStatus.OK | null | [connection: 'close'] - "ok-with-body" | HttpStatus.OK | "some text" | ['content-length': '9', 'content-type': 'text/plain'] + [connection: 'close'] - "error-with-body" | HttpStatus.INTERNAL_SERVER_ERROR | "some text" | ['content-length': '9', 'content-type': 'text/plain'] + [connection: 'close'] - "ok-with-body-object" | HttpStatus.OK | '{"name":"blah","age":10}' | defaultHeaders + ['content-length': '24', 'content-type': 'application/json'] + [connection: 'close'] - "status" | HttpStatus.MOVED_PERMANENTLY | null | [connection: 'close'] - "created-body" | HttpStatus.CREATED | '{"name":"blah","age":10}' | defaultHeaders + ['content-length': '24', 'content-type': 'application/json'] + [connection: 'close'] - "created-uri" | HttpStatus.CREATED | null | [connection: 'close', 'location': 'http://test.com'] - "created-body-uri" | HttpStatus.CREATED | '{"name":"blah","age":10}' | defaultHeaders + ['content-length': '24', 'content-type': 'application/json'] + [connection: 'close', 'location': 'http://test.com'] - "accepted" | HttpStatus.ACCEPTED | null | [connection: 'close'] - "accepted-uri" | HttpStatus.ACCEPTED | null | [connection: 'close', 'location': 'http://example.com'] - "disallow" | HttpStatus.METHOD_NOT_ALLOWED | null | [connection: "close", 'allow': 'DELETE'] - "optional-response/false" | HttpStatus.OK | null | [connection: 'close'] - "optional-response/true" | HttpStatus.NOT_FOUND | null | ['content-type': 'application/json', 'content-length': '162', connection: 'close'] + "ok" | HttpStatus.OK | null | [connection: 'keep-alive'] + "ok-with-body" | HttpStatus.OK | "some text" | ['content-length': '9', 'content-type': 'text/plain'] + [connection: 'keep-alive'] + "error-with-body" | HttpStatus.INTERNAL_SERVER_ERROR | "some text" | ['content-length': '9', 'content-type': 'text/plain'] + [connection: 'keep-alive'] + "ok-with-body-object" | HttpStatus.OK | '{"name":"blah","age":10}' | defaultHeaders + ['content-length': '24', 'content-type': 'application/json'] + [connection: 'keep-alive'] + "status" | HttpStatus.MOVED_PERMANENTLY | null | [connection: 'keep-alive'] + "created-body" | HttpStatus.CREATED | '{"name":"blah","age":10}' | defaultHeaders + ['content-length': '24', 'content-type': 'application/json'] + [connection: 'keep-alive'] + "created-uri" | HttpStatus.CREATED | null | [connection: 'keep-alive', 'location': 'http://test.com'] + "created-body-uri" | HttpStatus.CREATED | '{"name":"blah","age":10}' | defaultHeaders + ['content-length': '24', 'content-type': 'application/json'] + [connection: 'keep-alive', 'location': 'http://test.com'] + "accepted" | HttpStatus.ACCEPTED | null | [connection: 'keep-alive'] + "accepted-uri" | HttpStatus.ACCEPTED | null | [connection: 'keep-alive', 'location': 'http://example.com'] + "disallow" | HttpStatus.METHOD_NOT_ALLOWED | null | [connection: "keep-alive", 'allow': 'DELETE'] + "optional-response/false" | HttpStatus.OK | null | [connection: 'keep-alive'] + "optional-response/true" | HttpStatus.NOT_FOUND | null | ['content-type': 'application/json', 'content-length': '162', connection: 'keep-alive'] } @@ -104,7 +99,7 @@ class HttpResponseSpec extends AbstractMicronautSpec { } def responseBody = response.body.orElse(null) - def defaultHeaders = [connection: 'close'] + def defaultHeaders = [connection: 'keep-alive'] then: response.code() == status.code @@ -113,15 +108,15 @@ class HttpResponseSpec extends AbstractMicronautSpec { where: action | status | body | headers - "ok" | HttpStatus.OK | null | [connection: 'close'] - "ok-with-body" | HttpStatus.OK | "some text" | ['content-length': '9', 'content-type': 'text/plain'] + [connection: 'close'] - "error-with-body" | HttpStatus.INTERNAL_SERVER_ERROR | "some text" | ['content-length': '9', 'content-type': 'text/plain'] + [connection: 'close'] - "ok-with-body-object" | HttpStatus.OK | '{"name":"blah","age":10}' | defaultHeaders + ['content-length': '24', 'content-type': 'application/json'] + [connection: 'close'] - "status" | HttpStatus.MOVED_PERMANENTLY | null | [connection: 'close'] - "created-body" | HttpStatus.CREATED | '{"name":"blah","age":10}' | defaultHeaders + ['content-length': '24', 'content-type': 'application/json'] + [connection: 'close'] - "created-uri" | HttpStatus.CREATED | null | [connection: 'close', 'location': 'http://test.com'] - "accepted" | HttpStatus.ACCEPTED | null | [connection: 'close'] - "accepted-uri" | HttpStatus.ACCEPTED | null | [connection: 'close', 'location': 'http://example.com'] + "ok" | HttpStatus.OK | null | [connection: 'keep-alive'] + "ok-with-body" | HttpStatus.OK | "some text" | ['content-length': '9', 'content-type': 'text/plain'] + [connection: 'keep-alive'] + "error-with-body" | HttpStatus.INTERNAL_SERVER_ERROR | "some text" | ['content-length': '9', 'content-type': 'text/plain'] + [connection: 'keep-alive'] + "ok-with-body-object" | HttpStatus.OK | '{"name":"blah","age":10}' | defaultHeaders + ['content-length': '24', 'content-type': 'application/json'] + [connection: 'keep-alive'] + "status" | HttpStatus.MOVED_PERMANENTLY | null | [connection: 'keep-alive'] + "created-body" | HttpStatus.CREATED | '{"name":"blah","age":10}' | defaultHeaders + ['content-length': '24', 'content-type': 'application/json'] + [connection: 'keep-alive'] + "created-uri" | HttpStatus.CREATED | null | [connection: 'keep-alive', 'location': 'http://test.com'] + "accepted" | HttpStatus.ACCEPTED | null | [connection: 'keep-alive'] + "accepted-uri" | HttpStatus.ACCEPTED | null | [connection: 'keep-alive', 'location': 'http://example.com'] } void "test content encoding"() { @@ -232,9 +227,13 @@ class HttpResponseSpec extends AbstractMicronautSpec { server.close() } - void "test keep alive connection header is not set by default for > 499 response"() { + void "test keep alive connection header is not set for 500 response"() { when: - EmbeddedServer server = ApplicationContext.run(EmbeddedServer, ['micronaut.server.date-header': false, (SPEC_NAME_PROPERTY):getClass().simpleName]) + EmbeddedServer server = ApplicationContext.run(EmbeddedServer, [ + 'micronaut.server.netty.keepAliveOnServerError': false, + 'micronaut.server.date-header': false, + (SPEC_NAME_PROPERTY):getClass().simpleName, + ]) ApplicationContext ctx = server.getApplicationContext() HttpClient client = ctx.createBean(HttpClient, server.getURL()) @@ -253,19 +252,13 @@ class HttpResponseSpec extends AbstractMicronautSpec { server.close() } - void "test connection header is defaulted to keep-alive when configured to true for > 499 response"() { + void "test connection header is defaulted to keep-alive by default for > 499 response"() { when: - DefaultHttpClientConfiguration config = new DefaultHttpClientConfiguration() - - // The client will explicitly request "Connection: close" unless using a connection pool, so set it up - config.connectionPoolConfiguration.enabled = true - EmbeddedServer server = ApplicationContext.run(EmbeddedServer, [ - (SPEC_NAME_PROPERTY):getClass().simpleName, - 'micronaut.server.netty.keepAliveOnServerError':true + (SPEC_NAME_PROPERTY):getClass().simpleName ]) def ctx = server.getApplicationContext() - HttpClient client = ctx.createBean(HttpClient, embeddedServer.getURL(), config) + HttpClient client = ctx.createBean(HttpClient, server.getURL()) Flux.from(client.exchange( HttpRequest.GET('/test-header/fail') diff --git a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/binding/NettyHttpServerSpec.groovy b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/binding/NettyHttpServerSpec.groovy index fc59c5aaad7..6ac5f9fd84d 100644 --- a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/binding/NettyHttpServerSpec.groovy +++ b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/binding/NettyHttpServerSpec.groovy @@ -20,7 +20,11 @@ import io.micronaut.context.env.Environment import io.micronaut.context.env.PropertySource import io.micronaut.context.event.StartupEvent import io.micronaut.core.io.socket.SocketUtils -import io.micronaut.http.* +import io.micronaut.http.HttpHeaders +import io.micronaut.http.HttpMethod +import io.micronaut.http.HttpRequest +import io.micronaut.http.HttpResponse +import io.micronaut.http.HttpStatus import io.micronaut.http.annotation.Controller import io.micronaut.http.annotation.Get import io.micronaut.http.annotation.Put @@ -32,7 +36,6 @@ import io.micronaut.runtime.Micronaut import io.micronaut.runtime.event.annotation.EventListener import io.micronaut.runtime.server.EmbeddedServer import jakarta.inject.Singleton -import reactor.core.publisher.Flux import spock.lang.Retry import spock.lang.Specification import spock.lang.Stepwise @@ -179,7 +182,6 @@ class NettyHttpServerSpec extends Specification { DefaultHttpClientConfiguration config = new DefaultHttpClientConfiguration() // The client will explicitly request "Connection: close" unless using a connection pool, so set it up config.connectionPoolConfiguration.enabled = true - config.connectionPoolConfiguration.maxConnections = 2; config.connectionPoolConfiguration.acquireTimeout = Duration.of(3, ChronoUnit.SECONDS); ApplicationContext applicationContext = Micronaut.run() diff --git a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/configuration/NettyHttpServerConfigurationSpec.groovy b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/configuration/NettyHttpServerConfigurationSpec.groovy index 06a32e21179..8eec5b1c913 100644 --- a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/configuration/NettyHttpServerConfigurationSpec.groovy +++ b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/configuration/NettyHttpServerConfigurationSpec.groovy @@ -472,17 +472,17 @@ class NettyHttpServerConfigurationSpec extends Specification { NettyHttpServerConfiguration config = beanContext.getBean(NettyHttpServerConfiguration) then: - !config.keepAliveOnServerError + config.keepAliveOnServerError cleanup: beanContext.close() } - void "test keepAlive configuration set to true"() { + void "test keepAlive configuration set to false"() { given: ApplicationContext beanContext = new DefaultApplicationContext("test") beanContext.environment.addPropertySource(PropertySource.of("test", - ['micronaut.server.netty.keepAliveOnServerError': true] + ['micronaut.server.netty.keepAliveOnServerError': false] )) beanContext.start() @@ -490,7 +490,7 @@ class NettyHttpServerConfigurationSpec extends Specification { NettyHttpServerConfiguration config = beanContext.getBean(NettyHttpServerConfiguration) then: - config.keepAliveOnServerError + !config.keepAliveOnServerError cleanup: beanContext.close() diff --git a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/websocket/BinaryWebSocketSpec.groovy b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/websocket/BinaryWebSocketSpec.groovy index a4bd3a39ca2..4fc32dac55c 100644 --- a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/websocket/BinaryWebSocketSpec.groovy +++ b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/websocket/BinaryWebSocketSpec.groovy @@ -19,10 +19,12 @@ import io.micronaut.context.ApplicationContext import io.micronaut.context.annotation.Requires import io.micronaut.context.event.BeanCreatedEvent import io.micronaut.context.event.BeanCreatedEventListener -import io.micronaut.http.netty.channel.ChannelPipelineCustomizer +import io.micronaut.http.client.netty.NettyClientCustomizer +import io.micronaut.http.server.netty.NettyServerCustomizer import io.micronaut.runtime.server.EmbeddedServer import io.micronaut.websocket.WebSocketClient import io.netty.buffer.Unpooled +import io.netty.channel.Channel import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelOutboundHandlerAdapter import io.netty.channel.ChannelPipeline @@ -218,7 +220,8 @@ class BinaryWebSocketSpec extends Specification { 'spec.name' : 'test per-message compression', 'micronaut.server.port': -1 ]) - def compressionDetectionCustomizer = ctx.getBean(CompressionDetectionCustomizer) + def cdcServer = ctx.getBean(CompressionDetectionCustomizerServer) + def cdcClient = ctx.getBean(CompressionDetectionCustomizerClient) EmbeddedServer embeddedServer = ctx.getBean(EmbeddedServer) embeddedServer.start() PollingConditions conditions = new PollingConditions(timeout: 15, delay: 0.5) @@ -237,11 +240,12 @@ class BinaryWebSocketSpec extends Specification { fred.replies.size() == 1 } - compressionDetectionCustomizer.getPipelines().size() == 4 + cdcServer.getPipelines().size() == 2 + cdcClient.getPipelines().size() == 2 when: "A message is sent" List interceptors = new ArrayList<>() - for (ChannelPipeline pipeline : compressionDetectionCustomizer.getPipelines()) { + for (ChannelPipeline pipeline : cdcServer.getPipelines() + cdcClient.getPipelines()) { def interceptor = new MessageInterceptor() if (pipeline.get('ws-encoder') != null) { pipeline.addAfter('ws-encoder', 'MessageInterceptor', interceptor) @@ -268,17 +272,62 @@ class BinaryWebSocketSpec extends Specification { @Singleton @Requires(property = 'spec.name', value = 'test per-message compression') - static class CompressionDetectionCustomizer implements BeanCreatedEventListener { + static class CompressionDetectionCustomizerServer implements BeanCreatedEventListener { List pipelines = Collections.synchronizedList(new ArrayList<>()) @Override - ChannelPipelineCustomizer onCreated(BeanCreatedEvent event) { - event.getBean().doOnConnect { - pipelines.add(it) - return it + NettyServerCustomizer.Registry onCreated(BeanCreatedEvent event) { + event.getBean().register(new Customizer(null)) + return event.bean + } + + class Customizer implements NettyServerCustomizer { + final Channel channel + + Customizer(Channel channel) { + this.channel = channel + } + + @Override + NettyServerCustomizer specializeForChannel(Channel channel, ChannelRole role) { + return new Customizer(channel) } + + @Override + void onInitialPipelineBuilt() { + pipelines.add(channel.pipeline()) + } + } + } + + @Singleton + @Requires(property = 'spec.name', value = 'test per-message compression') + static class CompressionDetectionCustomizerClient implements BeanCreatedEventListener { + List pipelines = Collections.synchronizedList(new ArrayList<>()) + + @Override + NettyClientCustomizer.Registry onCreated(BeanCreatedEvent event) { + event.getBean().register(new Customizer(null)) return event.bean } + + class Customizer implements NettyClientCustomizer { + final Channel channel + + Customizer(Channel channel) { + this.channel = channel + } + + @Override + NettyClientCustomizer specializeForChannel(Channel channel, ChannelRole role) { + return new Customizer(channel) + } + + @Override + void onInitialPipelineBuilt() { + pipelines.add(channel.pipeline()) + } + } } static class MessageInterceptor extends ChannelOutboundHandlerAdapter { diff --git a/src/main/docs/guide/httpClient/lowLevelHttpClient/clientConfiguration.adoc b/src/main/docs/guide/httpClient/lowLevelHttpClient/clientConfiguration.adoc index 0b65eb51d34..aedaeffbe25 100644 --- a/src/main/docs/guide/httpClient/lowLevelHttpClient/clientConfiguration.adoc +++ b/src/main/docs/guide/httpClient/lowLevelHttpClient/clientConfiguration.adoc @@ -75,9 +75,17 @@ Alternatively, if you don't use service discovery you can use the `configuration ReactorHttpClient httpClient; ---- -=== Using HTTP Client Connection Pooling +=== Connection Pooling and HTTP/2 -A client that handles a significant number of requests will benefit from enabling HTTP client connection pooling. The following configuration enables pooling for the `foo` client: +Connections using normal HTTP (without TLS/SSL) use HTTP/1.1. This can be configured using the `plaintext-mode` configuration option. + +Secure connections (i.e. HTTP**S**, with TLS/SSL) use a feature called "Application Layer Protocol Negotiation" (ALPN) that is part of TLS to select the HTTP version. If the server supports HTTP/2, the Micronaut HTTP Client will use that capability by default, but if it doesn't, HTTP/1.1 is still supported. This is configured using the `alpn-modes` option, which is a list of supported ALPN protocol IDs (`"h2"` and `"http/1.1"`). + +NOTE: The HTTP/2 standard forbids the use of certain less secure TLS cipher suites for HTTP/2 connections. When the HTTP client supports HTTP/2 (which is the default), it will not support those cipher suites. Removing `"h2"` from `alpn-modes` will enable support for all cipher suites. + +Each HTTP/1.1 connection can only support one request at a time, but can be reused for subsequent requests using the `keep-alive` mechanism. HTTP/2 connections can support any number of concurrent requests. + +To remove the overhead of opening a new connection for each request, the Micronaut HTTP Client will reuse HTTP connections wherever possible. They are managed in a _connection pool_. HTTP/1.1 connections are kept around using keep-alive and are used for new requests, and for HTTP/2, a single connection is used for all requests. .Manually configuring HTTP services [source,yaml] @@ -90,15 +98,15 @@ micronaut: - http://foo1 - http://foo2 pool: - enabled: true # <1> - max-connections: 50 # <2> + max-concurrent-http1-connections: 50 # <1> ---- -<1> Enables the pool -<2> Sets the maximum number of connections in the pool +<1> Limit maximum concurrent HTTP/1.1 connections See the API for link:{api}/io/micronaut/http/client/HttpClientConfiguration.ConnectionPoolConfiguration.html[ConnectionPoolConfiguration] for details on available pool configuration options. +By setting the `pool.enabled` property to `false`, you can disable connection reuse. The pool is still used and other configuration options (e.g. concurrent HTTP 1 connections) still apply, but one connection will only serve one request. + === Configuring Event Loop Groups By default, Micronaut shares a common Netty `EventLoopGroup` for worker threads and all HTTP client threads. diff --git a/test-suite-groovy/src/test/groovy/io/micronaut/docs/netty/LogbookNettyClientCustomizer.groovy b/test-suite-groovy/src/test/groovy/io/micronaut/docs/netty/LogbookNettyClientCustomizer.groovy index de84e9777b7..a4d6b4b1aa5 100644 --- a/test-suite-groovy/src/test/groovy/io/micronaut/docs/netty/LogbookNettyClientCustomizer.groovy +++ b/test-suite-groovy/src/test/groovy/io/micronaut/docs/netty/LogbookNettyClientCustomizer.groovy @@ -1,12 +1,13 @@ package io.micronaut.docs.netty -import io.micronaut.context.annotation.Requires; +import io.micronaut.context.annotation.Requires +import io.micronaut.context.event.BeanCreatedEvent; // tag::imports[] -import io.micronaut.context.event.BeanCreatedEvent import io.micronaut.context.event.BeanCreatedEventListener import io.micronaut.http.client.netty.NettyClientCustomizer +import io.micronaut.http.netty.channel.ChannelPipelineCustomizer import io.netty.channel.Channel import jakarta.inject.Singleton import org.zalando.logbook.Logbook @@ -47,8 +48,9 @@ class LogbookNettyClientCustomizer } @Override - void onStreamPipelineBuilt() { - channel.pipeline().addLast( // <5> + void onRequestPipelineBuilt() { + channel.pipeline().addBefore( // <5> + ChannelPipelineCustomizer.HANDLER_HTTP_STREAM, "logbook", new LogbookClientHandler(logbook) ) diff --git a/test-suite-kotlin/src/test/kotlin/io/micronaut/docs/netty/LogbookNettyClientCustomizer.kt b/test-suite-kotlin/src/test/kotlin/io/micronaut/docs/netty/LogbookNettyClientCustomizer.kt index 5b980a2fcb8..7fdeb32a624 100644 --- a/test-suite-kotlin/src/test/kotlin/io/micronaut/docs/netty/LogbookNettyClientCustomizer.kt +++ b/test-suite-kotlin/src/test/kotlin/io/micronaut/docs/netty/LogbookNettyClientCustomizer.kt @@ -6,6 +6,7 @@ import io.micronaut.context.event.BeanCreatedEvent import io.micronaut.context.event.BeanCreatedEventListener import io.micronaut.http.client.netty.NettyClientCustomizer import io.micronaut.http.client.netty.NettyClientCustomizer.ChannelRole +import io.micronaut.http.netty.channel.ChannelPipelineCustomizer import io.netty.channel.Channel import jakarta.inject.Singleton import org.zalando.logbook.Logbook @@ -29,8 +30,9 @@ class LogbookNettyClientCustomizer(private val logbook: Logbook) : override fun specializeForChannel(channel: Channel, role: ChannelRole) = Customizer(channel) // <4> - override fun onStreamPipelineBuilt() { - channel!!.pipeline().addLast( // <5> + override fun onRequestPipelineBuilt() { + channel!!.pipeline().addBefore( // <5> + ChannelPipelineCustomizer.HANDLER_HTTP_STREAM, "logbook", LogbookClientHandler(logbook) ) diff --git a/test-suite/src/test/groovy/io/micronaut/http/client/http2/Http2RequestSpec.groovy b/test-suite/src/test/groovy/io/micronaut/http/client/http2/Http2RequestSpec.groovy index 88407a6ecd4..561bb44d1e8 100644 --- a/test-suite/src/test/groovy/io/micronaut/http/client/http2/Http2RequestSpec.groovy +++ b/test-suite/src/test/groovy/io/micronaut/http/client/http2/Http2RequestSpec.groovy @@ -152,6 +152,7 @@ class Http2RequestSpec extends Specification { "micronaut.server.http-version" : "2.0", 'micronaut.server.ssl.buildSelfSigned': true, 'micronaut.server.ssl.port': -1, + "micronaut.http.client.http-version" : "1.1", "micronaut.http.client.log-level" : "TRACE", "micronaut.server.netty.log-level" : "TRACE", 'micronaut.http.client.ssl.insecure-trust-all-certificates': true @@ -198,6 +199,7 @@ class Http2RequestSpec extends Specification { "micronaut.server.http-version" : "2.0", 'micronaut.server.ssl.buildSelfSigned': true, 'micronaut.server.ssl.port': -1, + "micronaut.http.client.http-version" : "1.1", "micronaut.http.client.log-level" : "TRACE", "micronaut.server.netty.log-level" : "TRACE" ]) diff --git a/test-suite/src/test/groovy/io/micronaut/http2/Http2AccessLoggerSpec.groovy b/test-suite/src/test/groovy/io/micronaut/http2/Http2AccessLoggerSpec.groovy index fcfc136201f..8839249b5d0 100644 --- a/test-suite/src/test/groovy/io/micronaut/http2/Http2AccessLoggerSpec.groovy +++ b/test-suite/src/test/groovy/io/micronaut/http2/Http2AccessLoggerSpec.groovy @@ -1,22 +1,21 @@ package io.micronaut.http2 +import ch.qos.logback.classic.Logger import ch.qos.logback.classic.spi.ILoggingEvent import ch.qos.logback.core.AppenderBase -import ch.qos.logback.classic.Logger -import io.micronaut.http.client.HttpClient -import io.micronaut.http.client.StreamingHttpClient -import org.reactivestreams.Publisher -import org.slf4j.LoggerFactory - import io.micronaut.context.ApplicationContext import io.micronaut.core.type.Argument import io.micronaut.docs.server.json.Person import io.micronaut.http.HttpRequest import io.micronaut.http.MediaType import io.micronaut.http.annotation.Get +import io.micronaut.http.client.HttpClient +import io.micronaut.http.client.StreamingHttpClient import io.micronaut.http.client.annotation.Client import io.micronaut.http.sse.Event import io.micronaut.runtime.server.EmbeddedServer +import org.reactivestreams.Publisher +import org.slf4j.LoggerFactory import reactor.core.publisher.Flux import spock.lang.AutoCleanup import spock.lang.Shared @@ -144,6 +143,7 @@ class Http2AccessLoggerSpec extends Specification { 'micronaut.server.ssl.buildSelfSigned': true, 'micronaut.server.ssl.port': -1, "micronaut.http.client.log-level" : "TRACE", + "micronaut.http.client.http-version" : "1.1", "micronaut.server.netty.log-level" : "TRACE", 'micronaut.server.netty.access-logger.enabled': true, 'micronaut.http.client.ssl.insecure-trust-all-certificates': true diff --git a/test-suite/src/test/java/io/micronaut/docs/netty/LogbookNettyClientCustomizer.java b/test-suite/src/test/java/io/micronaut/docs/netty/LogbookNettyClientCustomizer.java index a2d0069f968..6f37b73fc11 100644 --- a/test-suite/src/test/java/io/micronaut/docs/netty/LogbookNettyClientCustomizer.java +++ b/test-suite/src/test/java/io/micronaut/docs/netty/LogbookNettyClientCustomizer.java @@ -8,7 +8,6 @@ import io.micronaut.http.client.netty.NettyClientCustomizer; import io.micronaut.http.netty.channel.ChannelPipelineCustomizer; import io.netty.channel.Channel; -import io.netty.channel.ChannelPipeline; import jakarta.inject.Singleton; import org.zalando.logbook.Logbook; import org.zalando.logbook.netty.LogbookClientHandler; @@ -47,8 +46,9 @@ public NettyClientCustomizer specializeForChannel(Channel channel, ChannelRole r } @Override - public void onStreamPipelineBuilt() { - channel.pipeline().addLast( // <5> + public void onRequestPipelineBuilt() { + channel.pipeline().addBefore( // <5> + ChannelPipelineCustomizer.HANDLER_HTTP_STREAM, "logbook", new LogbookClientHandler(logbook) );