diff --git a/libp2p/connmanager.nim b/libp2p/connmanager.nim index ad37f00618..970f93597f 100644 --- a/libp2p/connmanager.nim +++ b/libp2p/connmanager.nim @@ -19,7 +19,8 @@ logScope: declareGauge(libp2p_peers, "total connected peers") -const MaxConnectionsPerPeer = 5 +const + MaxConnectionsPerPeer = 5 type TooManyConnections* = object of CatchableError @@ -236,12 +237,17 @@ proc cleanupConn(c: ConnManager, conn: Connection) {.async.} = trace "Connection cleaned up", conn -proc peerStartup(c: ConnManager, conn: Connection) {.async.} = +proc onConnUpgraded(c: ConnManager, conn: Connection) {.async.} = try: trace "Triggering connect events", conn + doAssert(not isNil(conn.upgraded), + "The `upgraded` event hasn't been properly initialized!") + conn.upgraded.complete() + let peerId = conn.peerInfo.peerId await c.triggerPeerEvents( peerId, PeerEvent(kind: PeerEventKind.Joined, initiator: conn.dir == Direction.Out)) + await c.triggerConnEvent( peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: conn.dir == Direction.In)) except CatchableError as exc: @@ -384,7 +390,7 @@ proc storeMuxer*(c: ConnManager, trace "Stored muxer", muxer, handle = not handle.isNil, connections = c.conns.len - asyncSpawn c.peerStartup(muxer.connection) + asyncSpawn c.onConnUpgraded(muxer.connection) proc getStream*(c: ConnManager, peerId: PeerID, diff --git a/libp2p/protocols/secure/secure.nim b/libp2p/protocols/secure/secure.nim index 22395ee1e0..1a47b0083f 100644 --- a/libp2p/protocols/secure/secure.nim +++ b/libp2p/protocols/secure/secure.nim @@ -46,6 +46,7 @@ proc init*(T: type SecureConn, peerInfo: peerInfo, observedAddr: observedAddr, closeEvent: conn.closeEvent, + upgraded: conn.upgraded, timeout: timeout, dir: conn.dir) result.initStream() diff --git a/libp2p/stream/connection.nim b/libp2p/stream/connection.nim index a20e98ec7f..eae67c6459 100644 --- a/libp2p/stream/connection.nim +++ b/libp2p/stream/connection.nim @@ -32,6 +32,7 @@ type timeoutHandler*: TimeoutHandler # timeout handler peerInfo*: PeerInfo observedAddr*: Multiaddress + upgraded*: Future[void] proc timeoutMonitor(s: Connection) {.async, gcsafe.} @@ -49,6 +50,9 @@ method initStream*(s: Connection) = doAssert(isNil(s.timerTaskFut)) + if isNil(s.upgraded): + s.upgraded = newFuture[void]() + if s.timeout > 0.millis: trace "Monitoring for timeout", s, timeout = s.timeout @@ -61,10 +65,15 @@ method initStream*(s: Connection) = method closeImpl*(s: Connection): Future[void] = # Cleanup timeout timer trace "Closing connection", s + if not isNil(s.timerTaskFut) and not s.timerTaskFut.finished: s.timerTaskFut.cancel() s.timerTaskFut = nil + if not isNil(s.upgraded) and not s.upgraded.finished: + s.upgraded.cancel() + s.upgraded = nil + trace "Closed connection", s procCall LPStream(s).closeImpl() diff --git a/libp2p/stream/lpstream.nim b/libp2p/stream/lpstream.nim index c433b3a380..1ced16bfd5 100644 --- a/libp2p/stream/lpstream.nim +++ b/libp2p/stream/lpstream.nim @@ -125,7 +125,7 @@ method initStream*(s: LPStream) {.base.} = libp2p_open_streams.inc(labelValues = [s.objName, $s.dir]) inc getStreamTracker(s.objName).opened - debug "Stream created", s, objName = s.objName, dir = $s.dir + trace "Stream created", s, objName = s.objName, dir = $s.dir proc join*(s: LPStream): Future[void] = s.closeEvent.wait() @@ -258,7 +258,7 @@ method closeImpl*(s: LPStream): Future[void] {.async, base.} = s.closeEvent.fire() libp2p_open_streams.dec(labelValues = [s.objName, $s.dir]) inc getStreamTracker(s.objName).closed - debug "Closed stream", s, objName = s.objName, dir = $s.dir + trace "Closed stream", s, objName = s.objName, dir = $s.dir method close*(s: LPStream): Future[void] {.base, async.} = # {.raises [Defect].} ## close the stream - this may block, but will not raise exceptions diff --git a/libp2p/switch.nim b/libp2p/switch.nim index d254ae6089..9dbc31b2cd 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -26,6 +26,7 @@ import stream/connection, peerinfo, protocols/identify, muxers/muxer, + utils/semaphore, connmanager, peerid, errors @@ -45,6 +46,9 @@ declareCounter(libp2p_dialed_peers, "dialed peers") declareCounter(libp2p_failed_dials, "failed dials") declareCounter(libp2p_failed_upgrade, "peers failed upgrade") +const + ConcurrentUpgrades* = 4 + type UpgradeFailedError* = object of CatchableError DialFailedError* = object of CatchableError @@ -223,23 +227,26 @@ proc upgradeIncoming(s: Switch, incomingConn: Connection) {.async, gcsafe.} = # trace "Starting secure handler", conn let secure = s.secureManagers.filterIt(it.codec == proto)[0] - var sconn: Connection + var cconn = conn try: - sconn = await secure.secure(conn, false) + var sconn = await secure.secure(cconn, false) if isNil(sconn): return + cconn = sconn # add the muxer for muxer in s.muxers.values: ms.addHandler(muxer.codecs, muxer) # handle subsequent secure requests - await ms.handle(sconn) + await ms.handle(cconn) except CatchableError as exc: debug "Exception in secure handler during incoming upgrade", msg = exc.msg, conn + if not cconn.upgraded.finished: + cconn.upgraded.fail(exc) finally: - if not isNil(sconn): - await sconn.close() + if not isNil(cconn): + await cconn.close() trace "Stopped secure handler", conn @@ -254,6 +261,8 @@ proc upgradeIncoming(s: Switch, incomingConn: Connection) {.async, gcsafe.} = # await ms.handle(incomingConn, active = true) except CatchableError as exc: debug "Exception upgrading incoming", exc = exc.msg + if not incomingConn.upgraded.finished: + incomingConn.upgraded.fail(exc) finally: await incomingConn.close() @@ -416,31 +425,61 @@ proc mount*[T: LPProtocol](s: Switch, proto: T, matcher: Matcher = nil) {.gcsafe s.ms.addHandler(proto.codecs, proto, matcher) +proc upgradeMonitor(conn: Connection, upgrades: AsyncSemaphore) {.async.} = + ## monitor connection for upgrades + ## + try: + # Since we don't control the flow of the + # upgrade, this timeout guarantees that a + # "hanged" remote doesn't hold the upgrade + # forever + await conn.upgraded.wait(30.seconds) # wait for connection to be upgraded + trace "Connection upgrade succeeded" + except CatchableError as exc: + # if not isNil(conn): # for some reason, this can be nil + await conn.close() + + trace "Exception awaiting connection upgrade", exc = exc.msg, conn + finally: + upgrades.release() # don't forget to release the slot! + proc accept(s: Switch, transport: Transport) {.async.} = # noraises - ## transport's accept loop + ## switch accept loop, ran for every transport ## + let upgrades = AsyncSemaphore.init(ConcurrentUpgrades) while transport.running: var conn: Connection try: debug "About to accept incoming connection" - conn = await transport.accept() - if not isNil(conn): - debug "Accepted an incoming connection", conn - asyncSpawn s.upgradeIncoming(conn) # perform upgrade on incoming connection - else: + # remember to always release the slot when + # the upgrade succeeds or fails, this is + # currently done by the `upgradeMonitor` + await upgrades.acquire() # first wait for an upgrade slot to become available + conn = await transport.accept() # next attempt to get a connection + if isNil(conn): # A nil connection means that we might have hit a # file-handle limit (or another non-fatal error), # we can get one on the next try, but we should # be careful to not end up in a thigh loop that # will starve the main event loop, thus we sleep # here before retrying. + trace "Unable to get a connection, sleeping" await sleepAsync(100.millis) # TODO: should be configurable? + upgrades.release() + continue + + debug "Accepted an incoming connection", conn + asyncSpawn upgradeMonitor(conn, upgrades) + asyncSpawn s.upgradeIncoming(conn) + except CancelledError as exc: + trace "releasing semaphore on cancellation" + upgrades.release() # always release the slot except CatchableError as exc: debug "Exception in accept loop, exiting", exc = exc.msg + upgrades.release() # always release the slot if not isNil(conn): await conn.close() - return proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = @@ -460,13 +499,6 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = proc stop*(s: Switch) {.async.} = trace "Stopping switch" - for a in s.acceptFuts: - if not a.finished: - a.cancel() - - checkFutures( - await allFinished(s.acceptFuts)) - # close and cleanup all connections await s.connManager.close() @@ -478,6 +510,18 @@ proc stop*(s: Switch) {.async.} = except CatchableError as exc: warn "error cleaning up transports", msg = exc.msg + try: + await allFutures(s.acceptFuts) + .wait(1.seconds) + except CatchableError as exc: + trace "Exception while stopping accept loops", exc = exc.msg + + # check that all futures were properly + # stopped and otherwise cancel them + for a in s.acceptFuts: + if not a.finished: + a.cancel() + trace "Switch stopped" proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} = diff --git a/libp2p/transports/tcptransport.nim b/libp2p/transports/tcptransport.nim index f48cc09cbb..50513b19c7 100644 --- a/libp2p/transports/tcptransport.nim +++ b/libp2p/transports/tcptransport.nim @@ -71,7 +71,7 @@ proc connHandler*(t: TcpTransport, await client.closeWait() raise exc - debug "Handling tcp connection", address = $observedAddr, + trace "Handling tcp connection", address = $observedAddr, dir = $dir, clients = t.clients[Direction.In].len + t.clients[Direction.Out].len @@ -130,7 +130,10 @@ method start*(t: TcpTransport, ma: MultiAddress) {.async.} = await procCall Transport(t).start(ma) trace "Starting TCP transport" - t.server = createStreamServer(t.ma, t.flags, t) + t.server = createStreamServer( + ma = t.ma, + flags = t.flags, + udata = t) # always get the resolved address in case we're bound to 0.0.0.0:0 t.ma = MultiAddress.init(t.server.sock.getLocalAddress()).tryGet() @@ -142,6 +145,8 @@ method stop*(t: TcpTransport) {.async, gcsafe.} = ## stop the transport ## + t.running = false # mark stopped as soon as possible + try: trace "Stopping TCP transport" await procCall Transport(t).stop() # call base @@ -160,8 +165,6 @@ method stop*(t: TcpTransport) {.async, gcsafe.} = inc getTcpTransportTracker().closed except CatchableError as exc: trace "Error shutting down tcp transport", exc = exc.msg - finally: - t.running = false method accept*(t: TcpTransport): Future[Connection] {.async, gcsafe.} = ## accept a new TCP connection @@ -179,12 +182,12 @@ method accept*(t: TcpTransport): Future[Connection] {.async, gcsafe.} = # that can't. debug "OS Error", exc = exc.msg except TransportTooManyError as exc: - warn "Too many files opened", exc = exc.msg + debug "Too many files opened", exc = exc.msg except TransportUseClosedError as exc: - info "Server was closed", exc = exc.msg + debug "Server was closed", exc = exc.msg raise newTransportClosedError(exc) except CatchableError as exc: - trace "Unexpected error creating connection", exc = exc.msg + warn "Unexpected error creating connection", exc = exc.msg raise exc method dial*(t: TcpTransport, diff --git a/libp2p/utils/semaphore.nim b/libp2p/utils/semaphore.nim new file mode 100644 index 0000000000..72ed68b752 --- /dev/null +++ b/libp2p/utils/semaphore.nim @@ -0,0 +1,75 @@ +## Nim-Libp2p +## Copyright (c) 2020 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. + +import sequtils +import chronos, chronicles + +# TODO: this should probably go in chronos + +logScope: + topics = "libp2p semaphore" + +type + AsyncSemaphore* = ref object of RootObj + size*: int + count*: int + queue*: seq[Future[void]] + +proc init*(T: type AsyncSemaphore, size: int): T = + T(size: size, count: size) + +proc tryAcquire*(s: AsyncSemaphore): bool = + ## Attempts to acquire a resource, if successful + ## returns true, otherwise false + ## + + if s.count > 0 and s.queue.len == 0: + s.count.dec + trace "Acquired slot", available = s.count, queue = s.queue.len + return true + +proc acquire*(s: AsyncSemaphore): Future[void] = + ## Acquire a resource and decrement the resource + ## counter. If no more resources are available, + ## the returned future will not complete until + ## the resource count goes above 0 again. + ## + + let fut = newFuture[void]("AsyncSemaphore.acquire") + if s.tryAcquire(): + fut.complete() + return fut + + s.queue.add(fut) + s.count.dec + trace "Queued slot", available = s.count, queue = s.queue.len + return fut + +proc release*(s: AsyncSemaphore) = + ## Release a resource from the semaphore, + ## by picking the first future from the queue + ## and completing it and incrementing the + ## internal resource count + ## + + doAssert(s.count <= s.size) + + if s.count < s.size: + trace "Releasing slot", available = s.count, + queue = s.queue.len + + if s.queue.len > 0: + var fut = s.queue.pop() + if not fut.finished(): + fut.complete() + + s.count.inc # increment the resource count + trace "Released slot", available = s.count, + queue = s.queue.len + return diff --git a/tests/helpers.nim b/tests/helpers.nim index b8ef25158f..20cb856996 100644 --- a/tests/helpers.nim +++ b/tests/helpers.nim @@ -9,6 +9,8 @@ import ../libp2p/stream/lpstream import ../libp2p/muxers/mplex/lpchannel import ../libp2p/protocols/secure/secure +export unittest + const StreamTransportTrackerName = "stream.transport" StreamServerTrackerName = "stream.server" diff --git a/tests/testnative.nim b/tests/testnative.nim index 06b325895d..005d7e8c16 100644 --- a/tests/testnative.nim +++ b/tests/testnative.nim @@ -1,6 +1,7 @@ import testvarint, testminprotobuf, - teststreamseq + teststreamseq, + testsemaphore import testminasn1, testrsa, diff --git a/tests/testsemaphore.nim b/tests/testsemaphore.nim new file mode 100644 index 0000000000..39134b69ed --- /dev/null +++ b/tests/testsemaphore.nim @@ -0,0 +1,103 @@ +import random +import chronos +import ../libp2p/utils/semaphore + +import ./helpers + +randomize() + +suite "AsyncSemaphore": + asyncTest "should acquire": + let sema = AsyncSemaphore.init(3) + + await sema.acquire() + await sema.acquire() + await sema.acquire() + + check sema.count == 0 + + asyncTest "should release": + let sema = AsyncSemaphore.init(3) + + await sema.acquire() + await sema.acquire() + await sema.acquire() + + check sema.count == 0 + sema.release() + sema.release() + sema.release() + check sema.count == 3 + + asyncTest "should queue acquire": + let sema = AsyncSemaphore.init(1) + + await sema.acquire() + let fut = sema.acquire() + + check sema.count == -1 + check sema.queue.len == 1 + sema.release() + sema.release() + check sema.count == 1 + + await sleepAsync(10.millis) + check fut.finished() + + asyncTest "should keep count == size": + let sema = AsyncSemaphore.init(1) + sema.release() + sema.release() + sema.release() + check sema.count == 1 + + asyncTest "should tryAcquire": + let sema = AsyncSemaphore.init(1) + await sema.acquire() + check sema.tryAcquire() == false + + asyncTest "should tryAcquire and acquire": + let sema = AsyncSemaphore.init(4) + check sema.tryAcquire() == true + check sema.tryAcquire() == true + check sema.tryAcquire() == true + check sema.tryAcquire() == true + check sema.count == 0 + + let fut = sema.acquire() + check fut.finished == false + check sema.count == -1 + # queue is only used when count is < 0 + check sema.queue.len == 1 + + sema.release() + sema.release() + sema.release() + sema.release() + sema.release() + + check fut.finished == true + check sema.count == 4 + check sema.queue.len == 0 + + asyncTest "should restrict resource access": + let sema = AsyncSemaphore.init(3) + var resource = 0 + + proc task() {.async.} = + try: + await sema.acquire() + resource.inc() + check resource > 0 and resource <= 3 + let sleep = rand(0..10).millis + # echo sleep + await sleepAsync(sleep) + finally: + resource.dec() + sema.release() + + var tasks: seq[Future[void]] + for i in 0..<10: + tasks.add(task()) + + await allFutures(tasks)