diff --git a/Makefile b/Makefile index 7811c09..bf3427e 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,11 @@ covertools: go get github.com/mattn/goveralls go get golang.org/x/tools/cmd/cover -deps: gx covertools +ginkgo: + go get github.com/onsi/ginkgo/ginkgo + go get github.com/onsi/gomega + +deps: gx covertools ginkgo gx --verbose install --global gx-go rewrite diff --git a/conn.go b/conn.go index 1a11590..e589fb0 100644 --- a/conn.go +++ b/conn.go @@ -2,50 +2,89 @@ package conn import ( "context" + "errors" "io" "net" - "time" logging "github.com/ipfs/go-log" - mpool "github.com/jbenet/go-msgio/mpool" + ci "github.com/libp2p/go-libp2p-crypto" ic "github.com/libp2p/go-libp2p-crypto" iconn "github.com/libp2p/go-libp2p-interface-conn" lgbl "github.com/libp2p/go-libp2p-loggables" peer "github.com/libp2p/go-libp2p-peer" + secio "github.com/libp2p/go-libp2p-secio" tpt "github.com/libp2p/go-libp2p-transport" + smux "github.com/libp2p/go-stream-muxer" ma "github.com/multiformats/go-multiaddr" ) var log = logging.Logger("conn") -// ReleaseBuffer puts the given byte array back into the buffer pool, -// first verifying that it is the correct size -func ReleaseBuffer(b []byte) { - log.Debugf("Releasing buffer! (cap,size = %d, %d)", cap(b), len(b)) - mpool.ByteSlicePool.Put(uint32(cap(b)), b) -} - -// singleConn represents a single connection to another Peer (IPFS Node). +// singleConn represents a single stream-multipexed connection to another Peer (IPFS Node). type singleConn struct { - local peer.ID - remote peer.ID - maconn tpt.Conn - event io.Closer + streamConn smux.Conn + tptConn tpt.Conn + + secSession secio.Session + + event io.Closer } -// newConn constructs a new connection -func newSingleConn(ctx context.Context, local, remote peer.ID, maconn tpt.Conn) (iconn.Conn, error) { - ml := lgbl.Dial("conn", local, remote, maconn.LocalMultiaddr(), maconn.RemoteMultiaddr()) +var _ iconn.Conn = &singleConn{} - conn := &singleConn{ - local: local, - remote: remote, - maconn: maconn, - event: log.EventBegin(ctx, "connLifetime", ml), +// newSingleConn constructs a new connection +func newSingleConn(ctx context.Context, local, remote peer.ID, privKey ci.PrivKey, tptConn tpt.Conn, pstpt smux.Transport, isServer bool) (iconn.Conn, error) { + ml := lgbl.Dial("conn", local, remote, tptConn.LocalMultiaddr(), tptConn.RemoteMultiaddr()) + + var streamConn smux.Conn + var secSession secio.Session + + c := tptConn + // 1. secure the connection + if privKey != nil && iconn.EncryptConnections { + var err error + secSession, err = setupSecureSession(ctx, local, privKey, tptConn) + if err != nil { + return nil, err + } + c = &secureConn{ + insecure: tptConn, + secure: secSession, + } + } else { + log.Warning("creating INSECURE connection %s at %s", tptConn.LocalMultiaddr(), tptConn.RemoteMultiaddr()) + } + + // 2. start stream multipling + var err error + streamConn, err = pstpt.NewConn(c, isServer) + if err != nil { + return nil, err } - log.Debugf("newSingleConn %p: %v to %v", conn, local, remote) - return conn, nil + sconn := &singleConn{ + streamConn: streamConn, + tptConn: tptConn, + secSession: secSession, + event: log.EventBegin(ctx, "connLifetime", ml), + } + + log.Debugf("newSingleConn %p: %v to %v", sconn, local, remote) + return sconn, nil +} + +func setupSecureSession(ctx context.Context, local peer.ID, privKey ci.PrivKey, ch io.ReadWriteCloser) (secio.Session, error) { + if local == "" { + return nil, errors.New("local peer is nil") + } + if privKey == nil { + return nil, errors.New("private key is nil") + } + sessgen := secio.SessionGenerator{ + LocalID: local, + PrivateKey: privKey, + } + return sessgen.NewSession(ctx, ch) } // close is the internal close function, called by ContextCloser.Close @@ -57,8 +96,8 @@ func (c *singleConn) Close() error { } }() - // close underlying connection - return c.maconn.Close() + // closing the stream muxer also closes the underlying net.Conn + return c.streamConn.Close() } // ID is an identifier unique to this connection. @@ -71,62 +110,59 @@ func (c *singleConn) String() string { } func (c *singleConn) LocalAddr() net.Addr { - return c.maconn.LocalAddr() + return c.tptConn.LocalAddr() } func (c *singleConn) RemoteAddr() net.Addr { - return c.maconn.RemoteAddr() + return c.tptConn.RemoteAddr() } func (c *singleConn) LocalPrivateKey() ic.PrivKey { + if c.secSession != nil { + return c.secSession.LocalPrivateKey() + } return nil } func (c *singleConn) RemotePublicKey() ic.PubKey { + if c.secSession != nil { + return c.secSession.RemotePublicKey() + } return nil } -func (c *singleConn) SetDeadline(t time.Time) error { - return c.maconn.SetDeadline(t) -} -func (c *singleConn) SetReadDeadline(t time.Time) error { - return c.maconn.SetReadDeadline(t) -} - -func (c *singleConn) SetWriteDeadline(t time.Time) error { - return c.maconn.SetWriteDeadline(t) -} - // LocalMultiaddr is the Multiaddr on this side func (c *singleConn) LocalMultiaddr() ma.Multiaddr { - return c.maconn.LocalMultiaddr() + return c.tptConn.LocalMultiaddr() } // RemoteMultiaddr is the Multiaddr on the remote side func (c *singleConn) RemoteMultiaddr() ma.Multiaddr { - return c.maconn.RemoteMultiaddr() + return c.tptConn.RemoteMultiaddr() } func (c *singleConn) Transport() tpt.Transport { - return c.maconn.Transport() + return c.tptConn.Transport() } // LocalPeer is the Peer on this side func (c *singleConn) LocalPeer() peer.ID { - return c.local + return c.secSession.LocalPeer() } // RemotePeer is the Peer on the remote side func (c *singleConn) RemotePeer() peer.ID { - return c.remote + return c.secSession.RemotePeer() +} + +func (c *singleConn) AcceptStream() (smux.Stream, error) { + return c.streamConn.AcceptStream() } -// Read reads data, net.Conn style -func (c *singleConn) Read(buf []byte) (int, error) { - return c.maconn.Read(buf) +func (c *singleConn) OpenStream() (smux.Stream, error) { + return c.streamConn.OpenStream() } -// Write writes data, net.Conn style -func (c *singleConn) Write(buf []byte) (int, error) { - return c.maconn.Write(buf) +func (c *singleConn) IsClosed() bool { + return c.streamConn.IsClosed() } diff --git a/conn_suite_test.go b/conn_suite_test.go new file mode 100644 index 0000000..99cdf20 --- /dev/null +++ b/conn_suite_test.go @@ -0,0 +1,82 @@ +package conn + +import ( + "context" + "strings" + "testing" + "time" + + ci "github.com/libp2p/go-libp2p-crypto" + iconn "github.com/libp2p/go-libp2p-interface-conn" + peer "github.com/libp2p/go-libp2p-peer" + tpt "github.com/libp2p/go-libp2p-transport" + tcpt "github.com/libp2p/go-tcp-transport" + tu "github.com/libp2p/go-testutil" + ma "github.com/multiformats/go-multiaddr" + yamux "github.com/whyrusleeping/go-smux-yamux" + grc "github.com/whyrusleeping/gorocheck" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestGoLibp2pConn(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "go-libp2p-conn Suite") +} + +var _ = AfterEach(func() { + time.Sleep(300 * time.Millisecond) + Expect(grc.CheckForLeaks(func(r *grc.Goroutine) bool { + return strings.Contains(r.Function, "go-log.") || + strings.Contains(r.Stack[0], "testing.(*T).Run") || + strings.Contains(r.Function, "specrunner.") || + strings.Contains(r.Function, "runtime.gopark") + })).To(Succeed()) +}) + +// the stream muxer used for tests using the single stream connection +var streamMuxer = yamux.DefaultTransport + +// dialRawConn dials a tpt.Conn +// but it stops there. It doesn't do protocol selection and handshake +func dialRawConn(laddr, raddr ma.Multiaddr) tpt.Conn { + d, err := tcpt.NewTCPTransport().Dialer(laddr) + Expect(err).ToNot(HaveOccurred()) + c, err := d.Dial(raddr) + Expect(err).ToNot(HaveOccurred()) + return c +} + +// getTransport gets the right transport for a multiaddr +func getTransport(a ma.Multiaddr) tpt.Transport { + return tcpt.NewTCPTransport() +} + +// getListener creates a listener based on the PeerNetParams +// it updates the PeerNetParams to reflect the local address that was selected by the kernel +func getListener(ctx context.Context, p *tu.PeerNetParams) iconn.Listener { + tptListener, err := getTransport(p.Addr).Listen(p.Addr) + Expect(err).ToNot(HaveOccurred()) + list, err := WrapTransportListener(ctx, tptListener, p.ID, streamMuxer, p.PrivKey) + Expect(err).ToNot(HaveOccurred()) + p.Addr = list.Multiaddr() + return list +} + +func getDialer(localPeer peer.ID, privKey ci.PrivKey, addr ma.Multiaddr) *Dialer { + d := NewDialer(localPeer, privKey, nil, streamMuxer) + d.fallback = nil // unset the fallback dialer. We want tests use the configured dialer, and to fail otherwise + tptd, err := getTransport(addr).Dialer(addr) + Expect(err).ToNot(HaveOccurred()) + d.AddDialer(tptd) + return d +} + +// randPeerNetParams works like testutil.RandPeerNetParams +// if called for a multi-stream transport, it replaces the address with a QUIC address +func randPeerNetParams() *tu.PeerNetParams { + p, err := tu.RandPeerNetParams() + Expect(err).ToNot(HaveOccurred()) + return p +} diff --git a/conn_test.go b/conn_test.go index bf930a6..a96ddab 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,139 +1,312 @@ package conn import ( - "bytes" "context" "fmt" - "runtime" + "io" "sync" - "testing" "time" - msgio "github.com/jbenet/go-msgio" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + iconn "github.com/libp2p/go-libp2p-interface-conn" - travis "github.com/libp2p/go-testutil/ci/travis" + smux "github.com/libp2p/go-stream-muxer" ) -func msgioWrap(c iconn.Conn) msgio.ReadWriter { - return msgio.NewReadWriter(c) -} +var _ = Describe("Connections", func() { + It("uses the right handshake protocol", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() -func testOneSendRecv(t *testing.T, c1, c2 iconn.Conn) { - mc1 := msgioWrap(c1) - mc2 := msgioWrap(c2) + p1 := randPeerNetParams() + l1 := getListener(ctx, p1) + defer l1.Close() + go l1.Accept() + }) - log.Debugf("testOneSendRecv from %s to %s", c1.LocalPeer(), c2.LocalPeer()) - m1 := []byte("hello") - if err := mc1.WriteMsg(m1); err != nil { - t.Fatal(err) - } - m2, err := mc2.ReadMsg() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(m1, m2) { - t.Fatalf("failed to send: %s %s", m1, m2) - } -} + for _, val := range []bool{true, false} { + secure := val -func testNotOneSendRecv(t *testing.T, c1, c2 iconn.Conn) { - mc1 := msgioWrap(c1) - mc2 := msgioWrap(c2) + It(fmt.Sprintf("establishes a connection (secure: %t)", secure), func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - m1 := []byte("hello") - if err := mc1.WriteMsg(m1); err == nil { - t.Fatal("write should have failed", err) - } - _, err := mc2.ReadMsg() - if err == nil { - t.Fatal("read should have failed", err) + p1 := randPeerNetParams() + p2 := randPeerNetParams() + if !secure { + p1.PrivKey = nil + p2.PrivKey = nil + } + + l1 := getListener(ctx, p1) + defer l1.Close() + + // accept a connection, accept a stream on this connection and echo everything + go func() { + defer GinkgoRecover() + c, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + str, err := c.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + go io.Copy(str, str) + }() + + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + c, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + defer c.Close() + str, err := c.OpenStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("beep")) + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("boop")) + Expect(err).ToNot(HaveOccurred()) + + out := make([]byte, 8) + _, err = io.ReadFull(str, out) + Expect(err).ToNot(HaveOccurred()) + Expect(out).To(Equal([]byte("beepboop"))) + }) } -} -func TestClose(t *testing.T) { - // t.Skip("Skipping in favor of another test") + It("continues accepting connections while another accept is hanging", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1 := randPeerNetParams() + p2 := randPeerNetParams() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - c1, c2, _, _ := setupSingleConn(t, ctx) + l1 := getListener(ctx, p1) + defer l1.Close() - testOneSendRecv(t, c1, c2) - testOneSendRecv(t, c2, c1) + go func() { + defer GinkgoRecover() + conn := dialRawConn(p2.Addr, l1.Multiaddr()) + defer conn.Close() // hang this connection - c1.Close() - testNotOneSendRecv(t, c1, c2) + // ensure that the first conn hits first + time.Sleep(50 * time.Millisecond) + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + conn2, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + defer conn2.Close() + }() - c2.Close() - testNotOneSendRecv(t, c2, c1) - testNotOneSendRecv(t, c1, c2) -} + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + Eventually(done).Should(BeClosed()) + }) -func TestCloseLeak(t *testing.T) { - // t.Skip("Skipping in favor of another test") - if testing.Short() { - t.SkipNow() - } + It("timeouts", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - if travis.IsRunning() { - t.Skip("this doesn't work well on travis") - } + old := ConnAcceptTimeout + ConnAcceptTimeout = 3 * time.Second + defer func() { ConnAcceptTimeout = old }() + + p1 := randPeerNetParams() + p2 := randPeerNetParams() + + l1 := getListener(ctx, p1) + defer l1.Close() + + n := 20 + + before := time.Now() + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + defer GinkgoRecover() + defer wg.Done() + c := dialRawConn(p2.Addr, l1.Multiaddr()) + defer c.Close() + // hang this connection until timeout + io.ReadFull(c, make([]byte, 1000)) + }() + } + + // wait to make sure the hanging dials have started + time.Sleep(50 * time.Millisecond) + + accepted := make(chan struct{}) // this chan is closed once all good connections have been accepted + goodN := 10 + for i := 0; i < goodN; i++ { + go func(i int) { + defer GinkgoRecover() + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + conn, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + <-accepted + conn.Close() + }(i) + } - var wg sync.WaitGroup + for i := 0; i < goodN; i++ { + _, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + } + close(accepted) + Expect(time.Now()).To(BeTemporally("<", before.Add(ConnAcceptTimeout/4))) + Eventually(func() bool { + wg.Wait() // wait for the timeouts for the raw connections to occur + return true + }, ConnAcceptTimeout).Should(BeTrue()) + Expect(time.Now()).To(BeTemporally(">", before.Add(ConnAcceptTimeout))) + + // make sure we can dial in still after a bunch of timeouts + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + conn, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + defer conn.Close() + Eventually(done).Should(BeClosed()) + }) - runPair := func(num int) { + It("doesn't complete the handshake with the wrong keys", func() { ctx, cancel := context.WithCancel(context.Background()) - c1, c2, _, _ := setupSingleConn(t, ctx) + defer cancel() - mc1 := msgioWrap(c1) - mc2 := msgioWrap(c2) + p1 := randPeerNetParams() + p2 := randPeerNetParams() - for i := 0; i < num; i++ { - b1 := []byte(fmt.Sprintf("beep%d", i)) - mc1.WriteMsg(b1) - b2, err := mc2.ReadMsg() - if err != nil { - panic(err) - } - if !bytes.Equal(b1, b2) { - panic(fmt.Errorf("bytes not equal: %s != %s", b1, b2)) - } + l1 := getListener(ctx, p1) + defer l1.Close() - b2 = []byte(fmt.Sprintf("boop%d", i)) - mc2.WriteMsg(b2) - b1, err = mc1.ReadMsg() - if err != nil { - panic(err) - } - if !bytes.Equal(b1, b2) { - panic(fmt.Errorf("bytes not equal: %s != %s", b1, b2)) - } + // use the wrong private key here, correct would be: p2.PrivKey + d2 := getDialer(p2.ID, p1.PrivKey, p2.Addr) + + accepted := make(chan struct{}) + go func() { + l1.Accept() + close(accepted) + }() + + _, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).To(MatchError("peer.ID does not match PrivateKey")) + // make sure no connection was accepted + Consistently(accepted).ShouldNot(BeClosed()) + }) + + Context("closing", func() { + setupConn := func(ctx context.Context) (iconn.Conn, iconn.Conn) { + p1 := randPeerNetParams() + p2 := randPeerNetParams() + + l1 := getListener(ctx, p1) + + var c2 iconn.Conn + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + done := make(chan error) + go func() { + defer GinkgoRecover() + var err error + c2, err = d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() - <-time.After(time.Microsecond * 5) + c1, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + Eventually(done).Should(BeClosed()) + return c1, c2 } - c1.Close() - c2.Close() - cancel() // close the listener - wg.Done() - } + openStreamAndSend := func(c1, c2 iconn.Conn) { + str1, err := c1.OpenStream() + Expect(err).ToNot(HaveOccurred()) + m1 := []byte("hello") + _, err = str1.Write(m1) + Expect(err).ToNot(HaveOccurred()) + str2, err := c2.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + m2 := make([]byte, len(m1)) + _, err = str2.Read(m2) + Expect(err).ToNot(HaveOccurred()) + Expect(m1).To(Equal(m2)) + } - var cons = 5 - var msgs = 50 - log.Debugf("Running %d connections * %d msgs.\n", cons, msgs) - for i := 0; i < cons; i++ { - wg.Add(1) - go runPair(msgs) - } + checkStreamOpenAcceptFails := func(c1, c2 iconn.Conn) { + _, err := c1.OpenStream() + Expect(err).To(HaveOccurred()) + accepted := make(chan struct{}) + go func() { + _, err := c2.AcceptStream() + Expect(err).To(HaveOccurred()) + close(accepted) + }() + Eventually(accepted).Should(BeClosed()) + } - log.Debugf("Waiting...\n") - wg.Wait() - // done! + It("closes", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - time.Sleep(time.Millisecond * 150) - ngr := runtime.NumGoroutine() - if ngr > 25 { - // note, this is really innacurate - //panic("uncomment me to debug") - t.Fatal("leaking goroutines:", ngr) - } -} + c1, c2 := setupConn(ctx) + openStreamAndSend(c1, c2) + openStreamAndSend(c2, c1) + + c1.Close() + Expect(c1.IsClosed()).To(BeTrue()) + Eventually(c2.IsClosed).Should(BeTrue()) + checkStreamOpenAcceptFails(c2, c1) + checkStreamOpenAcceptFails(c1, c2) + }) + + It("doesn't leak", func() { + // runPair opens one stream and sends num messages + runPair := func(c1, c2 iconn.Conn, num int) { + var str2 smux.Stream + str1, err := c1.OpenStream() + Expect(err).ToNot(HaveOccurred()) + + for i := 0; i < num; i++ { + b1 := []byte("beep") + _, err := str1.Write(b1) + Expect(err).ToNot(HaveOccurred()) + if str2 == nil { + str2, err = c2.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + } + b2 := make([]byte, len(b1)) + _, err = str2.Read(b2) + Expect(err).ToNot(HaveOccurred()) + Expect(b1).To(Equal(b2)) + } + } + + var cons = 10 + var msgs = 10 + var wg sync.WaitGroup + for i := 0; i < cons; i++ { + wg.Add(1) + ctx, cancel := context.WithCancel(context.Background()) + c1, c2 := setupConn(ctx) + go func(c1, c2 iconn.Conn) { + defer GinkgoRecover() + defer cancel() + runPair(c1, c2, msgs) + c1.Close() + c2.Close() + wg.Done() + }(c1, c2) + } + + wg.Wait() + }) + }) +}) diff --git a/dial.go b/dial.go index 446687b..bc11f38 100644 --- a/dial.go +++ b/dial.go @@ -3,23 +3,35 @@ package conn import ( "context" "fmt" - "math/rand" "strings" "time" - addrutil "github.com/libp2p/go-addr-util" ci "github.com/libp2p/go-libp2p-crypto" iconn "github.com/libp2p/go-libp2p-interface-conn" ipnet "github.com/libp2p/go-libp2p-interface-pnet" lgbl "github.com/libp2p/go-libp2p-loggables" peer "github.com/libp2p/go-libp2p-peer" - transport "github.com/libp2p/go-libp2p-transport" + tpt "github.com/libp2p/go-libp2p-transport" + smux "github.com/libp2p/go-stream-muxer" ma "github.com/multiformats/go-multiaddr" - manet "github.com/multiformats/go-multiaddr-net" msmux "github.com/multiformats/go-multistream" ) -type WrapFunc func(transport.Conn) transport.Conn +// DialTimeout is the maximum duration a Dial is allowed to take. +// This includes the time between dialing the raw network connection, +// protocol selection as well the handshake, if applicable. +var DialTimeout = 60 * time.Second + +// dialTimeoutErr occurs when the DialTimeout is exceeded. +type dialTimeoutErr struct{} + +func (dialTimeoutErr) Error() string { return "deadline exceeded" } +func (dialTimeoutErr) Temporary() bool { return true } +func (dialTimeoutErr) Timeout() bool { return true } + +// The WrapFunc is used to wrap a tpt.Conn. +// It must not block. +type WrapFunc func(tpt.Conn) tpt.Conn // Dialer is an object that can open connections. We could have a "convenience" // Dial function as before, but it would have many arguments, as dialing is @@ -33,7 +45,7 @@ type Dialer struct { // Dialers are the sub-dialers usable by this dialer // selected in order based on the address being dialed - Dialers []transport.Dialer + Dialers []tpt.Dialer // PrivateKey used to initialize a secure connection. // Warning: if PrivateKey is nil, connection will not be secured. @@ -47,15 +59,18 @@ type Dialer struct { // Wrapper to wrap the raw connection (optional) Wrapper WrapFunc - fallback transport.Dialer + fallback tpt.Dialer + + streamMuxer smux.Transport } -func NewDialer(p peer.ID, pk ci.PrivKey, wrap WrapFunc) *Dialer { +func NewDialer(p peer.ID, pk ci.PrivKey, wrap WrapFunc, sm smux.Transport) *Dialer { return &Dialer{ - LocalPeer: p, - PrivateKey: pk, - Wrapper: wrap, - fallback: new(transport.FallbackDialer), + LocalPeer: p, + PrivateKey: pk, + Wrapper: wrap, + fallback: new(tpt.FallbackDialer), + streamMuxer: sm, } } @@ -69,48 +84,47 @@ func (d *Dialer) String() string { // Example: d.DialAddr(ctx, peer.Addresses()[0], peer) func (d *Dialer) Dial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) (iconn.Conn, error) { logdial := lgbl.Dial("conn", d.LocalPeer, remote, nil, raddr) + defer log.EventBegin(ctx, "connDial", logdial).Done() logdial["encrypted"] = (d.PrivateKey != nil) // log wether this will be an encrypted dial or not. logdial["inPrivNet"] = (d.Protector != nil) - defer log.EventBegin(ctx, "connDial", logdial).Done() - if d.Protector == nil && ipnet.ForcePrivateNetwork { log.Error("tried to dial with no Private Network Protector but usage" + " of Private Networks is forced by the enviroment") return nil, ipnet.ErrNotInPrivateNetwork } - var connOut iconn.Conn - var errOut error - done := make(chan struct{}) - - // do it async to ensure we respect don contexteone - go func() { - defer func() { - select { - case done <- struct{}{}: - case <-ctx.Done(): - } - }() + c, err := d.doDial(ctx, raddr, remote) + if err != nil { + logdial["error"] = err.Error() + logdial["dial"] = "failure" + return nil, err + } + logdial["dial"] = "success" + return c, nil +} - maconn, err := d.rawConnDial(ctx, raddr, remote) - if err != nil { - errOut = err - return - } +func (d *Dialer) doDial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) (iconn.Conn, error) { + rawConn, err := d.rawConnDial(ctx, raddr, remote) + if err != nil { + return nil, err + } + done := make(chan connOrErr, 1) + // do it async to ensure we respect the context + go func() { if d.Protector != nil { - pconn, err := d.Protector.Protect(maconn) + var pconn tpt.Conn + pconn, err = d.Protector.Protect(rawConn) if err != nil { - maconn.Close() - errOut = err + done <- connOrErr{err: err} return } - maconn = pconn + rawConn = pconn } if d.Wrapper != nil { - maconn = d.Wrapper(maconn) + rawConn = d.Wrapper(rawConn) } cryptoProtoChoice := SecioTag @@ -118,63 +132,43 @@ func (d *Dialer) Dial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) ( cryptoProtoChoice = NoEncryptionTag } - maconn.SetReadDeadline(time.Now().Add(NegotiateReadTimeout)) - - err = msmux.SelectProtoOrFail(cryptoProtoChoice, maconn) - if err != nil { - errOut = err - return - } - - maconn.SetReadDeadline(time.Time{}) - - c, err := newSingleConn(ctx, d.LocalPeer, remote, maconn) - if err != nil { - maconn.Close() - errOut = err - return - } - if d.PrivateKey == nil || !iconn.EncryptConnections { - log.Warning("dialer %s dialing INSECURELY %s at %s!", d, remote, raddr) - connOut = c + if err := msmux.SelectProtoOrFail(cryptoProtoChoice, rawConn); err != nil { + done <- connOrErr{err: err} return } - c2, err := newSecureConn(ctx, d.PrivateKey, c) + c, err := newSingleConn(ctx, d.LocalPeer, remote, d.PrivateKey, rawConn, d.streamMuxer, false) if err != nil { - errOut = err - c.Close() + done <- connOrErr{err: err} return } - connOut = c2 + done <- connOrErr{conn: c} }() + var res connOrErr select { case <-ctx.Done(): - logdial["error"] = ctx.Err().Error() - logdial["dial"] = "failure" + rawConn.Close() return nil, ctx.Err() - case <-done: - // whew, finished. - } - - if errOut != nil { - logdial["error"] = errOut.Error() - logdial["dial"] = "failure" - return nil, errOut + case <-time.After(DialTimeout): + rawConn.Close() + return nil, &dialTimeoutErr{} + case res = <-done: + if res.err != nil { + rawConn.Close() + } } - logdial["dial"] = "success" - return connOut, nil + return res.conn, res.err } -func (d *Dialer) AddDialer(pd transport.Dialer) { +func (d *Dialer) AddDialer(pd tpt.Dialer) { d.Dialers = append(d.Dialers, pd) } // returns dialer that can dial the given address -func (d *Dialer) subDialerForAddr(raddr ma.Multiaddr) transport.Dialer { +func (d *Dialer) subDialerForAddr(raddr ma.Multiaddr) tpt.Dialer { for _, pd := range d.Dialers { if pd.Matches(raddr) { return pd @@ -189,7 +183,7 @@ func (d *Dialer) subDialerForAddr(raddr ma.Multiaddr) transport.Dialer { } // rawConnDial dials the underlying net.Conn + manet.Conns -func (d *Dialer) rawConnDial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) (transport.Conn, error) { +func (d *Dialer) rawConnDial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) (tpt.Conn, error) { if strings.HasPrefix(raddr.String(), "/ip4/0.0.0.0") { log.Event(ctx, "connDialZeroAddr", lgbl.Dial("conn", d.LocalPeer, remote, nil, raddr)) return nil, fmt.Errorf("Attempted to connect to zero address: %s", raddr) @@ -202,66 +196,3 @@ func (d *Dialer) rawConnDial(ctx context.Context, raddr ma.Multiaddr, remote pee return sd.DialContext(ctx, raddr) } - -func pickLocalAddr(laddrs []ma.Multiaddr, raddr ma.Multiaddr) (laddr ma.Multiaddr) { - if len(laddrs) < 1 { - return nil - } - - // make sure that we ONLY use local addrs that match the remote addr. - laddrs = manet.AddrMatch(raddr, laddrs) - if len(laddrs) < 1 { - return nil - } - - // make sure that we ONLY use local addrs that CAN dial the remote addr. - // filter out all the local addrs that aren't capable - raddrIPLayer := ma.Split(raddr)[0] - raddrIsLoopback := manet.IsIPLoopback(raddrIPLayer) - raddrIsLinkLocal := manet.IsIP6LinkLocal(raddrIPLayer) - laddrs = addrutil.FilterAddrs(laddrs, func(a ma.Multiaddr) bool { - laddrIPLayer := ma.Split(a)[0] - laddrIsLoopback := manet.IsIPLoopback(laddrIPLayer) - laddrIsLinkLocal := manet.IsIP6LinkLocal(laddrIPLayer) - if laddrIsLoopback { // our loopback addrs can only dial loopbacks. - return raddrIsLoopback - } - if laddrIsLinkLocal { - return raddrIsLinkLocal // out linklocal addrs can only dial link locals. - } - return true - }) - - // TODO pick with a good heuristic - // we use a random one for now to prevent bad addresses from making nodes unreachable - // with a random selection, multiple tries may work. - return laddrs[rand.Intn(len(laddrs))] -} - -// MultiaddrProtocolsMatch returns whether two multiaddrs match in protocol stacks. -func MultiaddrProtocolsMatch(a, b ma.Multiaddr) bool { - ap := a.Protocols() - bp := b.Protocols() - - if len(ap) != len(bp) { - return false - } - - for i, api := range ap { - if api.Code != bp[i].Code { - return false - } - } - - return true -} - -// MultiaddrNetMatch returns the first Multiaddr found to match network. -func MultiaddrNetMatch(tgt ma.Multiaddr, srcs []ma.Multiaddr) ma.Multiaddr { - for _, a := range srcs { - if MultiaddrProtocolsMatch(tgt, a) { - return a - } - } - return nil -} diff --git a/dial_test.go b/dial_test.go index de70d07..0eb5155 100644 --- a/dial_test.go +++ b/dial_test.go @@ -1,750 +1,61 @@ package conn import ( - "bytes" "context" - "fmt" - "io" "net" - "runtime" - "strings" - "sync" - "testing" "time" - ic "github.com/libp2p/go-libp2p-crypto" - iconn "github.com/libp2p/go-libp2p-interface-conn" - ipnet "github.com/libp2p/go-libp2p-interface-pnet" - peer "github.com/libp2p/go-libp2p-peer" - transport "github.com/libp2p/go-libp2p-transport" - tcpt "github.com/libp2p/go-tcp-transport" - tu "github.com/libp2p/go-testutil" ma "github.com/multiformats/go-multiaddr" - msmux "github.com/multiformats/go-multistream" - grc "github.com/whyrusleeping/gorocheck" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" ) -func goroFilter(r *grc.Goroutine) bool { - return strings.Contains(r.Function, "go-log.") || strings.Contains(r.Stack[0], "testing.(*T).Run") -} - -func echoListen(ctx context.Context, listener iconn.Listener) { - for { - c, err := listener.Accept() - if err != nil { - - select { - case <-ctx.Done(): - return - default: - } - - if ne, ok := err.(net.Error); ok && ne.Temporary() { - <-time.After(time.Microsecond * 10) - continue - } - - log.Debugf("echoListen: listener appears to be closing") - return - } - - go echo(c.(iconn.Conn)) - } -} - -func echo(c iconn.Conn) { - io.Copy(c, c) -} - -func setupSecureConn(t *testing.T, ctx context.Context) (a, b iconn.Conn, p1, p2 tu.PeerNetParams) { - return setupConn(t, ctx, true) -} - -func setupSingleConn(t *testing.T, ctx context.Context) (a, b iconn.Conn, p1, p2 tu.PeerNetParams) { - return setupConn(t, ctx, false) -} - -func Listen(ctx context.Context, addr ma.Multiaddr, local peer.ID, sk ic.PrivKey) (iconn.Listener, error) { - list, err := tcpt.NewTCPTransport().Listen(addr) - if err != nil { - return nil, err - } - - return WrapTransportListener(ctx, list, local, sk) -} - -func dialer(t *testing.T, a ma.Multiaddr) transport.Dialer { - tpt := tcpt.NewTCPTransport() - tptd, err := tpt.Dialer(a) - if err != nil { - t.Fatal(err) - } - - return tptd -} - -func setupConn(t *testing.T, ctx context.Context, secure bool) (a, b iconn.Conn, p1, p2 tu.PeerNetParams) { - - p1 = tu.RandPeerNetParamsOrFatal(t) - p2 = tu.RandPeerNetParamsOrFatal(t) - - key1 := p1.PrivKey - key2 := p2.PrivKey - if !secure { - key1 = nil - key2 = nil - } - l1, err := Listen(ctx, p1.Addr, p1.ID, key1) - if err != nil { - t.Fatal(err) - } - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - d2 := &Dialer{ - LocalPeer: p2.ID, - PrivateKey: key2, - } - - d2.AddDialer(dialer(t, p2.Addr)) - - var c2 iconn.Conn - - done := make(chan error) - go func() { - defer close(done) - - var err error - c2, err = d2.Dial(ctx, p1.Addr, p1.ID) - if err != nil { - done <- err - return - } - - // if secure, need to read + write, as that's what triggers the handshake. - if secure { - if err := sayHello(c2); err != nil { - done <- err - } - } - }() - - c1, err := l1.Accept() - if err != nil { - t.Fatal("failed to accept", err) - } - - // if secure, need to read + write, as that's what triggers the handshake. - if secure { - if err := sayHello(c1); err != nil { - done <- err - } - } - - if err := <-done; err != nil { - t.Fatal(err) - } - - return c1.(iconn.Conn), c2, p1, p2 -} - -func sayHello(c net.Conn) error { - h := []byte("hello") - if _, err := c.Write(h); err != nil { - return err - } - if _, err := c.Read(h); err != nil { - return err - } - if string(h) != "hello" { - return fmt.Errorf("did not get hello") - } - return nil -} - -func testDialer(t *testing.T, secure bool) { - // t.Skip("Skipping in favor of another test") - - p1 := tu.RandPeerNetParamsOrFatal(t) - p2 := tu.RandPeerNetParamsOrFatal(t) - - key1 := p1.PrivKey - key2 := p2.PrivKey - if !secure { - key1 = nil - key2 = nil - t.Log("testing insecurely") - } else { - t.Log("testing securely") - } - - ctx, cancel := context.WithCancel(context.Background()) - l1, err := Listen(ctx, p1.Addr, p1.ID, key1) - if err != nil { - t.Fatal(err) - } - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - d2 := &Dialer{ - LocalPeer: p2.ID, - PrivateKey: key2, - } - d2.AddDialer(dialer(t, p2.Addr)) - - go echoListen(ctx, l1) - - c, err := d2.Dial(ctx, p1.Addr, p1.ID) - if err != nil { - t.Fatal("error dialing peer", err) - } - - // fmt.Println("sending") - mc := msgioWrap(c) - mc.WriteMsg([]byte("beep")) - mc.WriteMsg([]byte("boop")) - out, err := mc.ReadMsg() - if err != nil { - t.Fatal(err) - } - - // fmt.Println("recving", string(out)) - data := string(out) - if data != "beep" { - t.Error("unexpected conn output", data) - } - - out, err = mc.ReadMsg() - if err != nil { - t.Fatal(err) - } - - data = string(out) - if string(out) != "boop" { - t.Error("unexpected conn output", data) - } - - // fmt.Println("closing") - c.Close() - l1.Close() - cancel() -} - -func TestDialerInsecure(t *testing.T) { - // t.Skip("Skipping in favor of another test") - testDialer(t, false) -} - -func TestDialerSecure(t *testing.T) { - // t.Skip("Skipping in favor of another test") - testDialer(t, true) -} - -func testDialerCloseEarly(t *testing.T, secure bool) { - // t.Skip("Skipping in favor of another test") - - p1 := tu.RandPeerNetParamsOrFatal(t) - p2 := tu.RandPeerNetParamsOrFatal(t) - - key1 := p1.PrivKey - if !secure { - key1 = nil - t.Log("testing insecurely") - } else { - t.Log("testing securely") - } - - ctx, cancel := context.WithCancel(context.Background()) - l1, err := Listen(ctx, p1.Addr, p1.ID, key1) - if err != nil { - t.Fatal(err) - } - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - // lol nesting - d2 := &Dialer{ - LocalPeer: p2.ID, - PrivateKey: p2.PrivKey, //-- dont give it key. we'll just close the conn. - } - d2.AddDialer(dialer(t, p2.Addr)) - - errs := make(chan error, 100) - done := make(chan struct{}, 1) - gotclosed := make(chan struct{}, 1) - go func() { - defer func() { done <- struct{}{} }() - - c, err := l1.Accept() - if err != nil { - if strings.Contains(err.Error(), "closed") { - gotclosed <- struct{}{} - return - } - errs <- err - } - - if _, err := c.Write([]byte("hello")); err != nil { - gotclosed <- struct{}{} - return - } - - errs <- fmt.Errorf("wrote to conn") - }() - - c, err := d2.Dial(ctx, p1.Addr, p1.ID) - if err != nil { - t.Fatal(err) - } - c.Close() // close it early. - - readerrs := func() { - for { - select { - case e := <-errs: - t.Error(e) - default: - return - } - } - } - readerrs() - - l1.Close() - <-done - cancel() - readerrs() - close(errs) - - select { - case <-gotclosed: - default: - t.Error("did not get closed") - } -} - -// we dont do a handshake with singleConn, so cant "close early." -// func TestDialerCloseEarlyInsecure(t *testing.T) { -// // t.Skip("Skipping in favor of another test") -// testDialerCloseEarly(t, false) -// } - -func TestDialerCloseEarlySecure(t *testing.T) { - // t.Skip("Skipping in favor of another test") - testDialerCloseEarly(t, true) -} - -func TestMultistreamHeader(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - p1 := tu.RandPeerNetParamsOrFatal(t) - - l1, err := Listen(ctx, p1.Addr, p1.ID, p1.PrivKey) - if err != nil { - t.Fatal(err) - } - - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - go func() { - _, _ = l1.Accept() - }() - - con, err := net.Dial("tcp", l1.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer con.Close() - - err = msmux.SelectProtoOrFail(SecioTag, con) - if err != nil { - t.Fatal(err) - } -} - -func TestFailedAccept(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - p1 := tu.RandPeerNetParamsOrFatal(t) - - l1, err := Listen(ctx, p1.Addr, p1.ID, p1.PrivKey) - if err != nil { - t.Fatal(err) - } - - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - done := make(chan struct{}) - go func() { - defer close(done) - con, err := net.Dial("tcp", l1.Addr().String()) - if err != nil { - t.Error("first dial failed: ", err) - } - - // write some garbage - con.Write(bytes.Repeat([]byte{255}, 1000)) - - con.Close() - - con, err = net.Dial("tcp", l1.Addr().String()) - if err != nil { - t.Error("second dial failed: ", err) - } - defer con.Close() - - err = msmux.SelectProtoOrFail(SecioTag, con) - if err != nil { - t.Error("msmux select failed: ", err) - } - }() - - c, err := l1.Accept() - if err != nil { - t.Fatal("connections after a failed accept should still work: ", err) - } - - c.Close() - <-done -} - -func TestHangingAccept(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - p1 := tu.RandPeerNetParamsOrFatal(t) - - l1, err := Listen(ctx, p1.Addr, p1.ID, p1.PrivKey) - if err != nil { - t.Fatal(err) - } - - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - done := make(chan struct{}) - go func() { - defer close(done) - con, err := net.Dial("tcp", l1.Addr().String()) - if err != nil { - t.Error("first dial failed: ", err) - } - // hang this connection - defer con.Close() - - // ensure that the first conn hits first - time.Sleep(time.Millisecond * 50) - - con2, err := net.Dial("tcp", l1.Addr().String()) - if err != nil { - t.Error("second dial failed: ", err) - } - defer con2.Close() - - err = msmux.SelectProtoOrFail(SecioTag, con2) - if err != nil { - t.Error("msmux select failed: ", err) - } - - _, err = con2.Write([]byte("test")) - if err != nil { - t.Error("con write failed: ", err) - } - }() - - c, err := l1.Accept() - if err != nil { - t.Fatal("connections after a failed accept should still work: ", err) - } - - c.Close() - <-done -} - -// This test kicks off N (=300) concurrent dials, which wait d (=20ms) seconds before failing. -// That wait holds up the handshake (multistream AND crypto), which will happen BEFORE -// l1.Accept() returns a connection. This test checks that the handshakes all happen -// concurrently in the listener side, and not sequentially. This ensures that a hanging dial -// will not block the listener from accepting other dials concurrently. -func TestConcurrentAccept(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - p1 := tu.RandPeerNetParamsOrFatal(t) - - l1, err := Listen(ctx, p1.Addr, p1.ID, p1.PrivKey) - if err != nil { - t.Fatal(err) - } - - n := 300 - delay := time.Millisecond * 20 - if runtime.GOOS == "darwin" { - n = 100 - } - - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - var wg sync.WaitGroup - for i := 0; i < n; i++ { - wg.Add(1) +var _ = Describe("dialing", func() { + It("errors when it can't dial the raw connection", func() { + p := randPeerNetParams() + d := getDialer(p.ID, p.PrivKey, p.Addr) + raddr, err := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/0") + Expect(err).ToNot(HaveOccurred()) + _, err = d.Dial(context.Background(), raddr, p.ID) + Expect(err).To(HaveOccurred()) + }) + + It("returns immediately when the context is canceled", func() { + p1 := randPeerNetParams() + tptList, err := getTransport(p1.Addr).Listen(p1.Addr) + Expect(err).ToNot(HaveOccurred()) + defer tptList.Close() + + dialed := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) go func() { - defer wg.Done() - con, err := net.Dial("tcp", l1.Addr().String()) - if err != nil { - log.Error(err) - t.Error("first dial failed: ", err) - return - } - // hang this connection - defer con.Close() - - time.Sleep(delay) - err = msmux.SelectProtoOrFail(SecioTag, con) - if err != nil { - t.Error(err) - } - }() - } - - before := time.Now() - for i := 0; i < n; i++ { - c, err := l1.Accept() - if err != nil { - t.Fatal("connections after a failed accept should still work: ", err) - } - - c.Close() - } - - limit := delay * time.Duration(n) - took := time.Since(before) - if took > limit { - t.Fatal("took too long!") - } - log.Errorf("took: %s (less than %s)", took, limit) - l1.Close() - wg.Wait() - cancel() - - time.Sleep(time.Millisecond * 100) - - err = grc.CheckForLeaks(goroFilter) - if err != nil { - t.Fatal(err) - } -} - -func TestConnectionTimeouts(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - old := NegotiateReadTimeout - NegotiateReadTimeout = time.Second * 5 - defer func() { NegotiateReadTimeout = old }() - - p1 := tu.RandPeerNetParamsOrFatal(t) - - l1, err := Listen(ctx, p1.Addr, p1.ID, p1.PrivKey) - if err != nil { - t.Fatal(err) - } - - n := 100 - if runtime.GOOS == "darwin" { - n = 50 - } - - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - var wg sync.WaitGroup - for i := 0; i < n; i++ { - wg.Add(1) - go func() { - defer wg.Done() - con, err := net.Dial("tcp", l1.Addr().String()) - if err != nil { - log.Error(err) - t.Error("first dial failed: ", err) - return - } - defer con.Close() - - // hang this connection until timeout - io.ReadFull(con, make([]byte, 1000)) + defer GinkgoRecover() + p2 := randPeerNetParams() + d := getDialer(p2.ID, p2.PrivKey, p2.Addr) + _, err = d.Dial(ctx, tptList.Multiaddr(), p2.ID) + Expect(err).To(MatchError(context.Canceled)) + close(dialed) }() - } - - // wait to make sure the hanging dials have started - time.Sleep(time.Millisecond * 50) - - good_n := 20 - for i := 0; i < good_n; i++ { - wg.Add(1) - go func() { - defer wg.Done() - con, err := net.Dial("tcp", l1.Addr().String()) - if err != nil { - log.Error(err) - t.Error("first dial failed: ", err) - return - } - defer con.Close() - - // dial these ones through - err = msmux.SelectProtoOrFail(SecioTag, con) - if err != nil { - t.Error(err) - } - }() - } - - before := time.Now() - for i := 0; i < good_n; i++ { - c, err := l1.Accept() - if err != nil { - t.Fatal("connections during hung dials should still work: ", err) - } - - c.Close() - } - - took := time.Since(before) - - if took > time.Second*5 { - t.Fatal("hanging dials shouldnt block good dials") - } - - wg.Wait() - - go func() { - con, err := net.Dial("tcp", l1.Addr().String()) - if err != nil { - log.Error(err) - t.Error("first dial failed: ", err) - return - } - defer con.Close() - - // dial these ones through - err = msmux.SelectProtoOrFail(SecioTag, con) - if err != nil { - t.Error(err) - } - }() - - // make sure we can dial in still after a bunch of timeouts - con, err := l1.Accept() - if err != nil { - t.Fatal(err) - } - - con.Close() - l1.Close() - cancel() - - time.Sleep(time.Millisecond * 100) - - err = grc.CheckForLeaks(goroFilter) - if err != nil { - t.Fatal(err) - } -} - -func TestForcePNet(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ipnet.ForcePrivateNetwork = true - defer func() { - ipnet.ForcePrivateNetwork = false - }() - - p := tu.RandPeerNetParamsOrFatal(t) - list, err := tcpt.NewTCPTransport().Listen(p.Addr) - if err != nil { - t.Fatal(err) - } - - _, err = WrapTransportListenerWithProtector(ctx, list, p.ID, p.PrivKey, nil) - if err != ipnet.ErrNotInPrivateNetwork { - t.Fatal("Wrong error, expected error lack of protector") - } -} - -type fakeProtector struct { - used bool -} - -func (f *fakeProtector) Fingerprint() []byte { - return make([]byte, 32) -} - -func (f *fakeProtector) Protect(c transport.Conn) (transport.Conn, error) { - f.used = true - return &rot13Crypt{c}, nil -} - -type rot13Crypt struct { - transport.Conn -} - -func (r *rot13Crypt) Read(b []byte) (int, error) { - n, err := r.Conn.Read(b) - if err != nil { - return n, err - } - - for i, _ := range b { - b[i] = byte((uint8(b[i]) - 13) & 0xff) - } - return n, err -} - -func (r *rot13Crypt) Write(b []byte) (int, error) { - for i, _ := range b { - b[i] = byte((uint8(b[i]) + 13) & 0xff) - } - return r.Conn.Write(b) -} - -func TestPNetIsUsed(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - p1 := tu.RandPeerNetParamsOrFatal(t) - p2 := tu.RandPeerNetParamsOrFatal(t) - - p1Protec := &fakeProtector{} - - list, err := tcpt.NewTCPTransport().Listen(p1.Addr) - if err != nil { - t.Fatal(err) - } - - l1, err := WrapTransportListenerWithProtector(ctx, list, p1.ID, p1.PrivKey, p1Protec) - if err != nil { - t.Fatal(err) - } - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - d2 := NewDialer(p2.ID, p2.PrivKey, nil) - d2.Protector = &fakeProtector{} - - d2.AddDialer(dialer(t, p2.Addr)) - _, err = d2.Dial(ctx, p1.Addr, p1.ID) - if err != nil { - t.Fatal(err) - } - - _, err = l1.Accept() - if err != nil { - t.Fatal(err) - } - - if !p1Protec.used { - t.Error("Listener did not use protector for the connection") - } - - if !d2.Protector.(*fakeProtector).used { - t.Error("Dialer did not use protector for the connection") - } -} + Consistently(dialed).ShouldNot(BeClosed()) + cancel() + Eventually(dialed).Should(BeClosed()) + }) + + It("times out during multistream selection", func() { + old := DialTimeout + DialTimeout = time.Second + defer func() { DialTimeout = old }() + + p1 := randPeerNetParams() + p2 := randPeerNetParams() + tptList, err := getTransport(p1.Addr).Listen(p1.Addr) + Expect(err).ToNot(HaveOccurred()) + defer tptList.Close() + + d := getDialer(p2.ID, p2.PrivKey, p2.Addr) + _, err = d.Dial(context.Background(), tptList.Multiaddr(), p2.ID) + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) + Expect(err.(net.Error).Temporary()).To(BeTrue()) + }) +}) diff --git a/listen.go b/listen.go index f2e24c2..b36b02d 100644 --- a/listen.go +++ b/listen.go @@ -15,8 +15,9 @@ import ( iconn "github.com/libp2p/go-libp2p-interface-conn" ipnet "github.com/libp2p/go-libp2p-interface-pnet" peer "github.com/libp2p/go-libp2p-peer" - transport "github.com/libp2p/go-libp2p-transport" + tpt "github.com/libp2p/go-libp2p-transport" filter "github.com/libp2p/go-maddr-filter" + smux "github.com/libp2p/go-stream-muxer" ma "github.com/multiformats/go-multiaddr" msmux "github.com/multiformats/go-multistream" ) @@ -27,21 +28,23 @@ const ( ) var ( - connAcceptBuffer = 32 - NegotiateReadTimeout = time.Second * 60 + connAcceptBuffer = 32 + ConnAcceptTimeout = 60 * time.Second ) // ConnWrapper is any function that wraps a raw multiaddr connection -type ConnWrapper func(transport.Conn) transport.Conn +type ConnWrapper func(tpt.Conn) tpt.Conn // listener is an object that can accept connections. It implements Listener type listener struct { - transport.Listener + tpt.Listener local peer.ID // LocalPeer is the identity of the local Peer privk ic.PrivKey // private key to use to initialize secure conns protec ipnet.Protector + streamMuxer smux.Transport + filters *filter.Filters wrapper ConnWrapper @@ -51,11 +54,13 @@ type listener struct { mux *msmux.MultistreamMuxer - incoming chan connErr + incoming chan connOrErr ctx context.Context } +var _ iconn.Listener = &listener{} + func (l *listener) teardown() error { defer log.Debugf("listener closed: %s %s", l.local, l.Multiaddr()) return l.Listener.Close() @@ -74,41 +79,22 @@ func (l *listener) SetAddrFilters(fs *filter.Filters) { l.filters = fs } -type connErr struct { - conn transport.Conn +type connOrErr struct { + conn iconn.Conn err error } // Accept waits for and returns the next connection to the listener. -// Note that unfortunately this -func (l *listener) Accept() (transport.Conn, error) { - for con := range l.incoming { - if con.err != nil { - return nil, con.err - } - - c, err := newSingleConn(l.ctx, l.local, "", con.conn) - if err != nil { - con.conn.Close() - if l.catcher.IsTemporary(err) { - continue - } - return nil, err - } +func (l *listener) Accept() (iconn.Conn, error) { + if l.privk == nil || !iconn.EncryptConnections { + log.Warningf("listener %s listening INSECURELY!", l) + } - if l.privk == nil || !iconn.EncryptConnections { - log.Warning("listener %s listening INSECURELY!", l) - return c, nil - } - sc, err := newSecureConn(l.ctx, l.privk, c) - if err != nil { - con.conn.Close() - log.Infof("ignoring conn we failed to secure: %s %s", err, c) - continue - } - return sc, nil + c, ok := <-l.incoming + if !ok { + return nil, fmt.Errorf("listener is closed") } - return nil, fmt.Errorf("listener is closed") + return c.conn, c.err } func (l *listener) Addr() net.Addr { @@ -149,64 +135,83 @@ func (l *listener) handleIncoming() { defer wg.Done() for { - maconn, err := l.Listener.Accept() + conn, err := l.Listener.Accept() if err != nil { if l.catcher.IsTemporary(err) { continue } - - l.incoming <- connErr{err: err} + l.incoming <- connOrErr{err: err} return } - log.Debugf("listener %s got connection: %s <---> %s", l, maconn.LocalMultiaddr(), maconn.RemoteMultiaddr()) - - if l.filters != nil && l.filters.AddrBlocked(maconn.RemoteMultiaddr()) { - log.Debugf("blocked connection from %s", maconn.RemoteMultiaddr()) - maconn.Close() + if l.filters != nil && l.filters.AddrBlocked(conn.RemoteMultiaddr()) { + log.Debugf("blocked connection from %s", conn.RemoteMultiaddr()) + conn.Close() continue } - // If we have a wrapper func, wrap this conn - if l.wrapper != nil { - maconn = l.wrapper(maconn) - } + + log.Debugf("listener %s got connection: %s <---> %s", l, conn.LocalMultiaddr(), conn.RemoteMultiaddr()) wg.Add(1) go func() { defer wg.Done() - if l.protec != nil { - pc, err := l.protec.Protect(maconn) - if err != nil { - maconn.Close() - log.Warning("protector failed: ", err) + + ctx, cancel := context.WithTimeout(l.ctx, ConnAcceptTimeout) + defer cancel() + + done := make(chan struct{}) + var singleConn iconn.Conn + go func() { + defer close(done) + + if l.protec != nil { + pc, err := l.protec.Protect(conn) + if err != nil { + conn.Close() + log.Warning("protector failed: ", err) + return + } + conn = pc } - maconn = pc - } - maconn.SetReadDeadline(time.Now().Add(NegotiateReadTimeout)) - _, _, err = l.mux.Negotiate(maconn) - if err != nil { - log.Warning("incoming conn: negotiation of crypto protocol failed: ", err) - maconn.Close() - return - } + // If we have a wrapper func, wrap this conn + if l.wrapper != nil { + conn = l.wrapper(conn) + } - // clear read readline - maconn.SetReadDeadline(time.Time{}) + if _, _, err := l.mux.Negotiate(conn); err != nil { + log.Warning("incoming conn: negotiation of crypto protocol failed: ", err) + conn.Close() + return + } - l.incoming <- connErr{conn: maconn} + singleConn, err = newSingleConn(ctx, l.local, "", l.privk, conn, l.streamMuxer, true) + if err != nil { + log.Warning("connection setup failed: ", err) + conn.Close() + } + }() + + select { + case <-ctx.Done(): + log.Warning("incoming conn: conn not established in time:", ctx.Err().Error()) + conn.Close() + case <-done: // connection completed (or errored) + if singleConn != nil { + l.incoming <- connOrErr{conn: singleConn} + } + } }() } } -func WrapTransportListener(ctx context.Context, ml transport.Listener, local peer.ID, +func WrapTransportListener(ctx context.Context, ml tpt.Listener, local peer.ID, pstpt smux.Transport, sk ic.PrivKey) (iconn.Listener, error) { - return WrapTransportListenerWithProtector(ctx, ml, local, sk, nil) + return WrapTransportListenerWithProtector(ctx, ml, local, sk, pstpt, nil) } -func WrapTransportListenerWithProtector(ctx context.Context, ml transport.Listener, local peer.ID, - sk ic.PrivKey, protec ipnet.Protector) (iconn.Listener, error) { - +func WrapTransportListenerWithProtector(ctx context.Context, ml tpt.Listener, local peer.ID, + sk ic.PrivKey, pstpt smux.Transport, protec ipnet.Protector) (iconn.Listener, error) { if protec == nil && ipnet.ForcePrivateNetwork { log.Error("tried to listen with no Private Network Protector but usage" + " of Private Networks is forced by the enviroment") @@ -214,13 +219,14 @@ func WrapTransportListenerWithProtector(ctx context.Context, ml transport.Listen } l := &listener{ - Listener: ml, - local: local, - privk: sk, - protec: protec, - mux: msmux.NewMultistreamMuxer(), - incoming: make(chan connErr, connAcceptBuffer), - ctx: ctx, + Listener: ml, + local: local, + privk: sk, + protec: protec, + mux: msmux.NewMultistreamMuxer(), + incoming: make(chan connOrErr, connAcceptBuffer), + ctx: ctx, + streamMuxer: pstpt, } l.proc = goprocessctx.WithContextAndTeardown(ctx, l.teardown) l.catcher.IsTemp = func(e error) bool { diff --git a/listen_test.go b/listen_test.go new file mode 100644 index 0000000..7ffffe5 --- /dev/null +++ b/listen_test.go @@ -0,0 +1,208 @@ +package conn + +import ( + "bytes" + "context" + "net" + "sync" + "time" + + tpt "github.com/libp2p/go-libp2p-transport" + filter "github.com/libp2p/go-maddr-filter" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Listener", func() { + Context("accepting connections", func() { + It("returns immediately when the context is cancelled", func() { + p1 := randPeerNetParams() + ctx, cancel := context.WithCancel(context.Background()) + l := getListener(ctx, p1) + + accepted := make(chan struct{}) + go func() { + _, _ = l.Accept() + close(accepted) + }() + Consistently(accepted).ShouldNot(BeClosed()) + cancel() + Eventually(accepted).Should(BeClosed()) + }) + + It("returns immediately when it is closed", func() { + p1 := randPeerNetParams() + l := getListener(context.Background(), p1) + + accepted := make(chan struct{}) + go func() { + _, _ = l.Accept() + close(accepted) + }() + Consistently(accepted).ShouldNot(BeClosed()) + l.Close() + Eventually(accepted).Should(BeClosed()) + }) + + It("continues accepting connections after one accept failed", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1 := randPeerNetParams() + p2 := randPeerNetParams() + + l1 := getListener(ctx, p1) + defer l1.Close() + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + c := dialRawConn(p2.Addr, l1.Multiaddr()) + defer c.Close() + // write some garbage. This will fail the protocol selection + _, err := c.Write(bytes.Repeat([]byte{255}, 1000)) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + + accepted := make(chan struct{}) + go func() { + defer GinkgoRecover() + c, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + c.Close() + close(accepted) + }() + + // make sure it doesn't accept the raw connection + Eventually(done).Should(BeClosed()) + Consistently(accepted).ShouldNot(BeClosed()) + + // now dial the real connection, and make sure it is accepted + d := getDialer(p2.ID, p2.PrivKey, p2.Addr) + _, err := d.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + + Eventually(accepted).Should(BeClosed()) + }) + + // This test kicks off N (=10) concurrent dials, which wait d (=20ms) seconds before failing. + // That wait holds up the handshake (multistream AND crypto), which will happen BEFORE + // l1.Accept() returns a connection. This test checks that the handshakes all happen + // concurrently in the listener side, and not sequentially. This ensures that a hanging dial + // will not block the listener from accepting other dials concurrently. + It("accepts concurrently", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1 := randPeerNetParams() + p2 := randPeerNetParams() + + l1 := getListener(ctx, p1) + defer l1.Close() + + n := 10 + delay := 50 * time.Millisecond + + accepted := make(chan struct{}) + go func() { + defer GinkgoRecover() + for i := 0; i < n; i++ { + conn, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + defer conn.Close() + } + close(accepted) + }() + + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + defer GinkgoRecover() + defer wg.Done() + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + d2.Wrapper = func(c tpt.Conn) tpt.Conn { + time.Sleep(delay) + return c + } + before := time.Now() + _, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + // make sure the delay actually worked + Expect(time.Now()).To(BeTemporally(">", before.Add(delay))) + }() + } + + wg.Wait() + // the Eventually timeout is 100ms, which is a lot smaller than n*delay = 500ms + Eventually(accepted).Should(BeClosed()) + }) + + Context("address filters", func() { + It("doesn't accept connections from filtered addresses", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1 := randPeerNetParams() + p2 := randPeerNetParams() + + filt := filter.NewFilters() + _, ipnet, err := net.ParseCIDR("127.0.1.2/16") + Expect(err).ToNot(HaveOccurred()) + filt.AddDialFilter(ipnet) + Expect(filt.AddrBlocked(p2.Addr)).To(BeTrue()) + + l := getListener(ctx, p1) + defer l.Close() + l.SetAddrFilters(filt) + + accepted := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, _ = l.Accept() + close(accepted) + }() + + d := getDialer(p2.ID, p2.PrivKey, p2.Addr) + _, err = d.Dial(ctx, p1.Addr, p1.ID) + Expect(err).To(HaveOccurred()) + Eventually(accepted).ShouldNot(BeClosed()) + }) + + It("accepts connections from addresses that are not filtered", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1 := randPeerNetParams() + p2 := randPeerNetParams() + + filt := filter.NewFilters() + _, ipnet, err := net.ParseCIDR("192.168.1.2/16") + Expect(err).ToNot(HaveOccurred()) + filt.AddDialFilter(ipnet) + Expect(filt.AddrBlocked(p2.Addr)).To(BeFalse()) + + l := getListener(ctx, p1) + defer l.Close() + l.SetAddrFilters(filt) + + accepted := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := l.Accept() + Expect(err).ToNot(HaveOccurred()) + close(accepted) + }() + + d := getDialer(p2.ID, p2.PrivKey, p2.Addr) + c2, err := d.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + defer c2.Close() + Eventually(accepted).Should(BeClosed()) + time.Sleep(time.Second) + }) + }) + }) +}) diff --git a/protector_test.go b/protector_test.go new file mode 100644 index 0000000..0c0b6ae --- /dev/null +++ b/protector_test.go @@ -0,0 +1,156 @@ +package conn + +import ( + "context" + "errors" + + ipnet "github.com/libp2p/go-libp2p-interface-pnet" + tpt "github.com/libp2p/go-libp2p-transport" + tcpt "github.com/libp2p/go-tcp-transport" + tu "github.com/libp2p/go-testutil" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type fakeProtector struct { + used bool +} + +func (f *fakeProtector) Fingerprint() []byte { + return make([]byte, 32) +} + +func (f *fakeProtector) Protect(c tpt.Conn) (tpt.Conn, error) { + f.used = true + return &rot13Crypt{c}, nil +} + +type rot13Crypt struct { + tpt.Conn +} + +func (r *rot13Crypt) Read(b []byte) (int, error) { + n, err := r.Conn.Read(b) + for i := 0; i < n; i++ { + b[i] = b[i] - 13 + } + return n, err +} + +func (r *rot13Crypt) Write(b []byte) (int, error) { + p := make([]byte, len(b)) // write MUST NOT modify b + for i := range b { + p[i] = b[i] + 13 + } + return r.Conn.Write(p) +} + +var errProtect = errors.New("protecting failed") + +type erroringProtector struct{} + +func (f *erroringProtector) Fingerprint() []byte { + return make([]byte, 32) +} + +func (f *erroringProtector) Protect(c tpt.Conn) (tpt.Conn, error) { + return nil, errProtect +} + +var _ = Describe("using the protector", func() { + It("uses a protector for single-stream connections", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1 := randPeerNetParams() + p2 := randPeerNetParams() + p1Protec := &fakeProtector{} + p2Protec := &fakeProtector{} + + list, err := tcpt.NewTCPTransport().Listen(p1.Addr) + Expect(err).ToNot(HaveOccurred()) + l1, err := WrapTransportListenerWithProtector(ctx, list, p1.ID, p1.PrivKey, streamMuxer, p1Protec) + Expect(err).ToNot(HaveOccurred()) + p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. + + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + d2.Protector = p2Protec + + accepted := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + close(accepted) + }() + + c2, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + defer c2.Close() + + Expect(p2Protec.used).To(BeTrue()) + Eventually(accepted).Should(BeClosed()) + Expect(p1Protec.used).To(BeTrue()) + }) + + Context("forcing a private network", func() { + var p1, p2 *tu.PeerNetParams + var list tpt.Listener + + BeforeEach(func() { + ipnet.ForcePrivateNetwork = true + p1 = randPeerNetParams() + p2 = randPeerNetParams() + var err error + list, err = tcpt.NewTCPTransport().Listen(p1.Addr) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + ipnet.ForcePrivateNetwork = false + }) + + It("errors if no protector is specified for the listener", func() { + _, err := WrapTransportListenerWithProtector(context.Background(), list, p1.ID, p1.PrivKey, streamMuxer, nil) + Expect(err).To(Equal(ipnet.ErrNotInPrivateNetwork)) + }) + + It("errors if no protector is specified for the dialer", func() { + d := getDialer(p2.ID, p2.PrivKey, p2.Addr) + _, err := d.Dial(context.Background(), list.Multiaddr(), p1.ID) + Expect(err).To(Equal(ipnet.ErrNotInPrivateNetwork)) + }) + }) + + It("correctly handles a protected that errors", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1 := randPeerNetParams() + p2 := randPeerNetParams() + p1Protec := &erroringProtector{} + p2Protec := &erroringProtector{} + + list, err := tcpt.NewTCPTransport().Listen(p1.Addr) + Expect(err).ToNot(HaveOccurred()) + l1, err := WrapTransportListenerWithProtector(ctx, list, p1.ID, p1.PrivKey, streamMuxer, p1Protec) + Expect(err).ToNot(HaveOccurred()) + p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. + + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + d2.Protector = p2Protec + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, _ = l1.Accept() + close(done) + }() + + _, err = d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).To(MatchError(errProtect)) + // make sure no connection was accepted + Consistently(done).ShouldNot(BeClosed()) + }) +}) diff --git a/secure_conn.go b/secure_conn.go index d726d14..ee01dc5 100644 --- a/secure_conn.go +++ b/secure_conn.go @@ -1,14 +1,9 @@ package conn import ( - "context" - "errors" "net" "time" - ic "github.com/libp2p/go-libp2p-crypto" - iconn "github.com/libp2p/go-libp2p-interface-conn" - peer "github.com/libp2p/go-libp2p-peer" secio "github.com/libp2p/go-libp2p-secio" tpt "github.com/libp2p/go-libp2p-transport" ma "github.com/multiformats/go-multiaddr" @@ -16,113 +11,50 @@ import ( // secureConn wraps another Conn object with an encrypted channel. type secureConn struct { - insecure iconn.Conn // the wrapped conn + insecure tpt.Conn // the wrapped conn secure secio.Session // secure Session } -// newConn constructs a new connection -func newSecureConn(ctx context.Context, sk ic.PrivKey, insecure iconn.Conn) (iconn.Conn, error) { +var _ tpt.Conn = &secureConn{} - if insecure == nil { - return nil, errors.New("insecure is nil") - } - if insecure.LocalPeer() == "" { - return nil, errors.New("insecure.LocalPeer() is nil") - } - if sk == nil { - return nil, errors.New("private key is nil") - } - - // NewSession performs the secure handshake, which takes multiple RTT - sessgen := secio.SessionGenerator{LocalID: insecure.LocalPeer(), PrivateKey: sk} - secure, err := sessgen.NewSession(ctx, insecure) - if err != nil { - return nil, err - } +func (c *secureConn) Read(buf []byte) (int, error) { + return c.secure.ReadWriter().Read(buf) +} - conn := &secureConn{ - insecure: insecure, - secure: secure, - } - return conn, nil +func (c *secureConn) Write(buf []byte) (int, error) { + return c.secure.ReadWriter().Write(buf) } func (c *secureConn) Close() error { return c.secure.Close() } -// ID is an identifier unique to this connection. -func (c *secureConn) ID() string { - return iconn.ID(c) -} - -func (c *secureConn) String() string { - return iconn.String(c, "secureConn") -} - func (c *secureConn) LocalAddr() net.Addr { return c.insecure.LocalAddr() } -func (c *secureConn) RemoteAddr() net.Addr { - return c.insecure.RemoteAddr() -} - -func (c *secureConn) SetDeadline(t time.Time) error { - return c.insecure.SetDeadline(t) -} - -func (c *secureConn) SetReadDeadline(t time.Time) error { - return c.insecure.SetReadDeadline(t) -} - -func (c *secureConn) SetWriteDeadline(t time.Time) error { - return c.insecure.SetWriteDeadline(t) -} - -// LocalMultiaddr is the Multiaddr on this side func (c *secureConn) LocalMultiaddr() ma.Multiaddr { return c.insecure.LocalMultiaddr() } -// RemoteMultiaddr is the Multiaddr on the remote side -func (c *secureConn) RemoteMultiaddr() ma.Multiaddr { - return c.insecure.RemoteMultiaddr() -} - -// LocalPeer is the Peer on this side -func (c *secureConn) LocalPeer() peer.ID { - return c.secure.LocalPeer() -} - -// RemotePeer is the Peer on the remote side -func (c *secureConn) RemotePeer() peer.ID { - return c.secure.RemotePeer() -} - -// LocalPrivateKey is the public key of the peer on this side -func (c *secureConn) LocalPrivateKey() ic.PrivKey { - return c.secure.LocalPrivateKey() +func (c *secureConn) RemoteAddr() net.Addr { + return c.insecure.RemoteAddr() } -// RemotePubKey is the public key of the peer on the remote side -func (c *secureConn) RemotePublicKey() ic.PubKey { - return c.secure.RemotePublicKey() +func (c *secureConn) RemoteMultiaddr() ma.Multiaddr { + return c.insecure.RemoteMultiaddr() } -// Read reads data, net.Conn style -func (c *secureConn) Read(buf []byte) (int, error) { - return c.secure.ReadWriter().Read(buf) +func (c *secureConn) SetDeadline(t time.Time) error { + return c.insecure.SetDeadline(t) } -// Write writes data, net.Conn style -func (c *secureConn) Write(buf []byte) (int, error) { - return c.secure.ReadWriter().Write(buf) +func (c *secureConn) SetReadDeadline(t time.Time) error { + return c.insecure.SetDeadline(t) } -// ReleaseMsg releases a buffer -func (c *secureConn) ReleaseMsg(m []byte) { - c.secure.ReadWriter().ReleaseMsg(m) +func (c *secureConn) SetWriteDeadline(t time.Time) error { + return c.insecure.SetDeadline(t) } func (c *secureConn) Transport() tpt.Transport { diff --git a/secure_conn_test.go b/secure_conn_test.go deleted file mode 100644 index 80fb477..0000000 --- a/secure_conn_test.go +++ /dev/null @@ -1,212 +0,0 @@ -package conn - -import ( - "bytes" - "context" - "runtime" - "sync" - "testing" - "time" - - ic "github.com/libp2p/go-libp2p-crypto" - iconn "github.com/libp2p/go-libp2p-interface-conn" - travis "github.com/libp2p/go-testutil/ci/travis" -) - -func upgradeToSecureConn(t *testing.T, ctx context.Context, sk ic.PrivKey, c iconn.Conn) (iconn.Conn, error) { - if c, ok := c.(*secureConn); ok { - return c, nil - } - - // shouldn't happen, because dial + listen already return secure conns. - s, err := newSecureConn(ctx, sk, c) - if err != nil { - return nil, err - } - - // need to read + write, as that's what triggers the handshake. - h := []byte("hello") - if _, err := s.Write(h); err != nil { - return nil, err - } - if _, err := s.Read(h); err != nil { - return nil, err - } - return s, nil -} - -func secureHandshake(t *testing.T, ctx context.Context, sk ic.PrivKey, c iconn.Conn, done chan error) { - _, err := upgradeToSecureConn(t, ctx, sk, c) - done <- err -} - -func TestSecureSimple(t *testing.T) { - // t.Skip("Skipping in favor of another test") - - numMsgs := 100 - if testing.Short() { - numMsgs = 10 - } - - ctx := context.Background() - c1, c2, p1, p2 := setupSingleConn(t, ctx) - - done := make(chan error) - go secureHandshake(t, ctx, p1.PrivKey, c1, done) - go secureHandshake(t, ctx, p2.PrivKey, c2, done) - - for i := 0; i < 2; i++ { - if err := <-done; err != nil { - t.Fatal(err) - } - } - - for i := 0; i < numMsgs; i++ { - testOneSendRecv(t, c1, c2) - testOneSendRecv(t, c2, c1) - } - - c1.Close() - c2.Close() -} - -func TestSecureClose(t *testing.T) { - // t.Skip("Skipping in favor of another test") - - ctx := context.Background() - c1, c2, p1, p2 := setupSingleConn(t, ctx) - - done := make(chan error) - go secureHandshake(t, ctx, p1.PrivKey, c1, done) - go secureHandshake(t, ctx, p2.PrivKey, c2, done) - - for i := 0; i < 2; i++ { - if err := <-done; err != nil { - t.Fatal(err) - } - } - - testOneSendRecv(t, c1, c2) - - c1.Close() - testNotOneSendRecv(t, c1, c2) - - c2.Close() - testNotOneSendRecv(t, c1, c2) - testNotOneSendRecv(t, c2, c1) - -} - -func TestSecureCancelHandshake(t *testing.T) { - // t.Skip("Skipping in favor of another test") - - ctx, cancel := context.WithCancel(context.Background()) - c1, c2, p1, p2 := setupSingleConn(t, ctx) - - done := make(chan error) - go secureHandshake(t, ctx, p1.PrivKey, c1, done) - time.Sleep(time.Millisecond) - cancel() // cancel ctx - go secureHandshake(t, ctx, p2.PrivKey, c2, done) - - for i := 0; i < 2; i++ { - if err := <-done; err == nil { - t.Error("cancel should've errored out") - } - } -} - -func TestSecureHandshakeFailsWithWrongKeys(t *testing.T) { - // t.Skip("Skipping in favor of another test") - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - c1, c2, p1, p2 := setupSingleConn(t, ctx) - - done := make(chan error) - go secureHandshake(t, ctx, p2.PrivKey, c1, done) - go secureHandshake(t, ctx, p1.PrivKey, c2, done) - - for i := 0; i < 2; i++ { - if err := <-done; err == nil { - t.Fatal("wrong keys should've errored out.") - } - } -} - -func TestSecureCloseLeak(t *testing.T) { - // t.Skip("Skipping in favor of another test") - - if testing.Short() { - t.SkipNow() - } - if travis.IsRunning() { - t.Skip("this doesn't work well on travis") - } - - runPair := func(c1, c2 iconn.Conn, num int) { - mc1 := msgioWrap(c1) - mc2 := msgioWrap(c2) - - log.Debugf("runPair %d", num) - - for i := 0; i < num; i++ { - log.Debugf("runPair iteration %d", i) - b1 := []byte("beep") - mc1.WriteMsg(b1) - b2, err := mc2.ReadMsg() - if err != nil { - panic(err) - } - if !bytes.Equal(b1, b2) { - panic("bytes not equal") - } - - b2 = []byte("beep") - mc2.WriteMsg(b2) - b1, err = mc1.ReadMsg() - if err != nil { - panic(err) - } - if !bytes.Equal(b1, b2) { - panic("bytes not equal") - } - - time.Sleep(time.Microsecond * 5) - } - } - - var cons = 5 - var msgs = 50 - log.Debugf("Running %d connections * %d msgs.\n", cons, msgs) - - var wg sync.WaitGroup - for i := 0; i < cons; i++ { - wg.Add(1) - - ctx, cancel := context.WithCancel(context.Background()) - c1, c2, _, _ := setupSecureConn(t, ctx) - go func(c1, c2 iconn.Conn) { - - defer func() { - c1.Close() - c2.Close() - cancel() - wg.Done() - }() - - runPair(c1, c2, msgs) - }(c1, c2) - } - - log.Debugf("Waiting...") - wg.Wait() - // done! - - time.Sleep(time.Millisecond * 150) - ngr := runtime.NumGoroutine() - if ngr > 25 { - // panic("uncomment me to debug") - t.Fatal("leaking goroutines:", ngr) - } -}