Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix reentrant subscribe in StreamingNettyByteBody #11051

Merged
merged 2 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import io.micronaut.http.body.CloseableByteBody;
import io.micronaut.http.exceptions.BufferLengthExceededException;
import io.micronaut.http.exceptions.ContentLengthExceededException;
import io.micronaut.http.netty.EventLoopFlow;
import io.micronaut.http.netty.PublisherAsBlocking;
import io.micronaut.http.netty.PublisherAsStream;
import io.netty.buffer.ByteBuf;
Expand Down Expand Up @@ -57,14 +56,30 @@
@Internal
public final class StreamingNettyByteBody extends NettyByteBody implements CloseableByteBody {
private final SharedBuffer sharedBuffer;
/**
* We have reserve, subscribe, and add calls in {@link SharedBuffer} that all modify the same
* data structures. They can all happen concurrently and must be moved to the event loop. We
* also need to ensure that a reserve and associated subscribe stay serialized
* ({@link io.micronaut.http.netty.EventLoopFlow} semantics). But because of the potential
* concurrency, we actually need stronger semantics than
* {@link io.micronaut.http.netty.EventLoopFlow}.
* <p>
* The solution is to use the old {@link EventLoop#inEventLoop()} + {@link EventLoop#execute}
* pattern. Serialization semantics for reserve to subscribe are guaranteed using this field:
* If the reserve call is delayed, this field is {@code true}, and the subscribe call will also
* be delayed. This approach is possible because we only need to serialize a single reserve
* with a single subscribe.
*/
private final boolean forceDelaySubscribe;
private BufferConsumer.Upstream upstream;

public StreamingNettyByteBody(SharedBuffer sharedBuffer) {
this(sharedBuffer, sharedBuffer.rootUpstream);
this(sharedBuffer, false, sharedBuffer.rootUpstream);
}

private StreamingNettyByteBody(SharedBuffer sharedBuffer, BufferConsumer.Upstream upstream) {
private StreamingNettyByteBody(SharedBuffer sharedBuffer, boolean forceDelaySubscribe, BufferConsumer.Upstream upstream) {
this.sharedBuffer = sharedBuffer;
this.forceDelaySubscribe = forceDelaySubscribe;
this.upstream = upstream;
}

Expand All @@ -74,7 +89,7 @@ public BufferConsumer.Upstream primary(BufferConsumer primary) {
failClaim();
}
this.upstream = null;
sharedBuffer.subscribe(primary, upstream);
sharedBuffer.subscribe(primary, upstream, forceDelaySubscribe);
return upstream;
}

Expand All @@ -86,8 +101,8 @@ public BufferConsumer.Upstream primary(BufferConsumer primary) {
}
UpstreamBalancer.UpstreamPair pair = UpstreamBalancer.balancer(upstream, backpressureMode);
this.upstream = pair.left();
this.sharedBuffer.reserve();
return new StreamingNettyByteBody(sharedBuffer, pair.right());
boolean forceDelaySubscribe = this.sharedBuffer.reserve();
return new StreamingNettyByteBody(sharedBuffer, forceDelaySubscribe, pair.right());
}

@Override
Expand Down Expand Up @@ -163,7 +178,7 @@ public void error(Throwable e) {
this.upstream = null;
upstream.start();
upstream.onBytesConsumed(Long.MAX_VALUE);
return sharedBuffer.subscribeFull(upstream).map(AvailableNettyByteBody::new);
return sharedBuffer.subscribeFull(upstream, forceDelaySubscribe).map(AvailableNettyByteBody::new);
}

@Override
Expand All @@ -176,14 +191,14 @@ public void close() {
upstream.allowDiscard();
upstream.disregardBackpressure();
upstream.start();
sharedBuffer.subscribe(null, upstream);
sharedBuffer.subscribe(null, upstream, forceDelaySubscribe);
}

