From bf15e47a8a934c563f5d4767d5da062b8551ef6b Mon Sep 17 00:00:00 2001 From: Dr Ian Preston Date: Tue, 15 Aug 2023 15:02:16 +0100 Subject: [PATCH] Fix yamux handling of writes bigger than the window size (#295) * Fix yamux handling of writes bigger than the window size * Fix inefficient window size handling * Release delayed yamux send buffer once sent. * Add Unit test --------- Co-authored-by: Anton Nashatyrev --- .../io/libp2p/mux/yamux/YamuxHandler.kt | 19 ++++-- .../main/kotlin/io/libp2p/protocol/Ping.kt | 9 ++- .../java/io/libp2p/core/HostTestJava.java | 63 +++++++++++++++++++ .../io/libp2p/mux/yamux/YamuxHandlerTest.kt | 29 ++++++++- 4 files changed, 110 insertions(+), 10 deletions(-) diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt index 3f975f441..887745c0a 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt @@ -34,19 +34,26 @@ open class YamuxHandler( private val buffered = ArrayDeque() fun add(data: ByteBuf) { - buffered.add(data) + buffered.add(data.retain()) } fun flush(sendWindow: AtomicInteger, id: MuxId): Int { var written = 0 while (! buffered.isEmpty()) { val buf = buffered.first() - if (buf.readableBytes() + written < sendWindow.get()) { - buffered.removeFirst() + val readableBytes = buf.readableBytes() + if (readableBytes + written < sendWindow.get()) { sendBlocks(ctx, buf, sendWindow, id) - written += buf.readableBytes() - } else + written += readableBytes + buf.release() + buffered.removeFirst() + } else { + // partial write to fit within window + val toRead = sendWindow.get() - written + sendBlocks(ctx, buf.readSlice(toRead), sendWindow, id) + written += toRead break + } } return written } @@ -96,7 +103,7 @@ open class YamuxHandler( } val newWindow = recWindow.addAndGet(-size.toInt()) if (newWindow < INITIAL_WINDOW_SIZE / 2) { - val delta = INITIAL_WINDOW_SIZE / 2 + val delta = INITIAL_WINDOW_SIZE - newWindow recWindow.addAndGet(delta) ctx.write(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, 0, delta.toLong())) ctx.flush() diff --git a/libp2p/src/main/kotlin/io/libp2p/protocol/Ping.kt b/libp2p/src/main/kotlin/io/libp2p/protocol/Ping.kt index 616cd6450..7a9c20a0f 100644 --- a/libp2p/src/main/kotlin/io/libp2p/protocol/Ping.kt +++ b/libp2p/src/main/kotlin/io/libp2p/protocol/Ping.kt @@ -22,20 +22,23 @@ interface PingController { fun ping(): CompletableFuture } -class Ping : PingBinding(PingProtocol()) +class Ping(pingSize: Int) : PingBinding(PingProtocol(pingSize)) { + constructor() : this(32) +} open class PingBinding(ping: PingProtocol) : StrictProtocolBinding("/ipfs/ping/1.0.0", ping) class PingTimeoutException : Libp2pException() -open class PingProtocol : ProtocolHandler(Long.MAX_VALUE, Long.MAX_VALUE) { +open class PingProtocol(var pingSize: Int) : ProtocolHandler(Long.MAX_VALUE, Long.MAX_VALUE) { var timeoutScheduler by lazyVar { Executors.newSingleThreadScheduledExecutor() } var curTime: () -> Long = { System.currentTimeMillis() } var random = Random() - var pingSize = 32 var pingTimeout = Duration.ofSeconds(10) + constructor() : this(32) + override fun onStartInitiator(stream: Stream): CompletableFuture { val handler = PingInitiator() stream.pushHandler(handler) diff --git a/libp2p/src/test/java/io/libp2p/core/HostTestJava.java b/libp2p/src/test/java/io/libp2p/core/HostTestJava.java index f8f92ef86..c8713f8ad 100644 --- a/libp2p/src/test/java/io/libp2p/core/HostTestJava.java +++ b/libp2p/src/test/java/io/libp2p/core/HostTestJava.java @@ -93,6 +93,69 @@ void ping() throws Exception { System.out.println("Server stopped"); } + @Test + void largePing() throws Exception { + int pingSize = 200 * 1024; + String localListenAddress = "/ip4/127.0.0.1/tcp/40002"; + + Host clientHost = new HostBuilder() + .transport(TcpTransport::new) + .secureChannel((k, m) -> new TlsSecureChannel(k, m, "ECDSA")) + .muxer(StreamMuxerProtocol::getYamux) + .build(); + + Host serverHost = new HostBuilder() + .transport(TcpTransport::new) + .secureChannel(TlsSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .protocol(new Ping(pingSize)) + .listen(localListenAddress) + .build(); + + CompletableFuture clientStarted = clientHost.start(); + CompletableFuture serverStarted = serverHost.start(); + clientStarted.get(5, TimeUnit.SECONDS); + System.out.println("Client started"); + serverStarted.get(5, TimeUnit.SECONDS); + System.out.println("Server started"); + + Assertions.assertEquals(0, clientHost.listenAddresses().size()); + Assertions.assertEquals(1, serverHost.listenAddresses().size()); + Assertions.assertEquals( + localListenAddress + "/p2p/" + serverHost.getPeerId(), + serverHost.listenAddresses().get(0).toString() + ); + + StreamPromise ping = + clientHost.getNetwork().connect( + serverHost.getPeerId(), + new Multiaddr(localListenAddress) + ).thenApply( + it -> it.muxerSession().createStream(new Ping(pingSize)) + ) + .join(); + + Stream pingStream = ping.getStream().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream created"); + PingController pingCtr = ping.getController().get(5, TimeUnit.SECONDS); + System.out.println("Ping controller created"); + + for (int i = 0; i < 10; i++) { + long latency = pingCtr.ping().join();//get(5, TimeUnit.SECONDS); + System.out.println("Ping is " + latency); + } + pingStream.close().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream closed"); + + Assertions.assertThrows(ExecutionException.class, () -> + pingCtr.ping().get(5, TimeUnit.SECONDS)); + + clientHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Client stopped"); + serverHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Server stopped"); + } + @Test void keyPairGeneration() { Pair pair = KeyKt.generateKeyPair(KEY_TYPE.SECP256K1); diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt index d46336713..4fc35691d 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt @@ -10,6 +10,8 @@ import io.libp2p.mux.MuxHandlerAbstractTest import io.libp2p.mux.MuxHandlerAbstractTest.AbstractTestMuxFrame.Flag.* import io.libp2p.tools.readAllBytesAndRelease import io.netty.channel.ChannelHandlerContext +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test class YamuxHandlerTest : MuxHandlerAbstractTest() { @@ -27,8 +29,10 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { } } + private fun Long.toMuxId() = MuxId(parentChannelId, this, true) + override fun writeFrame(frame: AbstractTestMuxFrame) { - val muxId = MuxId(parentChannelId, frame.streamId, true) + val muxId = frame.streamId.toMuxId() val yamuxFrame = when (frame.flag) { Open -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.SYN, 0) Data -> YamuxFrame( @@ -65,4 +69,27 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { return readFrameQueue.removeFirstOrNull() } + + @Test + fun `data should be buffered and sent after window increased from zero`() { + val handler = openStreamByLocal() + val streamId = readFrameOrThrow().streamId + + ech.writeInbound( + YamuxFrame( + streamId.toMuxId(), + YamuxType.WINDOW_UPDATE, + YamuxFlags.ACK, + -INITIAL_WINDOW_SIZE.toLong() + ) + ) + + handler.ctx.writeAndFlush("1984".fromHex().toByteBuf(allocateBuf())) + + assertThat(readFrame()).isNull() + + ech.writeInbound(YamuxFrame(streamId.toMuxId(), YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 5000)) + val frame = readFrameOrThrow() + assertThat(frame.data).isEqualTo("1984") + } }