diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 7338a41e27..d69751b0d3 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -631,10 +631,24 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I }, nil } - selected, err := msmux.SelectOneOf(pidStrings, s) - if err != nil { + // Negotiate the protocol in the background, obeying the context. + var selected string + errCh := make(chan error, 1) + go func() { + selected, err = msmux.SelectOneOf(pidStrings, s) + errCh <- err + }() + select { + case err = <-errCh: + if err != nil { + s.Reset() + return nil, err + } + case <-ctx.Done(): s.Reset() - return nil, err + // wait for the negotiation to cancel. + <-errCh + return nil, ctx.Err() } selpid := protocol.ID(selected) diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 2f2ab7624a..3585b01b65 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -3,6 +3,7 @@ package basichost import ( "bytes" "context" + "fmt" "io" "reflect" "sync" @@ -777,6 +778,49 @@ func TestHostAddrChangeDetection(t *testing.T) { } } +func TestNegotiationCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h1, h2 := getHostPair(ctx, t) + defer h1.Close() + defer h2.Close() + + // pre-negotiation so we can make the negotiation hang. + h1.Network().SetStreamHandler(func(s network.Stream) { + <-ctx.Done() // wait till the test is done. + s.Reset() + }) + + ctx2, cancel2 := context.WithCancel(ctx) + defer cancel2() + + errCh := make(chan error, 1) + go func() { + s, err := h2.NewStream(ctx2, h1.ID(), "/testing") + if s != nil { + errCh <- fmt.Errorf("expected to fail negotiation") + return + } + errCh <- err + }() + select { + case err := <-errCh: + t.Fatal(err) + case <-time.After(10 * time.Millisecond): + // ok, hung. + } + cancel2() + + select { + case err := <-errCh: + require.Equal(t, err, context.Canceled) + case <-time.After(500 * time.Millisecond): + // failed to cancel + t.Fatal("expected negotiation to be canceled") + } +} + func waitForAddrChangeEvent(ctx context.Context, sub event.Subscription, t *testing.T) event.EvtLocalAddressesUpdated { for { select {