/**
* This class buffers input data and distributes it to multiple {@link StreamingNettyByteBody}
* instances.
* <p>Thread safety: The {@link BufferConsumer} methods <i>must</i> only be called from one
* thread, the {@link #eventLoopFlow} thread. The other methods (subscribe, reserve) can be
* thread, the {@link #eventLoop} thread. The other methods (subscribe, reserve) can be
* called from any thread.
*/
public static final class SharedBuffer implements BufferConsumer {
Expand All @@ -193,7 +208,7 @@ public static final class SharedBuffer implements BufferConsumer {
@Nullable
private final ResourceLeakTracker<SharedBuffer> tracker = LEAK_DETECTOR.get().track(this);

private final EventLoopFlow eventLoopFlow;
private final EventLoop eventLoop;
private final BodySizeLimits limits;
/**
* Upstream of all subscribers. This is only used to cancel incoming data if the max
Expand Down Expand Up @@ -230,6 +245,11 @@ public static final class SharedBuffer implements BufferConsumer {
* in a reentrant fashion.
*/
private boolean working = false;
/**
* {@code true} during {@link #add(ByteBuf)} to avoid reentrant subscribe or reserve calls.
* Field must only be accessed on the event loop.
*/
private boolean adding = false;
/**
* Number of bytes received so far.
*/
Expand All @@ -242,7 +262,7 @@ public static final class SharedBuffer implements BufferConsumer {
private volatile long expectedLength = -1;

public SharedBuffer(EventLoop loop, BodySizeLimits limits, Upstream rootUpstream) {
this.eventLoopFlow = new EventLoopFlow(loop);
this.eventLoop = loop;
this.limits = limits;
this.rootUpstream = rootUpstream;
}
Expand Down Expand Up @@ -274,9 +294,13 @@ public void setExpectedLength(long length) {
this.expectedLength = length;
}

void reserve() {
if (eventLoopFlow.executeNow(this::reserve0)) {
boolean reserve() {
if (eventLoop.inEventLoop() && !adding) {
reserve0();
return false;
} else {
eventLoop.execute(this::reserve0);
return true;
}
}

Expand All @@ -295,10 +319,13 @@ private void reserve0() {
*
* @param subscriber The subscriber to add. Can be {@code null}, then the bytes will just be discarded
* @param specificUpstream The upstream for the subscriber. This is used to call allowDiscard if there was an error
* @param forceDelay Whether to require an {@link EventLoop#execute} call to ensure serialization with previous {@link #reserve()} call
*/
void subscribe(@Nullable BufferConsumer subscriber, Upstream specificUpstream) {
if (eventLoopFlow.executeNow(() -> subscribe0(subscriber, specificUpstream))) {
void subscribe(@Nullable BufferConsumer subscriber, Upstream specificUpstream, boolean forceDelay) {
if (!forceDelay && eventLoop.inEventLoop() && !adding) {
subscribe0(subscriber, specificUpstream);
} else {
eventLoop.execute(() -> subscribe0(subscriber, specificUpstream));
}
}

Expand Down Expand Up @@ -354,16 +381,18 @@ private void subscribe0(@Nullable BufferConsumer subscriber, Upstream specificUp
* body.
*
* @param specificUpstream The upstream for the subscriber. This is used to call allowDiscard if there was an error
* @param forceDelay Whether to require an {@link EventLoop#execute} call to ensure serialization with previous {@link #reserve()} call
* @return A flow that will complete when all data has arrived, with a buffer containing that data
*/
ExecutionFlow<ByteBuf> subscribeFull(Upstream specificUpstream) {
ExecutionFlow<ByteBuf> subscribeFull(Upstream specificUpstream, boolean forceDelay) {
DelayedExecutionFlow<ByteBuf> asyncFlow = DelayedExecutionFlow.create();
if (eventLoopFlow.executeNow(() -> {
ExecutionFlow<ByteBuf> res = subscribeFull0(asyncFlow, specificUpstream, false);
assert res == asyncFlow;
})) {
if (!forceDelay && eventLoop.inEventLoop() && !adding) {
return subscribeFull0(asyncFlow, specificUpstream, true);
} else {
eventLoop.execute(() -> {
ExecutionFlow<ByteBuf> res = subscribeFull0(asyncFlow, specificUpstream, false);
assert res == asyncFlow;
});
return asyncFlow;
}
}
Expand Down Expand Up @@ -445,6 +474,7 @@ public void add(ByteBuf buf) {
buf.release();
return;
}
adding = true;
// calculate the new total length
long newLength = lengthSoFar + buf.readableBytes();
lengthSoFar = newLength;
Expand All @@ -453,6 +483,7 @@ public void add(ByteBuf buf) {
buf.release();
error(new ContentLengthExceededException(limits.maxBodySize(), newLength));
rootUpstream.allowDiscard();
adding = false;
return;
}

Expand Down Expand Up @@ -486,6 +517,7 @@ public void add(ByteBuf buf) {
} else {
buf.release();
}
adding = false;
working = false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package io.micronaut.http.server.netty.handler

import io.micronaut.core.io.buffer.ByteBuffer
import io.micronaut.http.body.AvailableByteBody
import io.micronaut.http.body.ByteBody
import io.micronaut.http.body.CloseableAvailableByteBody
import io.micronaut.http.body.CloseableByteBody
import io.netty.buffer.ByteBuf
Expand Down Expand Up @@ -604,6 +605,38 @@ class PipeliningServerHandlerSpec extends Specification {
unwritten == 0
}

def 'reentrant close'() {
given:
def resp = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK)
resp.headers().add(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED)
def ch = new EmbeddedChannel(new PipeliningServerHandler(new RequestHandler() {
@Override
void accept(ChannelHandlerContext ctx, HttpRequest request, CloseableByteBody body, OutboundAccess outboundAccess) {
def split = body.split(ByteBody.SplitBackpressureMode.FASTEST)
Flux.from(split.toByteArrayPublisher())
.subscribe {
body.close()
outboundAccess.writeFull(resp)
}
}

@Override
void handleUnboundError(Throwable cause) {
cause.printStackTrace()
}
}))


def request = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/")
request.headers().add(HttpHeaderNames.CONTENT_LENGTH, 3)
when:
ch.writeInbound(request)
ch.writeInbound(new DefaultLastHttpContent(Unpooled.copiedBuffer("foo", StandardCharsets.UTF_8)))

then:
ch.checkException()
}

static class MonitorHandler extends ChannelOutboundHandlerAdapter {
int flush = 0
int read = 0
Expand Down
Loading