diff --git a/p2p/protocol/internal/circuitv1-deprecated/relay.go b/p2p/protocol/internal/circuitv1-deprecated/relay.go index 646df850cd..0afcf0467a 100644 --- a/p2p/protocol/internal/circuitv1-deprecated/relay.go +++ b/p2p/protocol/internal/circuitv1-deprecated/relay.go @@ -369,13 +369,18 @@ func (r *Relay) handleHopStream(s inet.Stream, msg *pb.CircuitRelay) { r.addLiveHop(src.ID, dst.ID) - var wg sync.WaitGroup - wg.Add(2) + goroutines := new(int32) + *goroutines = 2 + done := func() { + if atomic.AddInt32(goroutines, -1) == 0 { + r.rmLiveHop(src.ID, dst.ID) + } + } // Don't reset streams after finishing or the other side will get an // error, not an EOF. go func() { - defer wg.Done() + defer done() buf := pool.Get(HopStreamBufferSize) defer pool.Put(buf) @@ -394,7 +399,7 @@ func (r *Relay) handleHopStream(s inet.Stream, msg *pb.CircuitRelay) { }() go func() { - defer wg.Done() + defer done() buf := pool.Get(HopStreamBufferSize) defer pool.Put(buf) @@ -411,11 +416,6 @@ func (r *Relay) handleHopStream(s inet.Stream, msg *pb.CircuitRelay) { } log.Debugf("relayed %d bytes from %s to %s", count, src.ID.Pretty(), dst.ID.Pretty()) }() - - go func() { - wg.Wait() - r.rmLiveHop(src.ID, dst.ID) - }() } func (r *Relay) handleStopStream(s inet.Stream, msg *pb.CircuitRelay) {