From e4d7c8573d21c7d17031e74796dfaf00673fb187 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 24 Mar 2017 11:38:41 +0700 Subject: [PATCH 01/10] more consistent import names --- dial.go | 16 ++++++++-------- listen.go | 14 +++++++------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/dial.go b/dial.go index 446687b..735c5a7 100644 --- a/dial.go +++ b/dial.go @@ -13,13 +13,13 @@ import ( ipnet "github.com/libp2p/go-libp2p-interface-pnet" lgbl "github.com/libp2p/go-libp2p-loggables" peer "github.com/libp2p/go-libp2p-peer" - transport "github.com/libp2p/go-libp2p-transport" + tpt "github.com/libp2p/go-libp2p-transport" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr-net" msmux "github.com/multiformats/go-multistream" ) -type WrapFunc func(transport.Conn) transport.Conn +type WrapFunc func(tpt.Conn) tpt.Conn // Dialer is an object that can open connections. We could have a "convenience" // Dial function as before, but it would have many arguments, as dialing is @@ -33,7 +33,7 @@ type Dialer struct { // Dialers are the sub-dialers usable by this dialer // selected in order based on the address being dialed - Dialers []transport.Dialer + Dialers []tpt.Dialer // PrivateKey used to initialize a secure connection. // Warning: if PrivateKey is nil, connection will not be secured. @@ -47,7 +47,7 @@ type Dialer struct { // Wrapper to wrap the raw connection (optional) Wrapper WrapFunc - fallback transport.Dialer + fallback tpt.Dialer } func NewDialer(p peer.ID, pk ci.PrivKey, wrap WrapFunc) *Dialer { @@ -55,7 +55,7 @@ func NewDialer(p peer.ID, pk ci.PrivKey, wrap WrapFunc) *Dialer { LocalPeer: p, PrivateKey: pk, Wrapper: wrap, - fallback: new(transport.FallbackDialer), + fallback: new(tpt.FallbackDialer), } } @@ -169,12 +169,12 @@ func (d *Dialer) Dial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) ( return connOut, nil } -func (d *Dialer) AddDialer(pd transport.Dialer) { +func (d *Dialer) AddDialer(pd tpt.Dialer) { d.Dialers = append(d.Dialers, pd) } // returns dialer that can dial the given address -func (d *Dialer) subDialerForAddr(raddr ma.Multiaddr) transport.Dialer { +func (d *Dialer) subDialerForAddr(raddr ma.Multiaddr) tpt.Dialer { for _, pd := range d.Dialers { if pd.Matches(raddr) { return pd @@ -189,7 +189,7 @@ func (d *Dialer) subDialerForAddr(raddr ma.Multiaddr) transport.Dialer { } // rawConnDial dials the underlying net.Conn + manet.Conns -func (d *Dialer) rawConnDial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) (transport.Conn, error) { +func (d *Dialer) rawConnDial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) (tpt.Conn, error) { if strings.HasPrefix(raddr.String(), "/ip4/0.0.0.0") { log.Event(ctx, "connDialZeroAddr", lgbl.Dial("conn", d.LocalPeer, remote, nil, raddr)) return nil, fmt.Errorf("Attempted to connect to zero address: %s", raddr) diff --git a/listen.go b/listen.go index f2e24c2..1ba7d9b 100644 --- a/listen.go +++ b/listen.go @@ -15,7 +15,7 @@ import ( iconn "github.com/libp2p/go-libp2p-interface-conn" ipnet "github.com/libp2p/go-libp2p-interface-pnet" peer "github.com/libp2p/go-libp2p-peer" - transport "github.com/libp2p/go-libp2p-transport" + tpt "github.com/libp2p/go-libp2p-transport" filter "github.com/libp2p/go-maddr-filter" ma "github.com/multiformats/go-multiaddr" msmux "github.com/multiformats/go-multistream" @@ -32,11 +32,11 @@ var ( ) // ConnWrapper is any function that wraps a raw multiaddr connection -type ConnWrapper func(transport.Conn) transport.Conn +type ConnWrapper func(tpt.Conn) tpt.Conn // listener is an object that can accept connections. It implements Listener type listener struct { - transport.Listener + tpt.Listener local peer.ID // LocalPeer is the identity of the local Peer privk ic.PrivKey // private key to use to initialize secure conns @@ -75,13 +75,13 @@ func (l *listener) SetAddrFilters(fs *filter.Filters) { } type connErr struct { - conn transport.Conn + conn tpt.Conn err error } // Accept waits for and returns the next connection to the listener. // Note that unfortunately this -func (l *listener) Accept() (transport.Conn, error) { +func (l *listener) Accept() (tpt.Conn, error) { for con := range l.incoming { if con.err != nil { return nil, con.err @@ -199,12 +199,12 @@ func (l *listener) handleIncoming() { } } -func WrapTransportListener(ctx context.Context, ml transport.Listener, local peer.ID, +func WrapTransportListener(ctx context.Context, ml tpt.Listener, local peer.ID, sk ic.PrivKey) (iconn.Listener, error) { return WrapTransportListenerWithProtector(ctx, ml, local, sk, nil) } -func WrapTransportListenerWithProtector(ctx context.Context, ml transport.Listener, local peer.ID, +func WrapTransportListenerWithProtector(ctx context.Context, ml tpt.Listener, local peer.ID, sk ic.PrivKey, protec ipnet.Protector) (iconn.Listener, error) { if protec == nil && ipnet.ForcePrivateNetwork { From bd7637af31376906a9b674a34fc6b0461e5452a7 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 15 Aug 2017 12:00:13 +0700 Subject: [PATCH 02/10] remove dead code Those functions aren't used anywhere, so this is not expected to break anything. They also don't belong into this package anyway. --- conn.go | 8 ------- dial.go | 66 --------------------------------------------------------- 2 files changed, 74 deletions(-) diff --git a/conn.go b/conn.go index 1a11590..f6bdc84 100644 --- a/conn.go +++ b/conn.go @@ -7,7 +7,6 @@ import ( "time" logging "github.com/ipfs/go-log" - mpool "github.com/jbenet/go-msgio/mpool" ic "github.com/libp2p/go-libp2p-crypto" iconn "github.com/libp2p/go-libp2p-interface-conn" lgbl "github.com/libp2p/go-libp2p-loggables" @@ -18,13 +17,6 @@ import ( var log = logging.Logger("conn") -// ReleaseBuffer puts the given byte array back into the buffer pool, -// first verifying that it is the correct size -func ReleaseBuffer(b []byte) { - log.Debugf("Releasing buffer! (cap,size = %d, %d)", cap(b), len(b)) - mpool.ByteSlicePool.Put(uint32(cap(b)), b) -} - // singleConn represents a single connection to another Peer (IPFS Node). type singleConn struct { local peer.ID diff --git a/dial.go b/dial.go index 735c5a7..17d2b70 100644 --- a/dial.go +++ b/dial.go @@ -3,11 +3,9 @@ package conn import ( "context" "fmt" - "math/rand" "strings" "time" - addrutil "github.com/libp2p/go-addr-util" ci "github.com/libp2p/go-libp2p-crypto" iconn "github.com/libp2p/go-libp2p-interface-conn" ipnet "github.com/libp2p/go-libp2p-interface-pnet" @@ -15,7 +13,6 @@ import ( peer "github.com/libp2p/go-libp2p-peer" tpt "github.com/libp2p/go-libp2p-transport" ma "github.com/multiformats/go-multiaddr" - manet "github.com/multiformats/go-multiaddr-net" msmux "github.com/multiformats/go-multistream" ) @@ -202,66 +199,3 @@ func (d *Dialer) rawConnDial(ctx context.Context, raddr ma.Multiaddr, remote pee return sd.DialContext(ctx, raddr) } - -func pickLocalAddr(laddrs []ma.Multiaddr, raddr ma.Multiaddr) (laddr ma.Multiaddr) { - if len(laddrs) < 1 { - return nil - } - - // make sure that we ONLY use local addrs that match the remote addr. - laddrs = manet.AddrMatch(raddr, laddrs) - if len(laddrs) < 1 { - return nil - } - - // make sure that we ONLY use local addrs that CAN dial the remote addr. - // filter out all the local addrs that aren't capable - raddrIPLayer := ma.Split(raddr)[0] - raddrIsLoopback := manet.IsIPLoopback(raddrIPLayer) - raddrIsLinkLocal := manet.IsIP6LinkLocal(raddrIPLayer) - laddrs = addrutil.FilterAddrs(laddrs, func(a ma.Multiaddr) bool { - laddrIPLayer := ma.Split(a)[0] - laddrIsLoopback := manet.IsIPLoopback(laddrIPLayer) - laddrIsLinkLocal := manet.IsIP6LinkLocal(laddrIPLayer) - if laddrIsLoopback { // our loopback addrs can only dial loopbacks. - return raddrIsLoopback - } - if laddrIsLinkLocal { - return raddrIsLinkLocal // out linklocal addrs can only dial link locals. - } - return true - }) - - // TODO pick with a good heuristic - // we use a random one for now to prevent bad addresses from making nodes unreachable - // with a random selection, multiple tries may work. - return laddrs[rand.Intn(len(laddrs))] -} - -// MultiaddrProtocolsMatch returns whether two multiaddrs match in protocol stacks. -func MultiaddrProtocolsMatch(a, b ma.Multiaddr) bool { - ap := a.Protocols() - bp := b.Protocols() - - if len(ap) != len(bp) { - return false - } - - for i, api := range ap { - if api.Code != bp[i].Code { - return false - } - } - - return true -} - -// MultiaddrNetMatch returns the first Multiaddr found to match network. -func MultiaddrNetMatch(tgt ma.Multiaddr, srcs []ma.Multiaddr) ma.Multiaddr { - for _, a := range srcs { - if MultiaddrProtocolsMatch(tgt, a) { - return a - } - } - return nil -} From c9c0af814135eb08f4fbc5531111ae3791d04359 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 10 Apr 2017 11:28:13 +0700 Subject: [PATCH 03/10] setup new connections This package now sets up new connections. A go-libp2p-conn.Conn is an encrypted, stream-multiplexed connection. If the underlying tpt.Conn is a single-stream connection, it is first encrypted (using secio) and then a stream multiplexer is used to provide multistream support. --- conn.go | 145 +++++++++++++++++++++--------- dial.go | 72 +++++++++------ listen.go | 110 +++++++++++++---------- secure_conn.go | 116 +++++------------------- secure_conn_test.go | 212 -------------------------------------------- 5 files changed, 233 insertions(+), 422 deletions(-) delete mode 100644 secure_conn_test.go diff --git a/conn.go b/conn.go index f6bdc84..258908b 100644 --- a/conn.go +++ b/conn.go @@ -2,44 +2,108 @@ package conn import ( "context" + "errors" "io" "net" - "time" logging "github.com/ipfs/go-log" + ci "github.com/libp2p/go-libp2p-crypto" ic "github.com/libp2p/go-libp2p-crypto" iconn "github.com/libp2p/go-libp2p-interface-conn" lgbl "github.com/libp2p/go-libp2p-loggables" peer "github.com/libp2p/go-libp2p-peer" + secio "github.com/libp2p/go-libp2p-secio" tpt "github.com/libp2p/go-libp2p-transport" + smux "github.com/libp2p/go-stream-muxer" ma "github.com/multiformats/go-multiaddr" ) var log = logging.Logger("conn") -// singleConn represents a single connection to another Peer (IPFS Node). +// singleConn represents a single stream-multipexed connection to another Peer (IPFS Node). type singleConn struct { - local peer.ID - remote peer.ID - maconn tpt.Conn - event io.Closer -} + streamConn smux.Conn + tptConn tpt.Conn + + secSession secio.Session + + event io.Closer +} + +var _ iconn.Conn = &singleConn{} + +// newSingleConn constructs a new connection +func newSingleConn(ctx context.Context, local, remote peer.ID, privKey ci.PrivKey, tptConn tpt.Conn, pstpt smux.Transport, isServer bool) (iconn.Conn, error) { + ml := lgbl.Dial("conn", local, remote, tptConn.LocalMultiaddr(), tptConn.RemoteMultiaddr()) + + var streamConn smux.Conn + var secSession secio.Session + switch conn := tptConn.(type) { + case tpt.SingleStreamConn: + c := conn + // 1. secure the connection + if privKey != nil && iconn.EncryptConnections { + var err error + secSession, err = setupSecureSession(ctx, local, privKey, conn) + if err != nil { + return nil, err + } + c = &secureDuplexConn{ + SingleStreamConn: conn, + secure: secSession, + } + } else { + log.Warning("creating INSECURE connection %s at %s", tptConn.LocalMultiaddr(), tptConn.RemoteMultiaddr()) + } -// newConn constructs a new connection -func newSingleConn(ctx context.Context, local, remote peer.ID, maconn tpt.Conn) (iconn.Conn, error) { - ml := lgbl.Dial("conn", local, remote, maconn.LocalMultiaddr(), maconn.RemoteMultiaddr()) + // 2. start stream multipling + var err error + streamConn, err = pstpt.NewConn(c, isServer) + if err != nil { + return nil, err + } + case tpt.MultiStreamConn: + panic("not implemented yet") + } conn := &singleConn{ - local: local, - remote: remote, - maconn: maconn, - event: log.EventBegin(ctx, "connLifetime", ml), + streamConn: streamConn, + tptConn: tptConn, + secSession: secSession, + event: log.EventBegin(ctx, "connLifetime", ml), } log.Debugf("newSingleConn %p: %v to %v", conn, local, remote) return conn, nil } +func setupSecureSession(ctx context.Context, local peer.ID, privKey ci.PrivKey, ch io.ReadWriteCloser) (secio.Session, error) { + if local == "" { + return nil, errors.New("local peer is nil") + } + if privKey == nil { + return nil, errors.New("private key is nil") + } + sessgen := secio.SessionGenerator{ + LocalID: local, + PrivateKey: privKey, + } + secSession, err := sessgen.NewSession(ctx, ch) + if err != nil { + return nil, err + } + // force the handshake right now + // TODO: find a better solution for this + b := []byte("handshake") + if _, err := secSession.ReadWriter().Write(b); err != nil { + return nil, err + } + if _, err := io.ReadFull(secSession.ReadWriter(), b); err != nil { + return nil, err + } + return secSession, nil +} + // close is the internal close function, called by ContextCloser.Close func (c *singleConn) Close() error { defer func() { @@ -49,8 +113,8 @@ func (c *singleConn) Close() error { } }() - // close underlying connection - return c.maconn.Close() + // closing the stream muxer also closes the underlying net.Conn + return c.streamConn.Close() } // ID is an identifier unique to this connection. @@ -63,62 +127,63 @@ func (c *singleConn) String() string { } func (c *singleConn) LocalAddr() net.Addr { - return c.maconn.LocalAddr() + return c.tptConn.LocalAddr() } func (c *singleConn) RemoteAddr() net.Addr { - return c.maconn.RemoteAddr() + return c.tptConn.RemoteAddr() } func (c *singleConn) LocalPrivateKey() ic.PrivKey { + if c.secSession != nil { + return c.secSession.LocalPrivateKey() + } return nil } func (c *singleConn) RemotePublicKey() ic.PubKey { + if c.secSession != nil { + return c.secSession.RemotePublicKey() + } return nil } -func (c *singleConn) SetDeadline(t time.Time) error { - return c.maconn.SetDeadline(t) -} -func (c *singleConn) SetReadDeadline(t time.Time) error { - return c.maconn.SetReadDeadline(t) -} - -func (c *singleConn) SetWriteDeadline(t time.Time) error { - return c.maconn.SetWriteDeadline(t) -} - // LocalMultiaddr is the Multiaddr on this side func (c *singleConn) LocalMultiaddr() ma.Multiaddr { - return c.maconn.LocalMultiaddr() + return c.tptConn.LocalMultiaddr() } // RemoteMultiaddr is the Multiaddr on the remote side func (c *singleConn) RemoteMultiaddr() ma.Multiaddr { - return c.maconn.RemoteMultiaddr() + return c.tptConn.RemoteMultiaddr() } func (c *singleConn) Transport() tpt.Transport { - return c.maconn.Transport() + return c.tptConn.Transport() } // LocalPeer is the Peer on this side func (c *singleConn) LocalPeer() peer.ID { - return c.local + return c.secSession.LocalPeer() } // RemotePeer is the Peer on the remote side func (c *singleConn) RemotePeer() peer.ID { - return c.remote + return c.secSession.RemotePeer() +} + +func (c *singleConn) AcceptStream() (smux.Stream, error) { + return c.streamConn.AcceptStream() +} + +func (c *singleConn) OpenStream() (smux.Stream, error) { + return c.streamConn.OpenStream() } -// Read reads data, net.Conn style -func (c *singleConn) Read(buf []byte) (int, error) { - return c.maconn.Read(buf) +func (c *singleConn) Serve(s smux.StreamHandler) { + c.streamConn.Serve(s) } -// Write writes data, net.Conn style -func (c *singleConn) Write(buf []byte) (int, error) { - return c.maconn.Write(buf) +func (c *singleConn) IsClosed() bool { + return c.streamConn.IsClosed() } diff --git a/dial.go b/dial.go index 17d2b70..307dc3c 100644 --- a/dial.go +++ b/dial.go @@ -3,6 +3,7 @@ package conn import ( "context" "fmt" + "io" "strings" "time" @@ -12,12 +13,20 @@ import ( lgbl "github.com/libp2p/go-libp2p-loggables" peer "github.com/libp2p/go-libp2p-peer" tpt "github.com/libp2p/go-libp2p-transport" + smux "github.com/libp2p/go-stream-muxer" ma "github.com/multiformats/go-multiaddr" msmux "github.com/multiformats/go-multistream" ) type WrapFunc func(tpt.Conn) tpt.Conn +type timeoutReadWriteCloser interface { + io.ReadWriteCloser + SetDeadline(time.Time) error + SetReadDeadline(time.Time) error + SetWriteDeadline(time.Time) error +} + // Dialer is an object that can open connections. We could have a "convenience" // Dial function as before, but it would have many arguments, as dialing is // no longer simple (need a peerstore, a local peer, a context, a network, etc) @@ -45,14 +54,17 @@ type Dialer struct { Wrapper WrapFunc fallback tpt.Dialer + + streamMuxer smux.Transport } -func NewDialer(p peer.ID, pk ci.PrivKey, wrap WrapFunc) *Dialer { +func NewDialer(p peer.ID, pk ci.PrivKey, wrap WrapFunc, sm smux.Transport) *Dialer { return &Dialer{ - LocalPeer: p, - PrivateKey: pk, - Wrapper: wrap, - fallback: new(tpt.FallbackDialer), + LocalPeer: p, + PrivateKey: pk, + Wrapper: wrap, + fallback: new(tpt.FallbackDialer), + streamMuxer: sm, } } @@ -81,7 +93,7 @@ func (d *Dialer) Dial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) ( var errOut error done := make(chan struct{}) - // do it async to ensure we respect don contexteone + // do it async to ensure we respect done context go func() { defer func() { select { @@ -90,24 +102,25 @@ func (d *Dialer) Dial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) ( } }() - maconn, err := d.rawConnDial(ctx, raddr, remote) + tptConn, err := d.rawConnDial(ctx, raddr, remote) if err != nil { errOut = err return } if d.Protector != nil { - pconn, err := d.Protector.Protect(maconn) + var pconn tpt.Conn + pconn, err = d.Protector.Protect(tptConn) if err != nil { - maconn.Close() + tptConn.Close() errOut = err return } - maconn = pconn + tptConn = pconn } if d.Wrapper != nil { - maconn = d.Wrapper(maconn) + tptConn = d.Wrapper(tptConn) } cryptoProtoChoice := SecioTag @@ -115,36 +128,37 @@ func (d *Dialer) Dial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) ( cryptoProtoChoice = NoEncryptionTag } - maconn.SetReadDeadline(time.Now().Add(NegotiateReadTimeout)) - - err = msmux.SelectProtoOrFail(cryptoProtoChoice, maconn) - if err != nil { - errOut = err - return + var stream timeoutReadWriteCloser + switch con := tptConn.(type) { + case tpt.SingleStreamConn: + stream = con + case tpt.MultiStreamConn: + stream, err = con.OpenStream() + if err != nil { + errOut = err + return + } + defer stream.Close() } - maconn.SetReadDeadline(time.Time{}) - - c, err := newSingleConn(ctx, d.LocalPeer, remote, maconn) + stream.SetReadDeadline(time.Now().Add(NegotiateReadTimeout)) + err = msmux.SelectProtoOrFail(cryptoProtoChoice, stream) if err != nil { - maconn.Close() errOut = err return } - if d.PrivateKey == nil || !iconn.EncryptConnections { - log.Warning("dialer %s dialing INSECURELY %s at %s!", d, remote, raddr) - connOut = c - return - } - c2, err := newSecureConn(ctx, d.PrivateKey, c) + // clear deadline + stream.SetReadDeadline(time.Time{}) + + c, err := newSingleConn(ctx, d.LocalPeer, remote, d.PrivateKey, tptConn, d.streamMuxer, false) if err != nil { + tptConn.Close() errOut = err - c.Close() return } - connOut = c2 + connOut = c }() select { diff --git a/listen.go b/listen.go index 1ba7d9b..25bad18 100644 --- a/listen.go +++ b/listen.go @@ -17,6 +17,7 @@ import ( peer "github.com/libp2p/go-libp2p-peer" tpt "github.com/libp2p/go-libp2p-transport" filter "github.com/libp2p/go-maddr-filter" + smux "github.com/libp2p/go-stream-muxer" ma "github.com/multiformats/go-multiaddr" msmux "github.com/multiformats/go-multistream" ) @@ -42,6 +43,8 @@ type listener struct { privk ic.PrivKey // private key to use to initialize secure conns protec ipnet.Protector + streamMuxer smux.Transport + filters *filter.Filters wrapper ConnWrapper @@ -51,11 +54,13 @@ type listener struct { mux *msmux.MultistreamMuxer - incoming chan connErr + incoming chan connOrErr ctx context.Context } +var _ iconn.Listener = &listener{} + func (l *listener) teardown() error { defer log.Debugf("listener closed: %s %s", l.local, l.Multiaddr()) return l.Listener.Close() @@ -74,39 +79,30 @@ func (l *listener) SetAddrFilters(fs *filter.Filters) { l.filters = fs } -type connErr struct { +type connOrErr struct { conn tpt.Conn err error } // Accept waits for and returns the next connection to the listener. -// Note that unfortunately this -func (l *listener) Accept() (tpt.Conn, error) { +func (l *listener) Accept() (iconn.Conn, error) { for con := range l.incoming { if con.err != nil { return nil, con.err } - - c, err := newSingleConn(l.ctx, l.local, "", con.conn) - if err != nil { - con.conn.Close() - if l.catcher.IsTemporary(err) { - continue - } - return nil, err - } + tptConn := con.conn if l.privk == nil || !iconn.EncryptConnections { - log.Warning("listener %s listening INSECURELY!", l) - return c, nil + log.Warningf("listener %s listening INSECURELY!", l) } - sc, err := newSecureConn(l.ctx, l.privk, c) + + c, err := newSingleConn(l.ctx, l.local, "", l.privk, tptConn, l.streamMuxer, true) if err != nil { - con.conn.Close() - log.Infof("ignoring conn we failed to secure: %s %s", err, c) + tptConn.Close() continue } - return sc, nil + + return c, nil } return nil, fmt.Errorf("listener is closed") } @@ -149,64 +145,79 @@ func (l *listener) handleIncoming() { defer wg.Done() for { - maconn, err := l.Listener.Accept() + conn, err := l.Listener.Accept() if err != nil { if l.catcher.IsTemporary(err) { continue } - l.incoming <- connErr{err: err} + l.incoming <- connOrErr{err: err} return } - log.Debugf("listener %s got connection: %s <---> %s", l, maconn.LocalMultiaddr(), maconn.RemoteMultiaddr()) + log.Debugf("listener %s got connection: %s <---> %s", l, conn.LocalMultiaddr(), conn.RemoteMultiaddr()) - if l.filters != nil && l.filters.AddrBlocked(maconn.RemoteMultiaddr()) { - log.Debugf("blocked connection from %s", maconn.RemoteMultiaddr()) - maconn.Close() + if l.filters != nil && l.filters.AddrBlocked(conn.RemoteMultiaddr()) { + log.Debugf("blocked connection from %s", conn.RemoteMultiaddr()) + conn.Close() continue } - // If we have a wrapper func, wrap this conn - if l.wrapper != nil { - maconn = l.wrapper(maconn) - } wg.Add(1) go func() { defer wg.Done() if l.protec != nil { - pc, err := l.protec.Protect(maconn) + pc, err := l.protec.Protect(conn) if err != nil { - maconn.Close() + conn.Close() log.Warning("protector failed: ", err) + return } - maconn = pc + conn = pc + } + + // If we have a wrapper func, wrap this conn + if l.wrapper != nil { + conn = l.wrapper(conn) } - maconn.SetReadDeadline(time.Now().Add(NegotiateReadTimeout)) - _, _, err = l.mux.Negotiate(maconn) + var stream timeoutReadWriteCloser + switch conn := conn.(type) { + case tpt.SingleStreamConn: + stream = conn + case tpt.MultiStreamConn: + stream, err = conn.AcceptStream() + if err != nil { + conn.Close() + log.Warning("accepting stream failed: ", err) + return + } + defer stream.Close() + } + + // TODO: should the negotiate timeout include the time taken by AcceptStream? + stream.SetReadDeadline(time.Now().Add(NegotiateReadTimeout)) + _, _, err = l.mux.Negotiate(stream) if err != nil { log.Warning("incoming conn: negotiation of crypto protocol failed: ", err) - maconn.Close() + conn.Close() return } - // clear read readline - maconn.SetReadDeadline(time.Time{}) + stream.SetReadDeadline(time.Time{}) - l.incoming <- connErr{conn: maconn} + l.incoming <- connOrErr{conn: conn} }() } } -func WrapTransportListener(ctx context.Context, ml tpt.Listener, local peer.ID, +func WrapTransportListener(ctx context.Context, ml tpt.Listener, local peer.ID, pstpt smux.Transport, sk ic.PrivKey) (iconn.Listener, error) { - return WrapTransportListenerWithProtector(ctx, ml, local, sk, nil) + return WrapTransportListenerWithProtector(ctx, ml, local, sk, pstpt, nil) } func WrapTransportListenerWithProtector(ctx context.Context, ml tpt.Listener, local peer.ID, - sk ic.PrivKey, protec ipnet.Protector) (iconn.Listener, error) { - + sk ic.PrivKey, pstpt smux.Transport, protec ipnet.Protector) (iconn.Listener, error) { if protec == nil && ipnet.ForcePrivateNetwork { log.Error("tried to listen with no Private Network Protector but usage" + " of Private Networks is forced by the enviroment") @@ -214,13 +225,14 @@ func WrapTransportListenerWithProtector(ctx context.Context, ml tpt.Listener, lo } l := &listener{ - Listener: ml, - local: local, - privk: sk, - protec: protec, - mux: msmux.NewMultistreamMuxer(), - incoming: make(chan connErr, connAcceptBuffer), - ctx: ctx, + Listener: ml, + local: local, + privk: sk, + protec: protec, + mux: msmux.NewMultistreamMuxer(), + incoming: make(chan connOrErr, connAcceptBuffer), + ctx: ctx, + streamMuxer: pstpt, } l.proc = goprocessctx.WithContextAndTeardown(ctx, l.teardown) l.catcher.IsTemp = func(e error) bool { diff --git a/secure_conn.go b/secure_conn.go index d726d14..ec0513a 100644 --- a/secure_conn.go +++ b/secure_conn.go @@ -1,130 +1,62 @@ package conn import ( - "context" - "errors" "net" "time" - ic "github.com/libp2p/go-libp2p-crypto" - iconn "github.com/libp2p/go-libp2p-interface-conn" - peer "github.com/libp2p/go-libp2p-peer" secio "github.com/libp2p/go-libp2p-secio" tpt "github.com/libp2p/go-libp2p-transport" ma "github.com/multiformats/go-multiaddr" ) -// secureConn wraps another Conn object with an encrypted channel. -type secureConn struct { - insecure iconn.Conn // the wrapped conn - secure secio.Session // secure Session +// secureSingleStreamConn wraps another SingleStreamConn object with an encrypted channel. +type secureSingleStreamConn struct { + insecure tpt.SingleStreamConn // the wrapped conn + secure secio.Session // secure Session } -// newConn constructs a new connection -func newSecureConn(ctx context.Context, sk ic.PrivKey, insecure iconn.Conn) (iconn.Conn, error) { +var _ tpt.SingleStreamConn = &secureSingleStreamConn{} - if insecure == nil { - return nil, errors.New("insecure is nil") - } - if insecure.LocalPeer() == "" { - return nil, errors.New("insecure.LocalPeer() is nil") - } - if sk == nil { - return nil, errors.New("private key is nil") - } - - // NewSession performs the secure handshake, which takes multiple RTT - sessgen := secio.SessionGenerator{LocalID: insecure.LocalPeer(), PrivateKey: sk} - secure, err := sessgen.NewSession(ctx, insecure) - if err != nil { - return nil, err - } - - conn := &secureConn{ - insecure: insecure, - secure: secure, - } - return conn, nil -} - -func (c *secureConn) Close() error { - return c.secure.Close() +func (c *secureSingleStreamConn) Read(buf []byte) (int, error) { + return c.secure.ReadWriter().Read(buf) } -// ID is an identifier unique to this connection. -func (c *secureConn) ID() string { - return iconn.ID(c) +func (c *secureSingleStreamConn) Write(buf []byte) (int, error) { + return c.secure.ReadWriter().Write(buf) } -func (c *secureConn) String() string { - return iconn.String(c, "secureConn") +func (c *secureSingleStreamConn) Close() error { + return c.secure.Close() } -func (c *secureConn) LocalAddr() net.Addr { +func (c *secureSingleStreamConn) LocalAddr() net.Addr { return c.insecure.LocalAddr() } -func (c *secureConn) RemoteAddr() net.Addr { - return c.insecure.RemoteAddr() -} - -func (c *secureConn) SetDeadline(t time.Time) error { - return c.insecure.SetDeadline(t) -} - -func (c *secureConn) SetReadDeadline(t time.Time) error { - return c.insecure.SetReadDeadline(t) -} - -func (c *secureConn) SetWriteDeadline(t time.Time) error { - return c.insecure.SetWriteDeadline(t) -} - -// LocalMultiaddr is the Multiaddr on this side -func (c *secureConn) LocalMultiaddr() ma.Multiaddr { +func (c *secureSingleStreamConn) LocalMultiaddr() ma.Multiaddr { return c.insecure.LocalMultiaddr() } -// RemoteMultiaddr is the Multiaddr on the remote side -func (c *secureConn) RemoteMultiaddr() ma.Multiaddr { - return c.insecure.RemoteMultiaddr() -} - -// LocalPeer is the Peer on this side -func (c *secureConn) LocalPeer() peer.ID { - return c.secure.LocalPeer() -} - -// RemotePeer is the Peer on the remote side -func (c *secureConn) RemotePeer() peer.ID { - return c.secure.RemotePeer() -} - -// LocalPrivateKey is the public key of the peer on this side -func (c *secureConn) LocalPrivateKey() ic.PrivKey { - return c.secure.LocalPrivateKey() +func (c *secureSingleStreamConn) RemoteAddr() net.Addr { + return c.insecure.RemoteAddr() } -// RemotePubKey is the public key of the peer on the remote side -func (c *secureConn) RemotePublicKey() ic.PubKey { - return c.secure.RemotePublicKey() +func (c *secureSingleStreamConn) RemoteMultiaddr() ma.Multiaddr { + return c.insecure.RemoteMultiaddr() } -// Read reads data, net.Conn style -func (c *secureConn) Read(buf []byte) (int, error) { - return c.secure.ReadWriter().Read(buf) +func (c *secureSingleStreamConn) SetDeadline(t time.Time) error { + return c.insecure.SetDeadline(t) } -// Write writes data, net.Conn style -func (c *secureConn) Write(buf []byte) (int, error) { - return c.secure.ReadWriter().Write(buf) +func (c *secureSingleStreamConn) SetReadDeadline(t time.Time) error { + return c.insecure.SetDeadline(t) } -// ReleaseMsg releases a buffer -func (c *secureConn) ReleaseMsg(m []byte) { - c.secure.ReadWriter().ReleaseMsg(m) +func (c *secureSingleStreamConn) SetWriteDeadline(t time.Time) error { + return c.insecure.SetDeadline(t) } -func (c *secureConn) Transport() tpt.Transport { +func (c *secureSingleStreamConn) Transport() tpt.Transport { return c.insecure.Transport() } diff --git a/secure_conn_test.go b/secure_conn_test.go deleted file mode 100644 index 80fb477..0000000 --- a/secure_conn_test.go +++ /dev/null @@ -1,212 +0,0 @@ -package conn - -import ( - "bytes" - "context" - "runtime" - "sync" - "testing" - "time" - - ic "github.com/libp2p/go-libp2p-crypto" - iconn "github.com/libp2p/go-libp2p-interface-conn" - travis "github.com/libp2p/go-testutil/ci/travis" -) - -func upgradeToSecureConn(t *testing.T, ctx context.Context, sk ic.PrivKey, c iconn.Conn) (iconn.Conn, error) { - if c, ok := c.(*secureConn); ok { - return c, nil - } - - // shouldn't happen, because dial + listen already return secure conns. - s, err := newSecureConn(ctx, sk, c) - if err != nil { - return nil, err - } - - // need to read + write, as that's what triggers the handshake. - h := []byte("hello") - if _, err := s.Write(h); err != nil { - return nil, err - } - if _, err := s.Read(h); err != nil { - return nil, err - } - return s, nil -} - -func secureHandshake(t *testing.T, ctx context.Context, sk ic.PrivKey, c iconn.Conn, done chan error) { - _, err := upgradeToSecureConn(t, ctx, sk, c) - done <- err -} - -func TestSecureSimple(t *testing.T) { - // t.Skip("Skipping in favor of another test") - - numMsgs := 100 - if testing.Short() { - numMsgs = 10 - } - - ctx := context.Background() - c1, c2, p1, p2 := setupSingleConn(t, ctx) - - done := make(chan error) - go secureHandshake(t, ctx, p1.PrivKey, c1, done) - go secureHandshake(t, ctx, p2.PrivKey, c2, done) - - for i := 0; i < 2; i++ { - if err := <-done; err != nil { - t.Fatal(err) - } - } - - for i := 0; i < numMsgs; i++ { - testOneSendRecv(t, c1, c2) - testOneSendRecv(t, c2, c1) - } - - c1.Close() - c2.Close() -} - -func TestSecureClose(t *testing.T) { - // t.Skip("Skipping in favor of another test") - - ctx := context.Background() - c1, c2, p1, p2 := setupSingleConn(t, ctx) - - done := make(chan error) - go secureHandshake(t, ctx, p1.PrivKey, c1, done) - go secureHandshake(t, ctx, p2.PrivKey, c2, done) - - for i := 0; i < 2; i++ { - if err := <-done; err != nil { - t.Fatal(err) - } - } - - testOneSendRecv(t, c1, c2) - - c1.Close() - testNotOneSendRecv(t, c1, c2) - - c2.Close() - testNotOneSendRecv(t, c1, c2) - testNotOneSendRecv(t, c2, c1) - -} - -func TestSecureCancelHandshake(t *testing.T) { - // t.Skip("Skipping in favor of another test") - - ctx, cancel := context.WithCancel(context.Background()) - c1, c2, p1, p2 := setupSingleConn(t, ctx) - - done := make(chan error) - go secureHandshake(t, ctx, p1.PrivKey, c1, done) - time.Sleep(time.Millisecond) - cancel() // cancel ctx - go secureHandshake(t, ctx, p2.PrivKey, c2, done) - - for i := 0; i < 2; i++ { - if err := <-done; err == nil { - t.Error("cancel should've errored out") - } - } -} - -func TestSecureHandshakeFailsWithWrongKeys(t *testing.T) { - // t.Skip("Skipping in favor of another test") - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - c1, c2, p1, p2 := setupSingleConn(t, ctx) - - done := make(chan error) - go secureHandshake(t, ctx, p2.PrivKey, c1, done) - go secureHandshake(t, ctx, p1.PrivKey, c2, done) - - for i := 0; i < 2; i++ { - if err := <-done; err == nil { - t.Fatal("wrong keys should've errored out.") - } - } -} - -func TestSecureCloseLeak(t *testing.T) { - // t.Skip("Skipping in favor of another test") - - if testing.Short() { - t.SkipNow() - } - if travis.IsRunning() { - t.Skip("this doesn't work well on travis") - } - - runPair := func(c1, c2 iconn.Conn, num int) { - mc1 := msgioWrap(c1) - mc2 := msgioWrap(c2) - - log.Debugf("runPair %d", num) - - for i := 0; i < num; i++ { - log.Debugf("runPair iteration %d", i) - b1 := []byte("beep") - mc1.WriteMsg(b1) - b2, err := mc2.ReadMsg() - if err != nil { - panic(err) - } - if !bytes.Equal(b1, b2) { - panic("bytes not equal") - } - - b2 = []byte("beep") - mc2.WriteMsg(b2) - b1, err = mc1.ReadMsg() - if err != nil { - panic(err) - } - if !bytes.Equal(b1, b2) { - panic("bytes not equal") - } - - time.Sleep(time.Microsecond * 5) - } - } - - var cons = 5 - var msgs = 50 - log.Debugf("Running %d connections * %d msgs.\n", cons, msgs) - - var wg sync.WaitGroup - for i := 0; i < cons; i++ { - wg.Add(1) - - ctx, cancel := context.WithCancel(context.Background()) - c1, c2, _, _ := setupSecureConn(t, ctx) - go func(c1, c2 iconn.Conn) { - - defer func() { - c1.Close() - c2.Close() - cancel() - wg.Done() - }() - - runPair(c1, c2, msgs) - }(c1, c2) - } - - log.Debugf("Waiting...") - wg.Wait() - // done! - - time.Sleep(time.Millisecond * 150) - ngr := runtime.NumGoroutine() - if ngr > 25 { - // panic("uncomment me to debug") - t.Fatal("leaking goroutines:", ngr) - } -} From a8718edbd9fdc2eb5dfeb3900b1909aaf416da27 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 19 Jul 2017 12:38:08 +0700 Subject: [PATCH 04/10] remove conn.Serve This method was recently removed from the interface. --- conn.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/conn.go b/conn.go index 258908b..d71cca8 100644 --- a/conn.go +++ b/conn.go @@ -180,10 +180,6 @@ func (c *singleConn) OpenStream() (smux.Stream, error) { return c.streamConn.OpenStream() } -func (c *singleConn) Serve(s smux.StreamHandler) { - c.streamConn.Serve(s) -} - func (c *singleConn) IsClosed() bool { return c.streamConn.IsClosed() } From f8e46f7be9dec2f5f4e5fb2a69148230a6564bc5 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 9 Aug 2017 17:05:41 +0700 Subject: [PATCH 05/10] reenable the tests, and add some more --- Makefile | 6 +- conn_suite_test.go | 115 +++++++ conn_test.go | 401 +++++++++++++++++------- dial_test.go | 750 --------------------------------------------- listen_test.go | 225 ++++++++++++++ protector_test.go | 266 ++++++++++++++++ 6 files changed, 906 insertions(+), 857 deletions(-) create mode 100644 conn_suite_test.go delete mode 100644 dial_test.go create mode 100644 listen_test.go create mode 100644 protector_test.go diff --git a/Makefile b/Makefile index 7811c09..bf3427e 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,11 @@ covertools: go get github.com/mattn/goveralls go get golang.org/x/tools/cmd/cover -deps: gx covertools +ginkgo: + go get github.com/onsi/ginkgo/ginkgo + go get github.com/onsi/gomega + +deps: gx covertools ginkgo gx --verbose install --global gx-go rewrite diff --git a/conn_suite_test.go b/conn_suite_test.go new file mode 100644 index 0000000..ee91944 --- /dev/null +++ b/conn_suite_test.go @@ -0,0 +1,115 @@ +package conn + +import ( + "context" + "strings" + "testing" + "time" + + ci "github.com/libp2p/go-libp2p-crypto" + iconn "github.com/libp2p/go-libp2p-interface-conn" + peer "github.com/libp2p/go-libp2p-peer" + tpt "github.com/libp2p/go-libp2p-transport" + tcpt "github.com/libp2p/go-tcp-transport" + tu "github.com/libp2p/go-testutil" + quict "github.com/marten-seemann/libp2p-quic-transport" + ma "github.com/multiformats/go-multiaddr" + yamux "github.com/whyrusleeping/go-smux-yamux" + grc "github.com/whyrusleeping/gorocheck" + "github.com/whyrusleeping/mafmt" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestGoLibp2pConn(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "go-libp2p-conn Suite") +} + +var _ = AfterEach(func() { + time.Sleep(300 * time.Millisecond) + Expect(grc.CheckForLeaks(func(r *grc.Goroutine) bool { + return strings.Contains(r.Function, "go-log.") || + strings.Contains(r.Stack[0], "testing.(*T).Run") || + strings.Contains(r.Function, "specrunner.") || + strings.Contains(r.Function, "runtime.gopark") + })).To(Succeed()) +}) + +// the stream muxer used for tests using the single stream connection +var streamMuxer = yamux.DefaultTransport + +type transportType uint8 + +const ( + singleStreamTransport transportType = 1 + iota + multiStreamTransport +) + +var transportTypes = []transportType{singleStreamTransport} + +func (t transportType) String() string { + if t == multiStreamTransport { + return "multi-stream transport" + } + return "single-stream transport" +} + +// dialRawConn dials a tpt.Conn +// but it stops there. It doesn't do protocol selection and handshake +func dialRawConn(laddr, raddr ma.Multiaddr) tpt.Conn { + var d tpt.Dialer + if mafmt.QUIC.Matches(laddr) { + var err error + d, err = quict.NewQuicTransport().Dialer(laddr) + Expect(err).ToNot(HaveOccurred()) + } else { + var err error + d, err = tcpt.NewTCPTransport().Dialer(laddr) + Expect(err).ToNot(HaveOccurred()) + } + c, err := d.Dial(raddr) + Expect(err).ToNot(HaveOccurred()) + return c +} + +// getTransport gets the right transport for a multiaddr +func getTransport(a ma.Multiaddr) tpt.Transport { + if mafmt.QUIC.Matches(a) { + return quict.NewQuicTransport() + } + return tcpt.NewTCPTransport() +} + +// getListener creates a listener based on the PeerNetParams +// it updates the PeerNetParams to reflect the local address that was selected by the kernel +func getListener(ctx context.Context, p *tu.PeerNetParams) iconn.Listener { + tptListener, err := getTransport(p.Addr).Listen(p.Addr) + Expect(err).ToNot(HaveOccurred()) + list, err := WrapTransportListener(ctx, tptListener, p.ID, streamMuxer, p.PrivKey) + Expect(err).ToNot(HaveOccurred()) + p.Addr = list.Multiaddr() + return list +} + +func getDialer(localPeer peer.ID, privKey ci.PrivKey, addr ma.Multiaddr) *Dialer { + d := NewDialer(localPeer, privKey, nil, streamMuxer) + d.fallback = nil // unset the fallback dialer. We want tests use the configured dialer, and to fail otherwise + tptd, err := getTransport(addr).Dialer(addr) + Expect(err).ToNot(HaveOccurred()) + d.AddDialer(tptd) + return d +} + +// randPeerNetParams works like testutil.RandPeerNetParams +// if called for a multi-stream transport, it replaces the address with a QUIC address +func randPeerNetParams(tr transportType) *tu.PeerNetParams { + p, err := tu.RandPeerNetParams() + Expect(err).ToNot(HaveOccurred()) + if tr == multiStreamTransport { + p.Addr, err = ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") + Expect(err).ToNot(HaveOccurred()) + } + return p +} diff --git a/conn_test.go b/conn_test.go index bf930a6..077376c 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,139 +1,328 @@ package conn import ( - "bytes" "context" "fmt" - "runtime" + "io" "sync" - "testing" "time" - msgio "github.com/jbenet/go-msgio" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + iconn "github.com/libp2p/go-libp2p-interface-conn" - travis "github.com/libp2p/go-testutil/ci/travis" + tpt "github.com/libp2p/go-libp2p-transport" + smux "github.com/libp2p/go-stream-muxer" ) -func msgioWrap(c iconn.Conn) msgio.ReadWriter { - return msgio.NewReadWriter(c) -} +var _ = Describe("Connections", func() { + It("uses the right handshake protocol", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() -func testOneSendRecv(t *testing.T, c1, c2 iconn.Conn) { - mc1 := msgioWrap(c1) - mc2 := msgioWrap(c2) + p1 := randPeerNetParams(singleStreamTransport) + l1 := getListener(ctx, p1) + defer l1.Close() + go l1.Accept() + }) - log.Debugf("testOneSendRecv from %s to %s", c1.LocalPeer(), c2.LocalPeer()) - m1 := []byte("hello") - if err := mc1.WriteMsg(m1); err != nil { - t.Fatal(err) - } - m2, err := mc2.ReadMsg() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(m1, m2) { - t.Fatalf("failed to send: %s %s", m1, m2) - } -} + for _, val := range transportTypes { + tr := val -func testNotOneSendRecv(t *testing.T, c1, c2 iconn.Conn) { - mc1 := msgioWrap(c1) - mc2 := msgioWrap(c2) + Context(fmt.Sprintf("using a %s", tr), func() { + for _, val := range []bool{true, false} { + secure := val - m1 := []byte("hello") - if err := mc1.WriteMsg(m1); err == nil { - t.Fatal("write should have failed", err) - } - _, err := mc2.ReadMsg() - if err == nil { - t.Fatal("read should have failed", err) - } -} + It(fmt.Sprintf("establishes a connection (secure: %t)", secure), func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() -func TestClose(t *testing.T) { - // t.Skip("Skipping in favor of another test") + p1 := randPeerNetParams(tr) + p2 := randPeerNetParams(tr) + if !secure { + p1.PrivKey = nil + p2.PrivKey = nil + } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - c1, c2, _, _ := setupSingleConn(t, ctx) + l1 := getListener(ctx, p1) + defer l1.Close() - testOneSendRecv(t, c1, c2) - testOneSendRecv(t, c2, c1) + // accept a connection, accept a stream on this connection and echo everything + go func() { + defer GinkgoRecover() + c, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + str, err := c.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + go io.Copy(str, str) + }() - c1.Close() - testNotOneSendRecv(t, c1, c2) + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + c, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + defer c.Close() + str, err := c.OpenStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("beep")) + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("boop")) + Expect(err).ToNot(HaveOccurred()) - c2.Close() - testNotOneSendRecv(t, c2, c1) - testNotOneSendRecv(t, c1, c2) -} + out := make([]byte, 8) + _, err = io.ReadFull(str, out) + Expect(err).ToNot(HaveOccurred()) + Expect(out).To(Equal([]byte("beepboop"))) + }) + } -func TestCloseLeak(t *testing.T) { - // t.Skip("Skipping in favor of another test") - if testing.Short() { - t.SkipNow() - } + It("continues accepting connections while another accept is hanging", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - if travis.IsRunning() { - t.Skip("this doesn't work well on travis") - } + p1 := randPeerNetParams(tr) + p2 := randPeerNetParams(tr) - var wg sync.WaitGroup + l1 := getListener(ctx, p1) + defer l1.Close() - runPair := func(num int) { - ctx, cancel := context.WithCancel(context.Background()) - c1, c2, _, _ := setupSingleConn(t, ctx) + go func() { + defer GinkgoRecover() + conn := dialRawConn(p2.Addr, l1.Multiaddr()) + defer conn.Close() // hang this connection - mc1 := msgioWrap(c1) - mc2 := msgioWrap(c2) + // ensure that the first conn hits first + time.Sleep(50 * time.Millisecond) + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + conn2, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + defer conn2.Close() + }() - for i := 0; i < num; i++ { - b1 := []byte(fmt.Sprintf("beep%d", i)) - mc1.WriteMsg(b1) - b2, err := mc2.ReadMsg() - if err != nil { - panic(err) - } - if !bytes.Equal(b1, b2) { - panic(fmt.Errorf("bytes not equal: %s != %s", b1, b2)) - } + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + Eventually(done).Should(BeClosed()) + }) - b2 = []byte(fmt.Sprintf("boop%d", i)) - mc2.WriteMsg(b2) - b1, err = mc1.ReadMsg() - if err != nil { - panic(err) - } - if !bytes.Equal(b1, b2) { - panic(fmt.Errorf("bytes not equal: %s != %s", b1, b2)) - } + It("timeouts", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - <-time.After(time.Microsecond * 5) - } + old := NegotiateReadTimeout + NegotiateReadTimeout = 3 * time.Second + defer func() { NegotiateReadTimeout = old }() - c1.Close() - c2.Close() - cancel() // close the listener - wg.Done() - } + p1 := randPeerNetParams(tr) + p2 := randPeerNetParams(tr) - var cons = 5 - var msgs = 50 - log.Debugf("Running %d connections * %d msgs.\n", cons, msgs) - for i := 0; i < cons; i++ { - wg.Add(1) - go runPair(msgs) - } + l1 := getListener(ctx, p1) + defer l1.Close() + + n := 20 + + before := time.Now() + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + defer GinkgoRecover() + defer wg.Done() + var conn io.Reader + c := dialRawConn(p2.Addr, l1.Multiaddr()) + defer c.Close() + switch tr { + case singleStreamTransport: + conn = c.(tpt.SingleStreamConn) + case multiStreamTransport: + var err error + conn, err = c.(tpt.MultiStreamConn).OpenStream() + Expect(err).ToNot(HaveOccurred()) + } + // hang this connection until timeout + io.ReadFull(conn, make([]byte, 1000)) + }() + } + + // wait to make sure the hanging dials have started + time.Sleep(50 * time.Millisecond) + + accepted := make(chan struct{}) // this chan is closed once all good connections have been accepted + goodN := 10 + for i := 0; i < goodN; i++ { + go func(i int) { + defer GinkgoRecover() + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + conn, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + <-accepted + conn.Close() + }(i) + } + + for i := 0; i < goodN; i++ { + _, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + } + close(accepted) + Expect(time.Now()).To(BeTemporally("<", before.Add(NegotiateReadTimeout/4))) + Eventually(func() bool { + wg.Wait() // wait for the timeouts for the raw connections to occur + return true + }, NegotiateReadTimeout).Should(BeTrue()) + Expect(time.Now()).To(BeTemporally(">", before.Add(NegotiateReadTimeout))) + + // make sure we can dial in still after a bunch of timeouts + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + conn, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + defer conn.Close() + Eventually(done).Should(BeClosed()) + }) + + It("doesn't complete the handshake with the wrong keys", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1 := randPeerNetParams(tr) + p2 := randPeerNetParams(tr) + + l1 := getListener(ctx, p1) + defer l1.Close() + + // use the wrong private key here, correct would be: p2.PrivKey + d2 := getDialer(p2.ID, p1.PrivKey, p2.Addr) + + accepted := make(chan struct{}) + go func() { + l1.Accept() + close(accepted) + }() + + _, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).To(MatchError("peer.ID does not match PrivateKey")) + // make sure no connection was accepted + Consistently(accepted).ShouldNot(BeClosed()) + }) + + Context("closing", func() { + setupConn := func(ctx context.Context, tr transportType) (iconn.Conn, iconn.Conn) { + p1 := randPeerNetParams(tr) + p2 := randPeerNetParams(tr) + + l1 := getListener(ctx, p1) + + var c2 iconn.Conn + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + done := make(chan error) + go func() { + defer GinkgoRecover() + var err error + c2, err = d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + + c1, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + Eventually(done).Should(BeClosed()) + return c1, c2 + } + + openStreamAndSend := func(c1, c2 iconn.Conn) { + str1, err := c1.OpenStream() + Expect(err).ToNot(HaveOccurred()) + m1 := []byte("hello") + _, err = str1.Write(m1) + Expect(err).ToNot(HaveOccurred()) + str2, err := c2.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + m2 := make([]byte, len(m1)) + _, err = str2.Read(m2) + Expect(err).ToNot(HaveOccurred()) + Expect(m1).To(Equal(m2)) + } + + checkStreamOpenAcceptFails := func(c1, c2 iconn.Conn) { + _, err := c1.OpenStream() + Expect(err).To(HaveOccurred()) + accepted := make(chan struct{}) + go func() { + _, err := c2.AcceptStream() + Expect(err).To(HaveOccurred()) + close(accepted) + }() + Eventually(accepted).Should(BeClosed()) + } + + It("closes", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c1, c2 := setupConn(ctx, tr) + openStreamAndSend(c1, c2) + openStreamAndSend(c2, c1) + + c1.Close() + Expect(c1.IsClosed()).To(BeTrue()) + Eventually(c2.IsClosed).Should(BeTrue()) + checkStreamOpenAcceptFails(c2, c1) + checkStreamOpenAcceptFails(c1, c2) + }) + + It("doesn't leak", func() { + // runPair opens one stream and sends num messages + runPair := func(c1, c2 iconn.Conn, num int) { + var str2 smux.Stream + str1, err := c1.OpenStream() + Expect(err).ToNot(HaveOccurred()) + + for i := 0; i < num; i++ { + b1 := []byte("beep") + _, err := str1.Write(b1) + Expect(err).ToNot(HaveOccurred()) + if str2 == nil { + str2, err = c2.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + } + b2 := make([]byte, len(b1)) + _, err = str2.Read(b2) + Expect(err).ToNot(HaveOccurred()) + Expect(b1).To(Equal(b2)) + } + } - log.Debugf("Waiting...\n") - wg.Wait() - // done! + var cons = 10 + var msgs = 10 + var wg sync.WaitGroup + for i := 0; i < cons; i++ { + wg.Add(1) + ctx, cancel := context.WithCancel(context.Background()) + c1, c2 := setupConn(ctx, tr) + go func(c1, c2 iconn.Conn) { + defer GinkgoRecover() + defer cancel() + runPair(c1, c2, msgs) + c1.Close() + c2.Close() + wg.Done() + }(c1, c2) + } - time.Sleep(time.Millisecond * 150) - ngr := runtime.NumGoroutine() - if ngr > 25 { - // note, this is really innacurate - //panic("uncomment me to debug") - t.Fatal("leaking goroutines:", ngr) + wg.Wait() + }) + }) + }) } -} +}) diff --git a/dial_test.go b/dial_test.go deleted file mode 100644 index de70d07..0000000 --- a/dial_test.go +++ /dev/null @@ -1,750 +0,0 @@ -package conn - -import ( - "bytes" - "context" - "fmt" - "io" - "net" - "runtime" - "strings" - "sync" - "testing" - "time" - - ic "github.com/libp2p/go-libp2p-crypto" - iconn "github.com/libp2p/go-libp2p-interface-conn" - ipnet "github.com/libp2p/go-libp2p-interface-pnet" - peer "github.com/libp2p/go-libp2p-peer" - transport "github.com/libp2p/go-libp2p-transport" - tcpt "github.com/libp2p/go-tcp-transport" - tu "github.com/libp2p/go-testutil" - ma "github.com/multiformats/go-multiaddr" - msmux "github.com/multiformats/go-multistream" - grc "github.com/whyrusleeping/gorocheck" -) - -func goroFilter(r *grc.Goroutine) bool { - return strings.Contains(r.Function, "go-log.") || strings.Contains(r.Stack[0], "testing.(*T).Run") -} - -func echoListen(ctx context.Context, listener iconn.Listener) { - for { - c, err := listener.Accept() - if err != nil { - - select { - case <-ctx.Done(): - return - default: - } - - if ne, ok := err.(net.Error); ok && ne.Temporary() { - <-time.After(time.Microsecond * 10) - continue - } - - log.Debugf("echoListen: listener appears to be closing") - return - } - - go echo(c.(iconn.Conn)) - } -} - -func echo(c iconn.Conn) { - io.Copy(c, c) -} - -func setupSecureConn(t *testing.T, ctx context.Context) (a, b iconn.Conn, p1, p2 tu.PeerNetParams) { - return setupConn(t, ctx, true) -} - -func setupSingleConn(t *testing.T, ctx context.Context) (a, b iconn.Conn, p1, p2 tu.PeerNetParams) { - return setupConn(t, ctx, false) -} - -func Listen(ctx context.Context, addr ma.Multiaddr, local peer.ID, sk ic.PrivKey) (iconn.Listener, error) { - list, err := tcpt.NewTCPTransport().Listen(addr) - if err != nil { - return nil, err - } - - return WrapTransportListener(ctx, list, local, sk) -} - -func dialer(t *testing.T, a ma.Multiaddr) transport.Dialer { - tpt := tcpt.NewTCPTransport() - tptd, err := tpt.Dialer(a) - if err != nil { - t.Fatal(err) - } - - return tptd -} - -func setupConn(t *testing.T, ctx context.Context, secure bool) (a, b iconn.Conn, p1, p2 tu.PeerNetParams) { - - p1 = tu.RandPeerNetParamsOrFatal(t) - p2 = tu.RandPeerNetParamsOrFatal(t) - - key1 := p1.PrivKey - key2 := p2.PrivKey - if !secure { - key1 = nil - key2 = nil - } - l1, err := Listen(ctx, p1.Addr, p1.ID, key1) - if err != nil { - t.Fatal(err) - } - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - d2 := &Dialer{ - LocalPeer: p2.ID, - PrivateKey: key2, - } - - d2.AddDialer(dialer(t, p2.Addr)) - - var c2 iconn.Conn - - done := make(chan error) - go func() { - defer close(done) - - var err error - c2, err = d2.Dial(ctx, p1.Addr, p1.ID) - if err != nil { - done <- err - return - } - - // if secure, need to read + write, as that's what triggers the handshake. - if secure { - if err := sayHello(c2); err != nil { - done <- err - } - } - }() - - c1, err := l1.Accept() - if err != nil { - t.Fatal("failed to accept", err) - } - - // if secure, need to read + write, as that's what triggers the handshake. - if secure { - if err := sayHello(c1); err != nil { - done <- err - } - } - - if err := <-done; err != nil { - t.Fatal(err) - } - - return c1.(iconn.Conn), c2, p1, p2 -} - -func sayHello(c net.Conn) error { - h := []byte("hello") - if _, err := c.Write(h); err != nil { - return err - } - if _, err := c.Read(h); err != nil { - return err - } - if string(h) != "hello" { - return fmt.Errorf("did not get hello") - } - return nil -} - -func testDialer(t *testing.T, secure bool) { - // t.Skip("Skipping in favor of another test") - - p1 := tu.RandPeerNetParamsOrFatal(t) - p2 := tu.RandPeerNetParamsOrFatal(t) - - key1 := p1.PrivKey - key2 := p2.PrivKey - if !secure { - key1 = nil - key2 = nil - t.Log("testing insecurely") - } else { - t.Log("testing securely") - } - - ctx, cancel := context.WithCancel(context.Background()) - l1, err := Listen(ctx, p1.Addr, p1.ID, key1) - if err != nil { - t.Fatal(err) - } - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - d2 := &Dialer{ - LocalPeer: p2.ID, - PrivateKey: key2, - } - d2.AddDialer(dialer(t, p2.Addr)) - - go echoListen(ctx, l1) - - c, err := d2.Dial(ctx, p1.Addr, p1.ID) - if err != nil { - t.Fatal("error dialing peer", err) - } - - // fmt.Println("sending") - mc := msgioWrap(c) - mc.WriteMsg([]byte("beep")) - mc.WriteMsg([]byte("boop")) - out, err := mc.ReadMsg() - if err != nil { - t.Fatal(err) - } - - // fmt.Println("recving", string(out)) - data := string(out) - if data != "beep" { - t.Error("unexpected conn output", data) - } - - out, err = mc.ReadMsg() - if err != nil { - t.Fatal(err) - } - - data = string(out) - if string(out) != "boop" { - t.Error("unexpected conn output", data) - } - - // fmt.Println("closing") - c.Close() - l1.Close() - cancel() -} - -func TestDialerInsecure(t *testing.T) { - // t.Skip("Skipping in favor of another test") - testDialer(t, false) -} - -func TestDialerSecure(t *testing.T) { - // t.Skip("Skipping in favor of another test") - testDialer(t, true) -} - -func testDialerCloseEarly(t *testing.T, secure bool) { - // t.Skip("Skipping in favor of another test") - - p1 := tu.RandPeerNetParamsOrFatal(t) - p2 := tu.RandPeerNetParamsOrFatal(t) - - key1 := p1.PrivKey - if !secure { - key1 = nil - t.Log("testing insecurely") - } else { - t.Log("testing securely") - } - - ctx, cancel := context.WithCancel(context.Background()) - l1, err := Listen(ctx, p1.Addr, p1.ID, key1) - if err != nil { - t.Fatal(err) - } - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - // lol nesting - d2 := &Dialer{ - LocalPeer: p2.ID, - PrivateKey: p2.PrivKey, //-- dont give it key. we'll just close the conn. - } - d2.AddDialer(dialer(t, p2.Addr)) - - errs := make(chan error, 100) - done := make(chan struct{}, 1) - gotclosed := make(chan struct{}, 1) - go func() { - defer func() { done <- struct{}{} }() - - c, err := l1.Accept() - if err != nil { - if strings.Contains(err.Error(), "closed") { - gotclosed <- struct{}{} - return - } - errs <- err - } - - if _, err := c.Write([]byte("hello")); err != nil { - gotclosed <- struct{}{} - return - } - - errs <- fmt.Errorf("wrote to conn") - }() - - c, err := d2.Dial(ctx, p1.Addr, p1.ID) - if err != nil { - t.Fatal(err) - } - c.Close() // close it early. - - readerrs := func() { - for { - select { - case e := <-errs: - t.Error(e) - default: - return - } - } - } - readerrs() - - l1.Close() - <-done - cancel() - readerrs() - close(errs) - - select { - case <-gotclosed: - default: - t.Error("did not get closed") - } -} - -// we dont do a handshake with singleConn, so cant "close early." -// func TestDialerCloseEarlyInsecure(t *testing.T) { -// // t.Skip("Skipping in favor of another test") -// testDialerCloseEarly(t, false) -// } - -func TestDialerCloseEarlySecure(t *testing.T) { - // t.Skip("Skipping in favor of another test") - testDialerCloseEarly(t, true) -} - -func TestMultistreamHeader(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - p1 := tu.RandPeerNetParamsOrFatal(t) - - l1, err := Listen(ctx, p1.Addr, p1.ID, p1.PrivKey) - if err != nil { - t.Fatal(err) - } - - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - go func() { - _, _ = l1.Accept() - }() - - con, err := net.Dial("tcp", l1.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer con.Close() - - err = msmux.SelectProtoOrFail(SecioTag, con) - if err != nil { - t.Fatal(err) - } -} - -func TestFailedAccept(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - p1 := tu.RandPeerNetParamsOrFatal(t) - - l1, err := Listen(ctx, p1.Addr, p1.ID, p1.PrivKey) - if err != nil { - t.Fatal(err) - } - - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - done := make(chan struct{}) - go func() { - defer close(done) - con, err := net.Dial("tcp", l1.Addr().String()) - if err != nil { - t.Error("first dial failed: ", err) - } - - // write some garbage - con.Write(bytes.Repeat([]byte{255}, 1000)) - - con.Close() - - con, err = net.Dial("tcp", l1.Addr().String()) - if err != nil { - t.Error("second dial failed: ", err) - } - defer con.Close() - - err = msmux.SelectProtoOrFail(SecioTag, con) - if err != nil { - t.Error("msmux select failed: ", err) - } - }() - - c, err := l1.Accept() - if err != nil { - t.Fatal("connections after a failed accept should still work: ", err) - } - - c.Close() - <-done -} - -func TestHangingAccept(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - p1 := tu.RandPeerNetParamsOrFatal(t) - - l1, err := Listen(ctx, p1.Addr, p1.ID, p1.PrivKey) - if err != nil { - t.Fatal(err) - } - - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - done := make(chan struct{}) - go func() { - defer close(done) - con, err := net.Dial("tcp", l1.Addr().String()) - if err != nil { - t.Error("first dial failed: ", err) - } - // hang this connection - defer con.Close() - - // ensure that the first conn hits first - time.Sleep(time.Millisecond * 50) - - con2, err := net.Dial("tcp", l1.Addr().String()) - if err != nil { - t.Error("second dial failed: ", err) - } - defer con2.Close() - - err = msmux.SelectProtoOrFail(SecioTag, con2) - if err != nil { - t.Error("msmux select failed: ", err) - } - - _, err = con2.Write([]byte("test")) - if err != nil { - t.Error("con write failed: ", err) - } - }() - - c, err := l1.Accept() - if err != nil { - t.Fatal("connections after a failed accept should still work: ", err) - } - - c.Close() - <-done -} - -// This test kicks off N (=300) concurrent dials, which wait d (=20ms) seconds before failing. -// That wait holds up the handshake (multistream AND crypto), which will happen BEFORE -// l1.Accept() returns a connection. This test checks that the handshakes all happen -// concurrently in the listener side, and not sequentially. This ensures that a hanging dial -// will not block the listener from accepting other dials concurrently. -func TestConcurrentAccept(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - p1 := tu.RandPeerNetParamsOrFatal(t) - - l1, err := Listen(ctx, p1.Addr, p1.ID, p1.PrivKey) - if err != nil { - t.Fatal(err) - } - - n := 300 - delay := time.Millisecond * 20 - if runtime.GOOS == "darwin" { - n = 100 - } - - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - var wg sync.WaitGroup - for i := 0; i < n; i++ { - wg.Add(1) - go func() { - defer wg.Done() - con, err := net.Dial("tcp", l1.Addr().String()) - if err != nil { - log.Error(err) - t.Error("first dial failed: ", err) - return - } - // hang this connection - defer con.Close() - - time.Sleep(delay) - err = msmux.SelectProtoOrFail(SecioTag, con) - if err != nil { - t.Error(err) - } - }() - } - - before := time.Now() - for i := 0; i < n; i++ { - c, err := l1.Accept() - if err != nil { - t.Fatal("connections after a failed accept should still work: ", err) - } - - c.Close() - } - - limit := delay * time.Duration(n) - took := time.Since(before) - if took > limit { - t.Fatal("took too long!") - } - log.Errorf("took: %s (less than %s)", took, limit) - l1.Close() - wg.Wait() - cancel() - - time.Sleep(time.Millisecond * 100) - - err = grc.CheckForLeaks(goroFilter) - if err != nil { - t.Fatal(err) - } -} - -func TestConnectionTimeouts(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - old := NegotiateReadTimeout - NegotiateReadTimeout = time.Second * 5 - defer func() { NegotiateReadTimeout = old }() - - p1 := tu.RandPeerNetParamsOrFatal(t) - - l1, err := Listen(ctx, p1.Addr, p1.ID, p1.PrivKey) - if err != nil { - t.Fatal(err) - } - - n := 100 - if runtime.GOOS == "darwin" { - n = 50 - } - - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - var wg sync.WaitGroup - for i := 0; i < n; i++ { - wg.Add(1) - go func() { - defer wg.Done() - con, err := net.Dial("tcp", l1.Addr().String()) - if err != nil { - log.Error(err) - t.Error("first dial failed: ", err) - return - } - defer con.Close() - - // hang this connection until timeout - io.ReadFull(con, make([]byte, 1000)) - }() - } - - // wait to make sure the hanging dials have started - time.Sleep(time.Millisecond * 50) - - good_n := 20 - for i := 0; i < good_n; i++ { - wg.Add(1) - go func() { - defer wg.Done() - con, err := net.Dial("tcp", l1.Addr().String()) - if err != nil { - log.Error(err) - t.Error("first dial failed: ", err) - return - } - defer con.Close() - - // dial these ones through - err = msmux.SelectProtoOrFail(SecioTag, con) - if err != nil { - t.Error(err) - } - }() - } - - before := time.Now() - for i := 0; i < good_n; i++ { - c, err := l1.Accept() - if err != nil { - t.Fatal("connections during hung dials should still work: ", err) - } - - c.Close() - } - - took := time.Since(before) - - if took > time.Second*5 { - t.Fatal("hanging dials shouldnt block good dials") - } - - wg.Wait() - - go func() { - con, err := net.Dial("tcp", l1.Addr().String()) - if err != nil { - log.Error(err) - t.Error("first dial failed: ", err) - return - } - defer con.Close() - - // dial these ones through - err = msmux.SelectProtoOrFail(SecioTag, con) - if err != nil { - t.Error(err) - } - }() - - // make sure we can dial in still after a bunch of timeouts - con, err := l1.Accept() - if err != nil { - t.Fatal(err) - } - - con.Close() - l1.Close() - cancel() - - time.Sleep(time.Millisecond * 100) - - err = grc.CheckForLeaks(goroFilter) - if err != nil { - t.Fatal(err) - } -} - -func TestForcePNet(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ipnet.ForcePrivateNetwork = true - defer func() { - ipnet.ForcePrivateNetwork = false - }() - - p := tu.RandPeerNetParamsOrFatal(t) - list, err := tcpt.NewTCPTransport().Listen(p.Addr) - if err != nil { - t.Fatal(err) - } - - _, err = WrapTransportListenerWithProtector(ctx, list, p.ID, p.PrivKey, nil) - if err != ipnet.ErrNotInPrivateNetwork { - t.Fatal("Wrong error, expected error lack of protector") - } -} - -type fakeProtector struct { - used bool -} - -func (f *fakeProtector) Fingerprint() []byte { - return make([]byte, 32) -} - -func (f *fakeProtector) Protect(c transport.Conn) (transport.Conn, error) { - f.used = true - return &rot13Crypt{c}, nil -} - -type rot13Crypt struct { - transport.Conn -} - -func (r *rot13Crypt) Read(b []byte) (int, error) { - n, err := r.Conn.Read(b) - if err != nil { - return n, err - } - - for i, _ := range b { - b[i] = byte((uint8(b[i]) - 13) & 0xff) - } - return n, err -} - -func (r *rot13Crypt) Write(b []byte) (int, error) { - for i, _ := range b { - b[i] = byte((uint8(b[i]) + 13) & 0xff) - } - return r.Conn.Write(b) -} - -func TestPNetIsUsed(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - p1 := tu.RandPeerNetParamsOrFatal(t) - p2 := tu.RandPeerNetParamsOrFatal(t) - - p1Protec := &fakeProtector{} - - list, err := tcpt.NewTCPTransport().Listen(p1.Addr) - if err != nil { - t.Fatal(err) - } - - l1, err := WrapTransportListenerWithProtector(ctx, list, p1.ID, p1.PrivKey, p1Protec) - if err != nil { - t.Fatal(err) - } - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - d2 := NewDialer(p2.ID, p2.PrivKey, nil) - d2.Protector = &fakeProtector{} - - d2.AddDialer(dialer(t, p2.Addr)) - _, err = d2.Dial(ctx, p1.Addr, p1.ID) - if err != nil { - t.Fatal(err) - } - - _, err = l1.Accept() - if err != nil { - t.Fatal(err) - } - - if !p1Protec.used { - t.Error("Listener did not use protector for the connection") - } - - if !d2.Protector.(*fakeProtector).used { - t.Error("Dialer did not use protector for the connection") - } -} diff --git a/listen_test.go b/listen_test.go new file mode 100644 index 0000000..31b20ae --- /dev/null +++ b/listen_test.go @@ -0,0 +1,225 @@ +package conn + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "sync" + "time" + + tpt "github.com/libp2p/go-libp2p-transport" + filter "github.com/libp2p/go-maddr-filter" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Listener", func() { + Context("accepting connections", func() { + for _, val := range transportTypes { + tr := val + + Context(fmt.Sprintf("using a %s", tr), func() { + It("returns immediately when the context is cancelled", func() { + p1 := randPeerNetParams(tr) + ctx, cancel := context.WithCancel(context.Background()) + l := getListener(ctx, p1) + + accepted := make(chan struct{}) + go func() { + _, _ = l.Accept() + close(accepted) + }() + Consistently(accepted).ShouldNot(BeClosed()) + cancel() + Eventually(accepted).Should(BeClosed()) + }) + + It("returns immediately when it is closed", func() { + p1 := randPeerNetParams(tr) + l := getListener(context.Background(), p1) + + accepted := make(chan struct{}) + go func() { + _, _ = l.Accept() + close(accepted) + }() + Consistently(accepted).ShouldNot(BeClosed()) + l.Close() + Eventually(accepted).Should(BeClosed()) + }) + + It("continues accepting connections after one accept failed", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1 := randPeerNetParams(tr) + p2 := randPeerNetParams(tr) + + l1 := getListener(ctx, p1) + defer l1.Close() + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + c := dialRawConn(p2.Addr, l1.Multiaddr()) + defer c.Close() + var conn io.ReadWriteCloser + switch tr { + case singleStreamTransport: + conn = c.(tpt.SingleStreamConn) + case multiStreamTransport: + var err error + conn, err = c.(tpt.MultiStreamConn).OpenStream() + Expect(err).ToNot(HaveOccurred()) + } + // write some garbage. This will fail the protocol selection + _, err := conn.Write(bytes.Repeat([]byte{255}, 1000)) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + + accepted := make(chan struct{}) + go func() { + defer GinkgoRecover() + c, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + c.Close() + close(accepted) + }() + + // make sure it doesn't accept the raw connection + Eventually(done).Should(BeClosed()) + Consistently(accepted).ShouldNot(BeClosed()) + + // now dial the real connection, and make sure it is accepted + d := getDialer(p2.ID, p2.PrivKey, p2.Addr) + _, err := d.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + + Eventually(accepted).Should(BeClosed()) + }) + + // This test kicks off N (=10) concurrent dials, which wait d (=20ms) seconds before failing. + // That wait holds up the handshake (multistream AND crypto), which will happen BEFORE + // l1.Accept() returns a connection. This test checks that the handshakes all happen + // concurrently in the listener side, and not sequentially. This ensures that a hanging dial + // will not block the listener from accepting other dials concurrently. + It("accepts concurrently", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1 := randPeerNetParams(tr) + p2 := randPeerNetParams(tr) + + l1 := getListener(ctx, p1) + defer l1.Close() + + n := 10 + delay := 50 * time.Millisecond + + accepted := make(chan struct{}) + go func() { + defer GinkgoRecover() + for i := 0; i < n; i++ { + conn, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + defer conn.Close() + } + close(accepted) + }() + + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + defer GinkgoRecover() + defer wg.Done() + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + d2.Wrapper = func(c tpt.Conn) tpt.Conn { + time.Sleep(delay) + return c + } + before := time.Now() + _, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + // make sure the delay actually worked + Expect(time.Now()).To(BeTemporally(">", before.Add(delay))) + }() + } + + wg.Wait() + // the Eventually timeout is 100ms, which is a lot smaller than n*delay = 500ms + Eventually(accepted).Should(BeClosed()) + }) + + Context("address filters", func() { + It("doesn't accept connections from filtered addresses", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1 := randPeerNetParams(tr) + p2 := randPeerNetParams(tr) + + filt := filter.NewFilters() + _, ipnet, err := net.ParseCIDR("127.0.1.2/16") + Expect(err).ToNot(HaveOccurred()) + filt.AddDialFilter(ipnet) + Expect(filt.AddrBlocked(p2.Addr)).To(BeTrue()) + + l := getListener(ctx, p1) + defer l.Close() + l.SetAddrFilters(filt) + + accepted := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, _ = l.Accept() + close(accepted) + }() + + d := getDialer(p2.ID, p2.PrivKey, p2.Addr) + _, err = d.Dial(ctx, p1.Addr, p1.ID) + Expect(err).To(HaveOccurred()) + Eventually(accepted).ShouldNot(BeClosed()) + }) + + It("accepts connections from addresses that are not filtered", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1 := randPeerNetParams(tr) + p2 := randPeerNetParams(tr) + + filt := filter.NewFilters() + _, ipnet, err := net.ParseCIDR("192.168.1.2/16") + Expect(err).ToNot(HaveOccurred()) + filt.AddDialFilter(ipnet) + Expect(filt.AddrBlocked(p2.Addr)).To(BeFalse()) + + l := getListener(ctx, p1) + defer l.Close() + l.SetAddrFilters(filt) + + accepted := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := l.Accept() + Expect(err).ToNot(HaveOccurred()) + close(accepted) + }() + + d := getDialer(p2.ID, p2.PrivKey, p2.Addr) + c2, err := d.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + defer c2.Close() + Eventually(accepted).Should(BeClosed()) + time.Sleep(time.Second) + }) + }) + }) + } + }) +}) diff --git a/protector_test.go b/protector_test.go new file mode 100644 index 0000000..55669f8 --- /dev/null +++ b/protector_test.go @@ -0,0 +1,266 @@ +package conn + +import ( + "context" + "errors" + + iconn "github.com/libp2p/go-libp2p-interface-conn" + ipnet "github.com/libp2p/go-libp2p-interface-pnet" + tpt "github.com/libp2p/go-libp2p-transport" + smux "github.com/libp2p/go-stream-muxer" + tcpt "github.com/libp2p/go-tcp-transport" + tu "github.com/libp2p/go-testutil" + quict "github.com/marten-seemann/libp2p-quic-transport" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type fakeSingleStreamProtector struct { + used bool +} + +func (f *fakeSingleStreamProtector) Fingerprint() []byte { + return make([]byte, 32) +} + +func (f *fakeSingleStreamProtector) Protect(c tpt.Conn) (tpt.Conn, error) { + f.used = true + return &rot13CryptSingleStream{c.(tpt.SingleStreamConn)}, nil +} + +type rot13CryptSingleStream struct { + tpt.SingleStreamConn +} + +func (r *rot13CryptSingleStream) Read(b []byte) (int, error) { + n, err := r.SingleStreamConn.Read(b) + for i := 0; i < n; i++ { + b[i] = b[i] - 13 + } + return n, err +} + +func (r *rot13CryptSingleStream) Write(b []byte) (int, error) { + p := make([]byte, len(b)) // write MUST NOT modify b + for i := range b { + p[i] = b[i] + 13 + } + return r.SingleStreamConn.Write(p) +} + +type fakeMultiStreamProtector struct { + used bool + crypt *rot13CryptMultiStream +} + +func (f *fakeMultiStreamProtector) Fingerprint() []byte { + return make([]byte, 32) +} + +func (f *fakeMultiStreamProtector) Protect(c tpt.Conn) (tpt.Conn, error) { + f.used = true + f.crypt = &rot13CryptMultiStream{c.(tpt.MultiStreamConn), 0, 0} + return f.crypt, nil +} + +type rot13CryptMultiStream struct { + tpt.MultiStreamConn + openedStreams int + acceptedStreams int +} + +func (r *rot13CryptMultiStream) OpenStream() (smux.Stream, error) { + r.openedStreams++ + str, err := r.MultiStreamConn.OpenStream() + return &rot13Stream{str}, err +} + +func (r *rot13CryptMultiStream) AcceptStream() (smux.Stream, error) { + r.acceptedStreams++ + str, err := r.MultiStreamConn.AcceptStream() + return &rot13Stream{str}, err +} + +type rot13Stream struct { + smux.Stream +} + +var errProtect = errors.New("protecting failed") + +type erroringProtector struct{} + +func (f *erroringProtector) Fingerprint() []byte { + return make([]byte, 32) +} + +func (f *erroringProtector) Protect(c tpt.Conn) (tpt.Conn, error) { + return nil, errProtect +} + +func (r *rot13Stream) Read(b []byte) (int, error) { + n, err := r.Stream.Read(b) + for i := 0; i < n; i++ { + b[i] = b[i] - 13 + } + return n, err +} + +func (r *rot13Stream) Write(b []byte) (int, error) { + p := make([]byte, len(b)) // write MUST NOT modify b + for i := range b { + p[i] = b[i] + 13 + } + return r.Stream.Write(p) +} + +var _ = Describe("using the protector", func() { + It("uses a protector for single-stream connections", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1 := randPeerNetParams(singleStreamTransport) + p2 := randPeerNetParams(singleStreamTransport) + p1Protec := &fakeSingleStreamProtector{} + p2Protec := &fakeSingleStreamProtector{} + + list, err := tcpt.NewTCPTransport().Listen(p1.Addr) + Expect(err).ToNot(HaveOccurred()) + l1, err := WrapTransportListenerWithProtector(ctx, list, p1.ID, p1.PrivKey, streamMuxer, p1Protec) + Expect(err).ToNot(HaveOccurred()) + p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. + + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + d2.Protector = p2Protec + + accepted := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + close(accepted) + }() + + c2, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + defer c2.Close() + + Expect(p2Protec.used).To(BeTrue()) + Eventually(accepted).Should(BeClosed()) + Expect(p1Protec.used).To(BeTrue()) + }) + + // TODO: enable this test when adding support for multi-stream connections + PIt("uses a protector for multi-stream connections", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1 := randPeerNetParams(multiStreamTransport) + p2 := randPeerNetParams(multiStreamTransport) + p1Protec := &fakeMultiStreamProtector{} + p2Protec := &fakeMultiStreamProtector{} + + list, err := quict.NewQuicTransport().Listen(p1.Addr) + Expect(err).ToNot(HaveOccurred()) + l1, err := WrapTransportListenerWithProtector(ctx, list, p1.ID, p1.PrivKey, streamMuxer, p1Protec) + Expect(err).ToNot(HaveOccurred()) + p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. + + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + d2.Protector = p2Protec + + var c1 iconn.Conn + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + var err error + c1, err = l1.Accept() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + + c2, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + defer c2.Close() + + Expect(p2Protec.used).To(BeTrue()) + <-done + Expect(p1Protec.used).To(BeTrue()) + + Expect(p1Protec.crypt.acceptedStreams).To(Equal(2)) + Expect(p1Protec.crypt.openedStreams).To(BeZero()) + Expect(p2Protec.crypt.openedStreams).To(Equal(2)) + Expect(p2Protec.crypt.acceptedStreams).To(BeZero()) + + str1, err := c1.OpenStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str1.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + _, err = c2.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + Expect(p1Protec.crypt.acceptedStreams).To(Equal(2)) + Expect(p1Protec.crypt.openedStreams).To(Equal(1)) + Expect(p2Protec.crypt.openedStreams).To(Equal(2)) + Expect(p2Protec.crypt.acceptedStreams).To(Equal(1)) + }) + + Context("forcing a private network", func() { + var p1, p2 *tu.PeerNetParams + var list tpt.Listener + + BeforeEach(func() { + ipnet.ForcePrivateNetwork = true + p1 = randPeerNetParams(singleStreamTransport) + p2 = randPeerNetParams(singleStreamTransport) + var err error + list, err = tcpt.NewTCPTransport().Listen(p1.Addr) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + ipnet.ForcePrivateNetwork = false + }) + + It("errors if no protector is specified for the listener", func() { + _, err := WrapTransportListenerWithProtector(context.Background(), list, p1.ID, p1.PrivKey, streamMuxer, nil) + Expect(err).To(Equal(ipnet.ErrNotInPrivateNetwork)) + }) + + It("errors if no protector is specified for the dialer", func() { + d := getDialer(p2.ID, p2.PrivKey, p2.Addr) + _, err := d.Dial(context.Background(), list.Multiaddr(), p1.ID) + Expect(err).To(Equal(ipnet.ErrNotInPrivateNetwork)) + }) + }) + + It("correctly handles a protected that errors", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1 := randPeerNetParams(singleStreamTransport) + p2 := randPeerNetParams(singleStreamTransport) + p1Protec := &erroringProtector{} + p2Protec := &erroringProtector{} + + list, err := tcpt.NewTCPTransport().Listen(p1.Addr) + Expect(err).ToNot(HaveOccurred()) + l1, err := WrapTransportListenerWithProtector(ctx, list, p1.ID, p1.PrivKey, streamMuxer, p1Protec) + Expect(err).ToNot(HaveOccurred()) + p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. + + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + d2.Protector = p2Protec + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, _ = l1.Accept() + close(done) + }() + + _, err = d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).To(MatchError(errProtect)) + // make sure no connection was accepted + Consistently(done).ShouldNot(BeClosed()) + }) +}) From a23e5a83a8930fb8ed623c97e81149235165e614 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 16 Aug 2017 13:33:38 +0700 Subject: [PATCH 06/10] introduce a timeout for dialing and accepting connections This removes the timeout for the multistream selection, and implements a timeout covering the whole dialing and accepting process. Timeouts can occur at multiple places during connection establishments. --- conn_test.go | 12 +++--- dial.go | 111 ++++++++++++++++++++++++-------------------------- dial_test.go | 67 ++++++++++++++++++++++++++++++ listen.go | 113 ++++++++++++++++++++++++++++----------------------- 4 files changed, 189 insertions(+), 114 deletions(-) create mode 100644 dial_test.go diff --git a/conn_test.go b/conn_test.go index 077376c..2fa133b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -112,9 +112,9 @@ var _ = Describe("Connections", func() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - old := NegotiateReadTimeout - NegotiateReadTimeout = 3 * time.Second - defer func() { NegotiateReadTimeout = old }() + old := ConnAcceptTimeout + ConnAcceptTimeout = 3 * time.Second + defer func() { ConnAcceptTimeout = old }() p1 := randPeerNetParams(tr) p2 := randPeerNetParams(tr) @@ -168,12 +168,12 @@ var _ = Describe("Connections", func() { Expect(err).ToNot(HaveOccurred()) } close(accepted) - Expect(time.Now()).To(BeTemporally("<", before.Add(NegotiateReadTimeout/4))) + Expect(time.Now()).To(BeTemporally("<", before.Add(ConnAcceptTimeout/4))) Eventually(func() bool { wg.Wait() // wait for the timeouts for the raw connections to occur return true - }, NegotiateReadTimeout).Should(BeTrue()) - Expect(time.Now()).To(BeTemporally(">", before.Add(NegotiateReadTimeout))) + }, ConnAcceptTimeout).Should(BeTrue()) + Expect(time.Now()).To(BeTemporally(">", before.Add(ConnAcceptTimeout))) // make sure we can dial in still after a bunch of timeouts done := make(chan struct{}) diff --git a/dial.go b/dial.go index 307dc3c..c71e2c8 100644 --- a/dial.go +++ b/dial.go @@ -18,14 +18,21 @@ import ( msmux "github.com/multiformats/go-multistream" ) -type WrapFunc func(tpt.Conn) tpt.Conn +// DialTimeout is the maximum duration a Dial is allowed to take. +// This includes the time between dialing the raw network connection, +// protocol selection as well the handshake, if applicable. +var DialTimeout = 60 * time.Second -type timeoutReadWriteCloser interface { - io.ReadWriteCloser - SetDeadline(time.Time) error - SetReadDeadline(time.Time) error - SetWriteDeadline(time.Time) error -} +// dialTimeoutErr occurs when the DialTimeout is exceeded. +type dialTimeoutErr struct{} + +func (dialTimeoutErr) Error() string { return "deadline exceeded" } +func (dialTimeoutErr) Temporary() bool { return true } +func (dialTimeoutErr) Timeout() bool { return true } + +// The WrapFunc is used to wrap a tpt.Conn. +// It must not block. +type WrapFunc func(tpt.Conn) tpt.Conn // Dialer is an object that can open connections. We could have a "convenience" // Dial function as before, but it would have many arguments, as dialing is @@ -78,49 +85,47 @@ func (d *Dialer) String() string { // Example: d.DialAddr(ctx, peer.Addresses()[0], peer) func (d *Dialer) Dial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) (iconn.Conn, error) { logdial := lgbl.Dial("conn", d.LocalPeer, remote, nil, raddr) + defer log.EventBegin(ctx, "connDial", logdial).Done() logdial["encrypted"] = (d.PrivateKey != nil) // log wether this will be an encrypted dial or not. logdial["inPrivNet"] = (d.Protector != nil) - defer log.EventBegin(ctx, "connDial", logdial).Done() - if d.Protector == nil && ipnet.ForcePrivateNetwork { log.Error("tried to dial with no Private Network Protector but usage" + " of Private Networks is forced by the enviroment") return nil, ipnet.ErrNotInPrivateNetwork } - var connOut iconn.Conn - var errOut error - done := make(chan struct{}) - - // do it async to ensure we respect done context - go func() { - defer func() { - select { - case done <- struct{}{}: - case <-ctx.Done(): - } - }() + c, err := d.doDial(ctx, raddr, remote) + if err != nil { + logdial["error"] = err.Error() + logdial["dial"] = "failure" + return nil, err + } + logdial["dial"] = "success" + return c, nil +} - tptConn, err := d.rawConnDial(ctx, raddr, remote) - if err != nil { - errOut = err - return - } +func (d *Dialer) doDial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) (iconn.Conn, error) { + rawConn, err := d.rawConnDial(ctx, raddr, remote) + if err != nil { + return nil, err + } + done := make(chan connOrErr, 1) + // do it async to ensure we respect the context + go func() { if d.Protector != nil { var pconn tpt.Conn - pconn, err = d.Protector.Protect(tptConn) + pconn, err = d.Protector.Protect(rawConn) if err != nil { - tptConn.Close() - errOut = err + done <- connOrErr{err: err} return } - tptConn = pconn + rawConn = pconn } if d.Wrapper != nil { - tptConn = d.Wrapper(tptConn) + rawConn = d.Wrapper(rawConn) } cryptoProtoChoice := SecioTag @@ -128,56 +133,48 @@ func (d *Dialer) Dial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) ( cryptoProtoChoice = NoEncryptionTag } - var stream timeoutReadWriteCloser - switch con := tptConn.(type) { + var stream io.ReadWriteCloser + switch con := rawConn.(type) { case tpt.SingleStreamConn: stream = con case tpt.MultiStreamConn: stream, err = con.OpenStream() if err != nil { - errOut = err + done <- connOrErr{err: err} return } defer stream.Close() } - stream.SetReadDeadline(time.Now().Add(NegotiateReadTimeout)) - err = msmux.SelectProtoOrFail(cryptoProtoChoice, stream) - if err != nil { - errOut = err + if err := msmux.SelectProtoOrFail(cryptoProtoChoice, stream); err != nil { + done <- connOrErr{err: err} return } - // clear deadline - stream.SetReadDeadline(time.Time{}) - - c, err := newSingleConn(ctx, d.LocalPeer, remote, d.PrivateKey, tptConn, d.streamMuxer, false) + c, err := newSingleConn(ctx, d.LocalPeer, remote, d.PrivateKey, rawConn, d.streamMuxer, false) if err != nil { - tptConn.Close() - errOut = err + done <- connOrErr{err: err} return } - connOut = c + done <- connOrErr{conn: c} }() + var res connOrErr select { case <-ctx.Done(): - logdial["error"] = ctx.Err().Error() - logdial["dial"] = "failure" + rawConn.Close() return nil, ctx.Err() - case <-done: - // whew, finished. - } - - if errOut != nil { - logdial["error"] = errOut.Error() - logdial["dial"] = "failure" - return nil, errOut + case <-time.After(DialTimeout): + rawConn.Close() + return nil, &dialTimeoutErr{} + case res = <-done: + if res.err != nil { + rawConn.Close() + } } - logdial["dial"] = "success" - return connOut, nil + return res.conn, res.err } func (d *Dialer) AddDialer(pd tpt.Dialer) { diff --git a/dial_test.go b/dial_test.go new file mode 100644 index 0000000..7fbfb29 --- /dev/null +++ b/dial_test.go @@ -0,0 +1,67 @@ +package conn + +import ( + "context" + "fmt" + "net" + "time" + + ma "github.com/multiformats/go-multiaddr" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("dialing", func() { + It("errors when it can't dial the raw connection", func() { + p := randPeerNetParams(singleStreamTransport) + d := getDialer(p.ID, p.PrivKey, p.Addr) + raddr, err := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/0") + Expect(err).ToNot(HaveOccurred()) + _, err = d.Dial(context.Background(), raddr, p.ID) + Expect(err).To(HaveOccurred()) + }) + + for _, val := range transportTypes { + tr := val + Context(fmt.Sprintf("using a %s", tr), func() { + It("returns immediately when the context is canceled", func() { + p1 := randPeerNetParams(tr) + tptList, err := getTransport(p1.Addr).Listen(p1.Addr) + Expect(err).ToNot(HaveOccurred()) + defer tptList.Close() + + dialed := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + defer GinkgoRecover() + p2 := randPeerNetParams(tr) + d := getDialer(p2.ID, p2.PrivKey, p2.Addr) + _, err = d.Dial(ctx, tptList.Multiaddr(), p2.ID) + Expect(err).To(MatchError(context.Canceled)) + close(dialed) + }() + Consistently(dialed).ShouldNot(BeClosed()) + cancel() + Eventually(dialed).Should(BeClosed()) + }) + + It("times out during multistream selection", func() { + old := DialTimeout + DialTimeout = time.Second + defer func() { DialTimeout = old }() + + p1 := randPeerNetParams(tr) + p2 := randPeerNetParams(tr) + tptList, err := getTransport(p1.Addr).Listen(p1.Addr) + Expect(err).ToNot(HaveOccurred()) + defer tptList.Close() + + d := getDialer(p2.ID, p2.PrivKey, p2.Addr) + _, err = d.Dial(context.Background(), tptList.Multiaddr(), p2.ID) + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) + Expect(err.(net.Error).Temporary()).To(BeTrue()) + }) + }) + } +}) diff --git a/listen.go b/listen.go index 25bad18..baaafcb 100644 --- a/listen.go +++ b/listen.go @@ -2,6 +2,7 @@ package conn import ( "context" + "errors" "fmt" "io" "net" @@ -28,8 +29,8 @@ const ( ) var ( - connAcceptBuffer = 32 - NegotiateReadTimeout = time.Second * 60 + connAcceptBuffer = 32 + ConnAcceptTimeout = 60 * time.Second ) // ConnWrapper is any function that wraps a raw multiaddr connection @@ -80,31 +81,23 @@ func (l *listener) SetAddrFilters(fs *filter.Filters) { } type connOrErr struct { - conn tpt.Conn + conn iconn.Conn err error } // Accept waits for and returns the next connection to the listener. func (l *listener) Accept() (iconn.Conn, error) { - for con := range l.incoming { - if con.err != nil { - return nil, con.err - } - tptConn := con.conn + if l.privk == nil || !iconn.EncryptConnections { + log.Warningf("listener %s listening INSECURELY!", l) + } - if l.privk == nil || !iconn.EncryptConnections { - log.Warningf("listener %s listening INSECURELY!", l) + for c := range l.incoming { + if c.err != nil { + return nil, c.err } - - c, err := newSingleConn(l.ctx, l.local, "", l.privk, tptConn, l.streamMuxer, true) - if err != nil { - tptConn.Close() - continue - } - - return c, nil + return c.conn, nil } - return nil, fmt.Errorf("listener is closed") + return nil, errors.New("listener is closed") } func (l *listener) Addr() net.Addr { @@ -150,63 +143,81 @@ func (l *listener) handleIncoming() { if l.catcher.IsTemporary(err) { continue } - l.incoming <- connOrErr{err: err} return } - log.Debugf("listener %s got connection: %s <---> %s", l, conn.LocalMultiaddr(), conn.RemoteMultiaddr()) - if l.filters != nil && l.filters.AddrBlocked(conn.RemoteMultiaddr()) { log.Debugf("blocked connection from %s", conn.RemoteMultiaddr()) conn.Close() continue } + log.Debugf("listener %s got connection: %s <---> %s", l, conn.LocalMultiaddr(), conn.RemoteMultiaddr()) + wg.Add(1) go func() { defer wg.Done() - if l.protec != nil { - pc, err := l.protec.Protect(conn) - if err != nil { + + ctx, cancel := context.WithTimeout(l.ctx, ConnAcceptTimeout) + defer cancel() + + done := make(chan struct{}) + go func() { + defer close(done) + + if l.protec != nil { + pc, err := l.protec.Protect(conn) + if err != nil { + conn.Close() + log.Warning("protector failed: ", err) + return + } + conn = pc + } + + // If we have a wrapper func, wrap this conn + if l.wrapper != nil { + conn = l.wrapper(conn) + } + + var stream io.ReadWriteCloser + switch conn := conn.(type) { + case tpt.SingleStreamConn: + stream = conn + case tpt.MultiStreamConn: + stream, err = conn.AcceptStream() + if err != nil { + conn.Close() + log.Warning("accepting stream failed: ", err) + return + } + defer stream.Close() + } + + if _, _, err := l.mux.Negotiate(stream); err != nil { + log.Warning("incoming conn: negotiation of crypto protocol failed: ", err) conn.Close() - log.Warning("protector failed: ", err) return } - conn = pc - } - - // If we have a wrapper func, wrap this conn - if l.wrapper != nil { - conn = l.wrapper(conn) - } - var stream timeoutReadWriteCloser - switch conn := conn.(type) { - case tpt.SingleStreamConn: - stream = conn - case tpt.MultiStreamConn: - stream, err = conn.AcceptStream() + c, err := newSingleConn(ctx, l.local, "", l.privk, conn, l.streamMuxer, true) if err != nil { + log.Warning("connection setup failed: ", err) conn.Close() - log.Warning("accepting stream failed: ", err) return } - defer stream.Close() - } - // TODO: should the negotiate timeout include the time taken by AcceptStream? - stream.SetReadDeadline(time.Now().Add(NegotiateReadTimeout)) - _, _, err = l.mux.Negotiate(stream) - if err != nil { - log.Warning("incoming conn: negotiation of crypto protocol failed: ", err) + l.incoming <- connOrErr{conn: c} + }() + + select { + case <-ctx.Done(): + log.Warning("incoming conn: conn not established in time:", ctx.Err().Error()) conn.Close() return + case <-done: // connection completed (or errored) } - // clear read readline - stream.SetReadDeadline(time.Time{}) - - l.incoming <- connOrErr{conn: conn} }() } } From dff3ff64dda3aa213612871e14a613d7c1b49cbe Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 6 Sep 2017 13:05:21 +0200 Subject: [PATCH 07/10] use the renamed transport interfaces --- conn.go | 8 +++--- conn_suite_test.go | 14 +++++----- conn_test.go | 10 +++---- dial.go | 4 +-- dial_test.go | 2 +- listen.go | 4 +-- listen_test.go | 8 +++--- protector_test.go | 70 +++++++++++++++++++++++----------------------- secure_conn.go | 32 ++++++++++----------- 9 files changed, 76 insertions(+), 76 deletions(-) diff --git a/conn.go b/conn.go index d71cca8..4361bb1 100644 --- a/conn.go +++ b/conn.go @@ -39,7 +39,7 @@ func newSingleConn(ctx context.Context, local, remote peer.ID, privKey ci.PrivKe var streamConn smux.Conn var secSession secio.Session switch conn := tptConn.(type) { - case tpt.SingleStreamConn: + case tpt.DuplexConn: c := conn // 1. secure the connection if privKey != nil && iconn.EncryptConnections { @@ -49,8 +49,8 @@ func newSingleConn(ctx context.Context, local, remote peer.ID, privKey ci.PrivKe return nil, err } c = &secureDuplexConn{ - SingleStreamConn: conn, - secure: secSession, + insecure: conn, + secure: secSession, } } else { log.Warning("creating INSECURE connection %s at %s", tptConn.LocalMultiaddr(), tptConn.RemoteMultiaddr()) @@ -62,7 +62,7 @@ func newSingleConn(ctx context.Context, local, remote peer.ID, privKey ci.PrivKe if err != nil { return nil, err } - case tpt.MultiStreamConn: + case tpt.MultiplexConn: panic("not implemented yet") } diff --git a/conn_suite_test.go b/conn_suite_test.go index ee91944..0a8580f 100644 --- a/conn_suite_test.go +++ b/conn_suite_test.go @@ -43,17 +43,17 @@ var streamMuxer = yamux.DefaultTransport type transportType uint8 const ( - singleStreamTransport transportType = 1 + iota - multiStreamTransport + duplexTransport transportType = 1 + iota + multiplexTransport ) -var transportTypes = []transportType{singleStreamTransport} +var transportTypes = []transportType{duplexTransport} func (t transportType) String() string { - if t == multiStreamTransport { - return "multi-stream transport" + if t == duplexTransport { + return "duplex transport" } - return "single-stream transport" + return "multiplex transport" } // dialRawConn dials a tpt.Conn @@ -107,7 +107,7 @@ func getDialer(localPeer peer.ID, privKey ci.PrivKey, addr ma.Multiaddr) *Dialer func randPeerNetParams(tr transportType) *tu.PeerNetParams { p, err := tu.RandPeerNetParams() Expect(err).ToNot(HaveOccurred()) - if tr == multiStreamTransport { + if tr == multiplexTransport { p.Addr, err = ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") Expect(err).ToNot(HaveOccurred()) } diff --git a/conn_test.go b/conn_test.go index 2fa133b..8cdfb6d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -20,7 +20,7 @@ var _ = Describe("Connections", func() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - p1 := randPeerNetParams(singleStreamTransport) + p1 := randPeerNetParams(duplexTransport) l1 := getListener(ctx, p1) defer l1.Close() go l1.Accept() @@ -135,11 +135,11 @@ var _ = Describe("Connections", func() { c := dialRawConn(p2.Addr, l1.Multiaddr()) defer c.Close() switch tr { - case singleStreamTransport: - conn = c.(tpt.SingleStreamConn) - case multiStreamTransport: + case duplexTransport: + conn = c.(tpt.DuplexConn) + case multiplexTransport: var err error - conn, err = c.(tpt.MultiStreamConn).OpenStream() + conn, err = c.(tpt.MultiplexConn).OpenStream() Expect(err).ToNot(HaveOccurred()) } // hang this connection until timeout diff --git a/dial.go b/dial.go index c71e2c8..65907b9 100644 --- a/dial.go +++ b/dial.go @@ -135,9 +135,9 @@ func (d *Dialer) doDial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) var stream io.ReadWriteCloser switch con := rawConn.(type) { - case tpt.SingleStreamConn: + case tpt.DuplexConn: stream = con - case tpt.MultiStreamConn: + case tpt.MultiplexConn: stream, err = con.OpenStream() if err != nil { done <- connOrErr{err: err} diff --git a/dial_test.go b/dial_test.go index 7fbfb29..d98c24a 100644 --- a/dial_test.go +++ b/dial_test.go @@ -13,7 +13,7 @@ import ( var _ = Describe("dialing", func() { It("errors when it can't dial the raw connection", func() { - p := randPeerNetParams(singleStreamTransport) + p := randPeerNetParams(duplexTransport) d := getDialer(p.ID, p.PrivKey, p.Addr) raddr, err := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/0") Expect(err).ToNot(HaveOccurred()) diff --git a/listen.go b/listen.go index baaafcb..7c4e8a9 100644 --- a/listen.go +++ b/listen.go @@ -183,9 +183,9 @@ func (l *listener) handleIncoming() { var stream io.ReadWriteCloser switch conn := conn.(type) { - case tpt.SingleStreamConn: + case tpt.DuplexConn: stream = conn - case tpt.MultiStreamConn: + case tpt.MultiplexConn: stream, err = conn.AcceptStream() if err != nil { conn.Close() diff --git a/listen_test.go b/listen_test.go index 31b20ae..03bd343 100644 --- a/listen_test.go +++ b/listen_test.go @@ -68,11 +68,11 @@ var _ = Describe("Listener", func() { defer c.Close() var conn io.ReadWriteCloser switch tr { - case singleStreamTransport: - conn = c.(tpt.SingleStreamConn) - case multiStreamTransport: + case duplexTransport: + conn = c.(tpt.DuplexConn) + case multiplexTransport: var err error - conn, err = c.(tpt.MultiStreamConn).OpenStream() + conn, err = c.(tpt.MultiplexConn).OpenStream() Expect(err).ToNot(HaveOccurred()) } // write some garbage. This will fail the protocol selection diff --git a/protector_test.go b/protector_test.go index 55669f8..5eb7b3a 100644 --- a/protector_test.go +++ b/protector_test.go @@ -16,69 +16,69 @@ import ( . "github.com/onsi/gomega" ) -type fakeSingleStreamProtector struct { +type fakeDuplexProtector struct { used bool } -func (f *fakeSingleStreamProtector) Fingerprint() []byte { +func (f *fakeDuplexProtector) Fingerprint() []byte { return make([]byte, 32) } -func (f *fakeSingleStreamProtector) Protect(c tpt.Conn) (tpt.Conn, error) { +func (f *fakeDuplexProtector) Protect(c tpt.Conn) (tpt.Conn, error) { f.used = true - return &rot13CryptSingleStream{c.(tpt.SingleStreamConn)}, nil + return &rot13CryptDuplex{c.(tpt.DuplexConn)}, nil } -type rot13CryptSingleStream struct { - tpt.SingleStreamConn +type rot13CryptDuplex struct { + tpt.DuplexConn } -func (r *rot13CryptSingleStream) Read(b []byte) (int, error) { - n, err := r.SingleStreamConn.Read(b) +func (r *rot13CryptDuplex) Read(b []byte) (int, error) { + n, err := r.DuplexConn.Read(b) for i := 0; i < n; i++ { b[i] = b[i] - 13 } return n, err } -func (r *rot13CryptSingleStream) Write(b []byte) (int, error) { +func (r *rot13CryptDuplex) Write(b []byte) (int, error) { p := make([]byte, len(b)) // write MUST NOT modify b for i := range b { p[i] = b[i] + 13 } - return r.SingleStreamConn.Write(p) + return r.DuplexConn.Write(p) } -type fakeMultiStreamProtector struct { +type fakeMultiplexProtector struct { used bool - crypt *rot13CryptMultiStream + crypt *rot13CryptMultiplex } -func (f *fakeMultiStreamProtector) Fingerprint() []byte { +func (f *fakeMultiplexProtector) Fingerprint() []byte { return make([]byte, 32) } -func (f *fakeMultiStreamProtector) Protect(c tpt.Conn) (tpt.Conn, error) { +func (f *fakeMultiplexProtector) Protect(c tpt.Conn) (tpt.Conn, error) { f.used = true - f.crypt = &rot13CryptMultiStream{c.(tpt.MultiStreamConn), 0, 0} + f.crypt = &rot13CryptMultiplex{c.(tpt.MultiplexConn), 0, 0} return f.crypt, nil } -type rot13CryptMultiStream struct { - tpt.MultiStreamConn +type rot13CryptMultiplex struct { + tpt.MultiplexConn openedStreams int acceptedStreams int } -func (r *rot13CryptMultiStream) OpenStream() (smux.Stream, error) { +func (r *rot13CryptMultiplex) OpenStream() (smux.Stream, error) { r.openedStreams++ - str, err := r.MultiStreamConn.OpenStream() + str, err := r.MultiplexConn.OpenStream() return &rot13Stream{str}, err } -func (r *rot13CryptMultiStream) AcceptStream() (smux.Stream, error) { +func (r *rot13CryptMultiplex) AcceptStream() (smux.Stream, error) { r.acceptedStreams++ - str, err := r.MultiStreamConn.AcceptStream() + str, err := r.MultiplexConn.AcceptStream() return &rot13Stream{str}, err } @@ -119,10 +119,10 @@ var _ = Describe("using the protector", func() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - p1 := randPeerNetParams(singleStreamTransport) - p2 := randPeerNetParams(singleStreamTransport) - p1Protec := &fakeSingleStreamProtector{} - p2Protec := &fakeSingleStreamProtector{} + p1 := randPeerNetParams(duplexTransport) + p2 := randPeerNetParams(duplexTransport) + p1Protec := &fakeDuplexProtector{} + p2Protec := &fakeDuplexProtector{} list, err := tcpt.NewTCPTransport().Listen(p1.Addr) Expect(err).ToNot(HaveOccurred()) @@ -150,15 +150,15 @@ var _ = Describe("using the protector", func() { Expect(p1Protec.used).To(BeTrue()) }) - // TODO: enable this test when adding support for multi-stream connections - PIt("uses a protector for multi-stream connections", func() { + // TODO: enable this test when adding support for multiplex connections + PIt("uses a protector for multiplex connections", func() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - p1 := randPeerNetParams(multiStreamTransport) - p2 := randPeerNetParams(multiStreamTransport) - p1Protec := &fakeMultiStreamProtector{} - p2Protec := &fakeMultiStreamProtector{} + p1 := randPeerNetParams(multiplexTransport) + p2 := randPeerNetParams(multiplexTransport) + p1Protec := &fakeMultiplexProtector{} + p2Protec := &fakeMultiplexProtector{} list, err := quict.NewQuicTransport().Listen(p1.Addr) Expect(err).ToNot(HaveOccurred()) @@ -210,8 +210,8 @@ var _ = Describe("using the protector", func() { BeforeEach(func() { ipnet.ForcePrivateNetwork = true - p1 = randPeerNetParams(singleStreamTransport) - p2 = randPeerNetParams(singleStreamTransport) + p1 = randPeerNetParams(duplexTransport) + p2 = randPeerNetParams(duplexTransport) var err error list, err = tcpt.NewTCPTransport().Listen(p1.Addr) Expect(err).ToNot(HaveOccurred()) @@ -237,8 +237,8 @@ var _ = Describe("using the protector", func() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - p1 := randPeerNetParams(singleStreamTransport) - p2 := randPeerNetParams(singleStreamTransport) + p1 := randPeerNetParams(duplexTransport) + p2 := randPeerNetParams(duplexTransport) p1Protec := &erroringProtector{} p2Protec := &erroringProtector{} diff --git a/secure_conn.go b/secure_conn.go index ec0513a..b753459 100644 --- a/secure_conn.go +++ b/secure_conn.go @@ -9,54 +9,54 @@ import ( ma "github.com/multiformats/go-multiaddr" ) -// secureSingleStreamConn wraps another SingleStreamConn object with an encrypted channel. -type secureSingleStreamConn struct { - insecure tpt.SingleStreamConn // the wrapped conn - secure secio.Session // secure Session +// secureDuplexConn wraps another DuplexConn object with an encrypted channel. +type secureDuplexConn struct { + insecure tpt.DuplexConn // the wrapped conn + secure secio.Session // secure Session } -var _ tpt.SingleStreamConn = &secureSingleStreamConn{} +var _ tpt.DuplexConn = &secureDuplexConn{} -func (c *secureSingleStreamConn) Read(buf []byte) (int, error) { +func (c *secureDuplexConn) Read(buf []byte) (int, error) { return c.secure.ReadWriter().Read(buf) } -func (c *secureSingleStreamConn) Write(buf []byte) (int, error) { +func (c *secureDuplexConn) Write(buf []byte) (int, error) { return c.secure.ReadWriter().Write(buf) } -func (c *secureSingleStreamConn) Close() error { +func (c *secureDuplexConn) Close() error { return c.secure.Close() } -func (c *secureSingleStreamConn) LocalAddr() net.Addr { +func (c *secureDuplexConn) LocalAddr() net.Addr { return c.insecure.LocalAddr() } -func (c *secureSingleStreamConn) LocalMultiaddr() ma.Multiaddr { +func (c *secureDuplexConn) LocalMultiaddr() ma.Multiaddr { return c.insecure.LocalMultiaddr() } -func (c *secureSingleStreamConn) RemoteAddr() net.Addr { +func (c *secureDuplexConn) RemoteAddr() net.Addr { return c.insecure.RemoteAddr() } -func (c *secureSingleStreamConn) RemoteMultiaddr() ma.Multiaddr { +func (c *secureDuplexConn) RemoteMultiaddr() ma.Multiaddr { return c.insecure.RemoteMultiaddr() } -func (c *secureSingleStreamConn) SetDeadline(t time.Time) error { +func (c *secureDuplexConn) SetDeadline(t time.Time) error { return c.insecure.SetDeadline(t) } -func (c *secureSingleStreamConn) SetReadDeadline(t time.Time) error { +func (c *secureDuplexConn) SetReadDeadline(t time.Time) error { return c.insecure.SetDeadline(t) } -func (c *secureSingleStreamConn) SetWriteDeadline(t time.Time) error { +func (c *secureDuplexConn) SetWriteDeadline(t time.Time) error { return c.insecure.SetDeadline(t) } -func (c *secureSingleStreamConn) Transport() tpt.Transport { +func (c *secureDuplexConn) Transport() tpt.Transport { return c.insecure.Transport() } From b2606fa2d8244641c7fc339b930551fac0fb468c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 19 Oct 2017 07:39:14 +0700 Subject: [PATCH 08/10] remove the MultiplexConn type assertions --- conn.go | 59 ++--- conn_suite_test.go | 39 +--- conn_test.go | 528 ++++++++++++++++++++++----------------------- dial.go | 16 +- dial_test.go | 78 ++++--- listen.go | 16 +- listen_test.go | 385 ++++++++++++++++----------------- protector_test.go | 146 ++----------- secure_conn.go | 32 +-- 9 files changed, 536 insertions(+), 763 deletions(-) diff --git a/conn.go b/conn.go index 4361bb1..e589fb0 100644 --- a/conn.go +++ b/conn.go @@ -38,43 +38,39 @@ func newSingleConn(ctx context.Context, local, remote peer.ID, privKey ci.PrivKe var streamConn smux.Conn var secSession secio.Session - switch conn := tptConn.(type) { - case tpt.DuplexConn: - c := conn - // 1. secure the connection - if privKey != nil && iconn.EncryptConnections { - var err error - secSession, err = setupSecureSession(ctx, local, privKey, conn) - if err != nil { - return nil, err - } - c = &secureDuplexConn{ - insecure: conn, - secure: secSession, - } - } else { - log.Warning("creating INSECURE connection %s at %s", tptConn.LocalMultiaddr(), tptConn.RemoteMultiaddr()) - } - // 2. start stream multipling + c := tptConn + // 1. secure the connection + if privKey != nil && iconn.EncryptConnections { var err error - streamConn, err = pstpt.NewConn(c, isServer) + secSession, err = setupSecureSession(ctx, local, privKey, tptConn) if err != nil { return nil, err } - case tpt.MultiplexConn: - panic("not implemented yet") + c = &secureConn{ + insecure: tptConn, + secure: secSession, + } + } else { + log.Warning("creating INSECURE connection %s at %s", tptConn.LocalMultiaddr(), tptConn.RemoteMultiaddr()) } - conn := &singleConn{ + // 2. start stream multipling + var err error + streamConn, err = pstpt.NewConn(c, isServer) + if err != nil { + return nil, err + } + + sconn := &singleConn{ streamConn: streamConn, tptConn: tptConn, secSession: secSession, event: log.EventBegin(ctx, "connLifetime", ml), } - log.Debugf("newSingleConn %p: %v to %v", conn, local, remote) - return conn, nil + log.Debugf("newSingleConn %p: %v to %v", sconn, local, remote) + return sconn, nil } func setupSecureSession(ctx context.Context, local peer.ID, privKey ci.PrivKey, ch io.ReadWriteCloser) (secio.Session, error) { @@ -88,20 +84,7 @@ func setupSecureSession(ctx context.Context, local peer.ID, privKey ci.PrivKey, LocalID: local, PrivateKey: privKey, } - secSession, err := sessgen.NewSession(ctx, ch) - if err != nil { - return nil, err - } - // force the handshake right now - // TODO: find a better solution for this - b := []byte("handshake") - if _, err := secSession.ReadWriter().Write(b); err != nil { - return nil, err - } - if _, err := io.ReadFull(secSession.ReadWriter(), b); err != nil { - return nil, err - } - return secSession, nil + return sessgen.NewSession(ctx, ch) } // close is the internal close function, called by ContextCloser.Close diff --git a/conn_suite_test.go b/conn_suite_test.go index 0a8580f..99cdf20 100644 --- a/conn_suite_test.go +++ b/conn_suite_test.go @@ -12,11 +12,9 @@ import ( tpt "github.com/libp2p/go-libp2p-transport" tcpt "github.com/libp2p/go-tcp-transport" tu "github.com/libp2p/go-testutil" - quict "github.com/marten-seemann/libp2p-quic-transport" ma "github.com/multiformats/go-multiaddr" yamux "github.com/whyrusleeping/go-smux-yamux" grc "github.com/whyrusleeping/gorocheck" - "github.com/whyrusleeping/mafmt" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -40,35 +38,11 @@ var _ = AfterEach(func() { // the stream muxer used for tests using the single stream connection var streamMuxer = yamux.DefaultTransport -type transportType uint8 - -const ( - duplexTransport transportType = 1 + iota - multiplexTransport -) - -var transportTypes = []transportType{duplexTransport} - -func (t transportType) String() string { - if t == duplexTransport { - return "duplex transport" - } - return "multiplex transport" -} - // dialRawConn dials a tpt.Conn // but it stops there. It doesn't do protocol selection and handshake func dialRawConn(laddr, raddr ma.Multiaddr) tpt.Conn { - var d tpt.Dialer - if mafmt.QUIC.Matches(laddr) { - var err error - d, err = quict.NewQuicTransport().Dialer(laddr) - Expect(err).ToNot(HaveOccurred()) - } else { - var err error - d, err = tcpt.NewTCPTransport().Dialer(laddr) - Expect(err).ToNot(HaveOccurred()) - } + d, err := tcpt.NewTCPTransport().Dialer(laddr) + Expect(err).ToNot(HaveOccurred()) c, err := d.Dial(raddr) Expect(err).ToNot(HaveOccurred()) return c @@ -76,9 +50,6 @@ func dialRawConn(laddr, raddr ma.Multiaddr) tpt.Conn { // getTransport gets the right transport for a multiaddr func getTransport(a ma.Multiaddr) tpt.Transport { - if mafmt.QUIC.Matches(a) { - return quict.NewQuicTransport() - } return tcpt.NewTCPTransport() } @@ -104,12 +75,8 @@ func getDialer(localPeer peer.ID, privKey ci.PrivKey, addr ma.Multiaddr) *Dialer // randPeerNetParams works like testutil.RandPeerNetParams // if called for a multi-stream transport, it replaces the address with a QUIC address -func randPeerNetParams(tr transportType) *tu.PeerNetParams { +func randPeerNetParams() *tu.PeerNetParams { p, err := tu.RandPeerNetParams() Expect(err).ToNot(HaveOccurred()) - if tr == multiplexTransport { - p.Addr, err = ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") - Expect(err).ToNot(HaveOccurred()) - } return p } diff --git a/conn_test.go b/conn_test.go index 8cdfb6d..a96ddab 100644 --- a/conn_test.go +++ b/conn_test.go @@ -11,7 +11,6 @@ import ( . "github.com/onsi/gomega" iconn "github.com/libp2p/go-libp2p-interface-conn" - tpt "github.com/libp2p/go-libp2p-transport" smux "github.com/libp2p/go-stream-muxer" ) @@ -20,309 +19,294 @@ var _ = Describe("Connections", func() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - p1 := randPeerNetParams(duplexTransport) + p1 := randPeerNetParams() l1 := getListener(ctx, p1) defer l1.Close() go l1.Accept() }) - for _, val := range transportTypes { - tr := val + for _, val := range []bool{true, false} { + secure := val - Context(fmt.Sprintf("using a %s", tr), func() { - for _, val := range []bool{true, false} { - secure := val + It(fmt.Sprintf("establishes a connection (secure: %t)", secure), func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - It(fmt.Sprintf("establishes a connection (secure: %t)", secure), func() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - p1 := randPeerNetParams(tr) - p2 := randPeerNetParams(tr) - if !secure { - p1.PrivKey = nil - p2.PrivKey = nil - } - - l1 := getListener(ctx, p1) - defer l1.Close() - - // accept a connection, accept a stream on this connection and echo everything - go func() { - defer GinkgoRecover() - c, err := l1.Accept() - Expect(err).ToNot(HaveOccurred()) - str, err := c.AcceptStream() - Expect(err).ToNot(HaveOccurred()) - go io.Copy(str, str) - }() - - d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) - c, err := d2.Dial(ctx, p1.Addr, p1.ID) - Expect(err).ToNot(HaveOccurred()) - defer c.Close() - str, err := c.OpenStream() - Expect(err).ToNot(HaveOccurred()) - _, err = str.Write([]byte("beep")) - Expect(err).ToNot(HaveOccurred()) - _, err = str.Write([]byte("boop")) - Expect(err).ToNot(HaveOccurred()) - - out := make([]byte, 8) - _, err = io.ReadFull(str, out) - Expect(err).ToNot(HaveOccurred()) - Expect(out).To(Equal([]byte("beepboop"))) - }) + p1 := randPeerNetParams() + p2 := randPeerNetParams() + if !secure { + p1.PrivKey = nil + p2.PrivKey = nil } - It("continues accepting connections while another accept is hanging", func() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + l1 := getListener(ctx, p1) + defer l1.Close() - p1 := randPeerNetParams(tr) - p2 := randPeerNetParams(tr) + // accept a connection, accept a stream on this connection and echo everything + go func() { + defer GinkgoRecover() + c, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + str, err := c.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + go io.Copy(str, str) + }() + + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + c, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + defer c.Close() + str, err := c.OpenStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("beep")) + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("boop")) + Expect(err).ToNot(HaveOccurred()) + + out := make([]byte, 8) + _, err = io.ReadFull(str, out) + Expect(err).ToNot(HaveOccurred()) + Expect(out).To(Equal([]byte("beepboop"))) + }) + } - l1 := getListener(ctx, p1) - defer l1.Close() + It("continues accepting connections while another accept is hanging", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - go func() { - defer GinkgoRecover() - conn := dialRawConn(p2.Addr, l1.Multiaddr()) - defer conn.Close() // hang this connection + p1 := randPeerNetParams() + p2 := randPeerNetParams() - // ensure that the first conn hits first - time.Sleep(50 * time.Millisecond) - d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) - conn2, err := d2.Dial(ctx, p1.Addr, p1.ID) - Expect(err).ToNot(HaveOccurred()) - defer conn2.Close() - }() + l1 := getListener(ctx, p1) + defer l1.Close() - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := l1.Accept() - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - Eventually(done).Should(BeClosed()) - }) + go func() { + defer GinkgoRecover() + conn := dialRawConn(p2.Addr, l1.Multiaddr()) + defer conn.Close() // hang this connection + + // ensure that the first conn hits first + time.Sleep(50 * time.Millisecond) + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + conn2, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + defer conn2.Close() + }() + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + Eventually(done).Should(BeClosed()) + }) - It("timeouts", func() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - old := ConnAcceptTimeout - ConnAcceptTimeout = 3 * time.Second - defer func() { ConnAcceptTimeout = old }() - - p1 := randPeerNetParams(tr) - p2 := randPeerNetParams(tr) - - l1 := getListener(ctx, p1) - defer l1.Close() - - n := 20 - - before := time.Now() - var wg sync.WaitGroup - for i := 0; i < n; i++ { - wg.Add(1) - go func() { - defer GinkgoRecover() - defer wg.Done() - var conn io.Reader - c := dialRawConn(p2.Addr, l1.Multiaddr()) - defer c.Close() - switch tr { - case duplexTransport: - conn = c.(tpt.DuplexConn) - case multiplexTransport: - var err error - conn, err = c.(tpt.MultiplexConn).OpenStream() - Expect(err).ToNot(HaveOccurred()) - } - // hang this connection until timeout - io.ReadFull(conn, make([]byte, 1000)) - }() - } + It("timeouts", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - // wait to make sure the hanging dials have started - time.Sleep(50 * time.Millisecond) + old := ConnAcceptTimeout + ConnAcceptTimeout = 3 * time.Second + defer func() { ConnAcceptTimeout = old }() - accepted := make(chan struct{}) // this chan is closed once all good connections have been accepted - goodN := 10 - for i := 0; i < goodN; i++ { - go func(i int) { - defer GinkgoRecover() - d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) - conn, err := d2.Dial(ctx, p1.Addr, p1.ID) - Expect(err).ToNot(HaveOccurred()) - <-accepted - conn.Close() - }(i) - } + p1 := randPeerNetParams() + p2 := randPeerNetParams() - for i := 0; i < goodN; i++ { - _, err := l1.Accept() - Expect(err).ToNot(HaveOccurred()) - } - close(accepted) - Expect(time.Now()).To(BeTemporally("<", before.Add(ConnAcceptTimeout/4))) - Eventually(func() bool { - wg.Wait() // wait for the timeouts for the raw connections to occur - return true - }, ConnAcceptTimeout).Should(BeTrue()) - Expect(time.Now()).To(BeTemporally(">", before.Add(ConnAcceptTimeout))) - - // make sure we can dial in still after a bunch of timeouts - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := l1.Accept() - Expect(err).ToNot(HaveOccurred()) - close(done) - }() + l1 := getListener(ctx, p1) + defer l1.Close() + n := 20 + + before := time.Now() + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + defer GinkgoRecover() + defer wg.Done() + c := dialRawConn(p2.Addr, l1.Multiaddr()) + defer c.Close() + // hang this connection until timeout + io.ReadFull(c, make([]byte, 1000)) + }() + } + + // wait to make sure the hanging dials have started + time.Sleep(50 * time.Millisecond) + + accepted := make(chan struct{}) // this chan is closed once all good connections have been accepted + goodN := 10 + for i := 0; i < goodN; i++ { + go func(i int) { + defer GinkgoRecover() d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) conn, err := d2.Dial(ctx, p1.Addr, p1.ID) Expect(err).ToNot(HaveOccurred()) - defer conn.Close() - Eventually(done).Should(BeClosed()) - }) + <-accepted + conn.Close() + }(i) + } + + for i := 0; i < goodN; i++ { + _, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + } + close(accepted) + Expect(time.Now()).To(BeTemporally("<", before.Add(ConnAcceptTimeout/4))) + Eventually(func() bool { + wg.Wait() // wait for the timeouts for the raw connections to occur + return true + }, ConnAcceptTimeout).Should(BeTrue()) + Expect(time.Now()).To(BeTemporally(">", before.Add(ConnAcceptTimeout))) + + // make sure we can dial in still after a bunch of timeouts + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + conn, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + defer conn.Close() + Eventually(done).Should(BeClosed()) + }) - It("doesn't complete the handshake with the wrong keys", func() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - p1 := randPeerNetParams(tr) - p2 := randPeerNetParams(tr) - - l1 := getListener(ctx, p1) - defer l1.Close() - - // use the wrong private key here, correct would be: p2.PrivKey - d2 := getDialer(p2.ID, p1.PrivKey, p2.Addr) - - accepted := make(chan struct{}) - go func() { - l1.Accept() - close(accepted) - }() - - _, err := d2.Dial(ctx, p1.Addr, p1.ID) - Expect(err).To(MatchError("peer.ID does not match PrivateKey")) - // make sure no connection was accepted - Consistently(accepted).ShouldNot(BeClosed()) - }) - - Context("closing", func() { - setupConn := func(ctx context.Context, tr transportType) (iconn.Conn, iconn.Conn) { - p1 := randPeerNetParams(tr) - p2 := randPeerNetParams(tr) - - l1 := getListener(ctx, p1) - - var c2 iconn.Conn - d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) - done := make(chan error) - go func() { - defer GinkgoRecover() - var err error - c2, err = d2.Dial(ctx, p1.Addr, p1.ID) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() + It("doesn't complete the handshake with the wrong keys", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - c1, err := l1.Accept() - Expect(err).ToNot(HaveOccurred()) - Eventually(done).Should(BeClosed()) - return c1, c2 - } + p1 := randPeerNetParams() + p2 := randPeerNetParams() - openStreamAndSend := func(c1, c2 iconn.Conn) { - str1, err := c1.OpenStream() - Expect(err).ToNot(HaveOccurred()) - m1 := []byte("hello") - _, err = str1.Write(m1) - Expect(err).ToNot(HaveOccurred()) - str2, err := c2.AcceptStream() - Expect(err).ToNot(HaveOccurred()) - m2 := make([]byte, len(m1)) - _, err = str2.Read(m2) - Expect(err).ToNot(HaveOccurred()) - Expect(m1).To(Equal(m2)) - } + l1 := getListener(ctx, p1) + defer l1.Close() - checkStreamOpenAcceptFails := func(c1, c2 iconn.Conn) { - _, err := c1.OpenStream() - Expect(err).To(HaveOccurred()) - accepted := make(chan struct{}) - go func() { - _, err := c2.AcceptStream() - Expect(err).To(HaveOccurred()) - close(accepted) - }() - Eventually(accepted).Should(BeClosed()) - } + // use the wrong private key here, correct would be: p2.PrivKey + d2 := getDialer(p2.ID, p1.PrivKey, p2.Addr) - It("closes", func() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + accepted := make(chan struct{}) + go func() { + l1.Accept() + close(accepted) + }() - c1, c2 := setupConn(ctx, tr) - openStreamAndSend(c1, c2) - openStreamAndSend(c2, c1) + _, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).To(MatchError("peer.ID does not match PrivateKey")) + // make sure no connection was accepted + Consistently(accepted).ShouldNot(BeClosed()) + }) - c1.Close() - Expect(c1.IsClosed()).To(BeTrue()) - Eventually(c2.IsClosed).Should(BeTrue()) - checkStreamOpenAcceptFails(c2, c1) - checkStreamOpenAcceptFails(c1, c2) - }) - - It("doesn't leak", func() { - // runPair opens one stream and sends num messages - runPair := func(c1, c2 iconn.Conn, num int) { - var str2 smux.Stream - str1, err := c1.OpenStream() - Expect(err).ToNot(HaveOccurred()) + Context("closing", func() { + setupConn := func(ctx context.Context) (iconn.Conn, iconn.Conn) { + p1 := randPeerNetParams() + p2 := randPeerNetParams() - for i := 0; i < num; i++ { - b1 := []byte("beep") - _, err := str1.Write(b1) - Expect(err).ToNot(HaveOccurred()) - if str2 == nil { - str2, err = c2.AcceptStream() - Expect(err).ToNot(HaveOccurred()) - } - b2 := make([]byte, len(b1)) - _, err = str2.Read(b2) - Expect(err).ToNot(HaveOccurred()) - Expect(b1).To(Equal(b2)) - } - } + l1 := getListener(ctx, p1) + + var c2 iconn.Conn + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + done := make(chan error) + go func() { + defer GinkgoRecover() + var err error + c2, err = d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + + c1, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + Eventually(done).Should(BeClosed()) + return c1, c2 + } + + openStreamAndSend := func(c1, c2 iconn.Conn) { + str1, err := c1.OpenStream() + Expect(err).ToNot(HaveOccurred()) + m1 := []byte("hello") + _, err = str1.Write(m1) + Expect(err).ToNot(HaveOccurred()) + str2, err := c2.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + m2 := make([]byte, len(m1)) + _, err = str2.Read(m2) + Expect(err).ToNot(HaveOccurred()) + Expect(m1).To(Equal(m2)) + } + + checkStreamOpenAcceptFails := func(c1, c2 iconn.Conn) { + _, err := c1.OpenStream() + Expect(err).To(HaveOccurred()) + accepted := make(chan struct{}) + go func() { + _, err := c2.AcceptStream() + Expect(err).To(HaveOccurred()) + close(accepted) + }() + Eventually(accepted).Should(BeClosed()) + } + + It("closes", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c1, c2 := setupConn(ctx) + openStreamAndSend(c1, c2) + openStreamAndSend(c2, c1) + + c1.Close() + Expect(c1.IsClosed()).To(BeTrue()) + Eventually(c2.IsClosed).Should(BeTrue()) + checkStreamOpenAcceptFails(c2, c1) + checkStreamOpenAcceptFails(c1, c2) + }) - var cons = 10 - var msgs = 10 - var wg sync.WaitGroup - for i := 0; i < cons; i++ { - wg.Add(1) - ctx, cancel := context.WithCancel(context.Background()) - c1, c2 := setupConn(ctx, tr) - go func(c1, c2 iconn.Conn) { - defer GinkgoRecover() - defer cancel() - runPair(c1, c2, msgs) - c1.Close() - c2.Close() - wg.Done() - }(c1, c2) + It("doesn't leak", func() { + // runPair opens one stream and sends num messages + runPair := func(c1, c2 iconn.Conn, num int) { + var str2 smux.Stream + str1, err := c1.OpenStream() + Expect(err).ToNot(HaveOccurred()) + + for i := 0; i < num; i++ { + b1 := []byte("beep") + _, err := str1.Write(b1) + Expect(err).ToNot(HaveOccurred()) + if str2 == nil { + str2, err = c2.AcceptStream() + Expect(err).ToNot(HaveOccurred()) } + b2 := make([]byte, len(b1)) + _, err = str2.Read(b2) + Expect(err).ToNot(HaveOccurred()) + Expect(b1).To(Equal(b2)) + } + } + + var cons = 10 + var msgs = 10 + var wg sync.WaitGroup + for i := 0; i < cons; i++ { + wg.Add(1) + ctx, cancel := context.WithCancel(context.Background()) + c1, c2 := setupConn(ctx) + go func(c1, c2 iconn.Conn) { + defer GinkgoRecover() + defer cancel() + runPair(c1, c2, msgs) + c1.Close() + c2.Close() + wg.Done() + }(c1, c2) + } - wg.Wait() - }) - }) + wg.Wait() }) - } + }) }) diff --git a/dial.go b/dial.go index 65907b9..bc11f38 100644 --- a/dial.go +++ b/dial.go @@ -3,7 +3,6 @@ package conn import ( "context" "fmt" - "io" "strings" "time" @@ -133,20 +132,7 @@ func (d *Dialer) doDial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) cryptoProtoChoice = NoEncryptionTag } - var stream io.ReadWriteCloser - switch con := rawConn.(type) { - case tpt.DuplexConn: - stream = con - case tpt.MultiplexConn: - stream, err = con.OpenStream() - if err != nil { - done <- connOrErr{err: err} - return - } - defer stream.Close() - } - - if err := msmux.SelectProtoOrFail(cryptoProtoChoice, stream); err != nil { + if err := msmux.SelectProtoOrFail(cryptoProtoChoice, rawConn); err != nil { done <- connOrErr{err: err} return } diff --git a/dial_test.go b/dial_test.go index d98c24a..0eb5155 100644 --- a/dial_test.go +++ b/dial_test.go @@ -2,7 +2,6 @@ package conn import ( "context" - "fmt" "net" "time" @@ -13,7 +12,7 @@ import ( var _ = Describe("dialing", func() { It("errors when it can't dial the raw connection", func() { - p := randPeerNetParams(duplexTransport) + p := randPeerNetParams() d := getDialer(p.ID, p.PrivKey, p.Addr) raddr, err := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/0") Expect(err).ToNot(HaveOccurred()) @@ -21,47 +20,42 @@ var _ = Describe("dialing", func() { Expect(err).To(HaveOccurred()) }) - for _, val := range transportTypes { - tr := val - Context(fmt.Sprintf("using a %s", tr), func() { - It("returns immediately when the context is canceled", func() { - p1 := randPeerNetParams(tr) - tptList, err := getTransport(p1.Addr).Listen(p1.Addr) - Expect(err).ToNot(HaveOccurred()) - defer tptList.Close() - - dialed := make(chan struct{}) - ctx, cancel := context.WithCancel(context.Background()) - go func() { - defer GinkgoRecover() - p2 := randPeerNetParams(tr) - d := getDialer(p2.ID, p2.PrivKey, p2.Addr) - _, err = d.Dial(ctx, tptList.Multiaddr(), p2.ID) - Expect(err).To(MatchError(context.Canceled)) - close(dialed) - }() - Consistently(dialed).ShouldNot(BeClosed()) - cancel() - Eventually(dialed).Should(BeClosed()) - }) + It("returns immediately when the context is canceled", func() { + p1 := randPeerNetParams() + tptList, err := getTransport(p1.Addr).Listen(p1.Addr) + Expect(err).ToNot(HaveOccurred()) + defer tptList.Close() + + dialed := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + defer GinkgoRecover() + p2 := randPeerNetParams() + d := getDialer(p2.ID, p2.PrivKey, p2.Addr) + _, err = d.Dial(ctx, tptList.Multiaddr(), p2.ID) + Expect(err).To(MatchError(context.Canceled)) + close(dialed) + }() + Consistently(dialed).ShouldNot(BeClosed()) + cancel() + Eventually(dialed).Should(BeClosed()) + }) - It("times out during multistream selection", func() { - old := DialTimeout - DialTimeout = time.Second - defer func() { DialTimeout = old }() + It("times out during multistream selection", func() { + old := DialTimeout + DialTimeout = time.Second + defer func() { DialTimeout = old }() - p1 := randPeerNetParams(tr) - p2 := randPeerNetParams(tr) - tptList, err := getTransport(p1.Addr).Listen(p1.Addr) - Expect(err).ToNot(HaveOccurred()) - defer tptList.Close() + p1 := randPeerNetParams() + p2 := randPeerNetParams() + tptList, err := getTransport(p1.Addr).Listen(p1.Addr) + Expect(err).ToNot(HaveOccurred()) + defer tptList.Close() - d := getDialer(p2.ID, p2.PrivKey, p2.Addr) - _, err = d.Dial(context.Background(), tptList.Multiaddr(), p2.ID) - Expect(err).To(HaveOccurred()) - Expect(err.(net.Error).Timeout()).To(BeTrue()) - Expect(err.(net.Error).Temporary()).To(BeTrue()) - }) - }) - } + d := getDialer(p2.ID, p2.PrivKey, p2.Addr) + _, err = d.Dial(context.Background(), tptList.Multiaddr(), p2.ID) + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) + Expect(err.(net.Error).Temporary()).To(BeTrue()) + }) }) diff --git a/listen.go b/listen.go index 7c4e8a9..36ae19b 100644 --- a/listen.go +++ b/listen.go @@ -181,21 +181,7 @@ func (l *listener) handleIncoming() { conn = l.wrapper(conn) } - var stream io.ReadWriteCloser - switch conn := conn.(type) { - case tpt.DuplexConn: - stream = conn - case tpt.MultiplexConn: - stream, err = conn.AcceptStream() - if err != nil { - conn.Close() - log.Warning("accepting stream failed: ", err) - return - } - defer stream.Close() - } - - if _, _, err := l.mux.Negotiate(stream); err != nil { + if _, _, err := l.mux.Negotiate(conn); err != nil { log.Warning("incoming conn: negotiation of crypto protocol failed: ", err) conn.Close() return diff --git a/listen_test.go b/listen_test.go index 03bd343..7ffffe5 100644 --- a/listen_test.go +++ b/listen_test.go @@ -3,8 +3,6 @@ package conn import ( "bytes" "context" - "fmt" - "io" "net" "sync" "time" @@ -18,208 +16,193 @@ import ( var _ = Describe("Listener", func() { Context("accepting connections", func() { - for _, val := range transportTypes { - tr := val - - Context(fmt.Sprintf("using a %s", tr), func() { - It("returns immediately when the context is cancelled", func() { - p1 := randPeerNetParams(tr) - ctx, cancel := context.WithCancel(context.Background()) - l := getListener(ctx, p1) - - accepted := make(chan struct{}) - go func() { - _, _ = l.Accept() - close(accepted) - }() - Consistently(accepted).ShouldNot(BeClosed()) - cancel() - Eventually(accepted).Should(BeClosed()) - }) - - It("returns immediately when it is closed", func() { - p1 := randPeerNetParams(tr) - l := getListener(context.Background(), p1) - - accepted := make(chan struct{}) - go func() { - _, _ = l.Accept() - close(accepted) - }() - Consistently(accepted).ShouldNot(BeClosed()) - l.Close() - Eventually(accepted).Should(BeClosed()) - }) - - It("continues accepting connections after one accept failed", func() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - p1 := randPeerNetParams(tr) - p2 := randPeerNetParams(tr) - - l1 := getListener(ctx, p1) - defer l1.Close() - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - c := dialRawConn(p2.Addr, l1.Multiaddr()) - defer c.Close() - var conn io.ReadWriteCloser - switch tr { - case duplexTransport: - conn = c.(tpt.DuplexConn) - case multiplexTransport: - var err error - conn, err = c.(tpt.MultiplexConn).OpenStream() - Expect(err).ToNot(HaveOccurred()) - } - // write some garbage. This will fail the protocol selection - _, err := conn.Write(bytes.Repeat([]byte{255}, 1000)) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - - accepted := make(chan struct{}) - go func() { - defer GinkgoRecover() - c, err := l1.Accept() - Expect(err).ToNot(HaveOccurred()) - c.Close() - close(accepted) - }() - - // make sure it doesn't accept the raw connection - Eventually(done).Should(BeClosed()) - Consistently(accepted).ShouldNot(BeClosed()) - - // now dial the real connection, and make sure it is accepted - d := getDialer(p2.ID, p2.PrivKey, p2.Addr) - _, err := d.Dial(ctx, p1.Addr, p1.ID) + It("returns immediately when the context is cancelled", func() { + p1 := randPeerNetParams() + ctx, cancel := context.WithCancel(context.Background()) + l := getListener(ctx, p1) + + accepted := make(chan struct{}) + go func() { + _, _ = l.Accept() + close(accepted) + }() + Consistently(accepted).ShouldNot(BeClosed()) + cancel() + Eventually(accepted).Should(BeClosed()) + }) + + It("returns immediately when it is closed", func() { + p1 := randPeerNetParams() + l := getListener(context.Background(), p1) + + accepted := make(chan struct{}) + go func() { + _, _ = l.Accept() + close(accepted) + }() + Consistently(accepted).ShouldNot(BeClosed()) + l.Close() + Eventually(accepted).Should(BeClosed()) + }) + + It("continues accepting connections after one accept failed", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1 := randPeerNetParams() + p2 := randPeerNetParams() + + l1 := getListener(ctx, p1) + defer l1.Close() + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + c := dialRawConn(p2.Addr, l1.Multiaddr()) + defer c.Close() + // write some garbage. This will fail the protocol selection + _, err := c.Write(bytes.Repeat([]byte{255}, 1000)) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + + accepted := make(chan struct{}) + go func() { + defer GinkgoRecover() + c, err := l1.Accept() + Expect(err).ToNot(HaveOccurred()) + c.Close() + close(accepted) + }() + + // make sure it doesn't accept the raw connection + Eventually(done).Should(BeClosed()) + Consistently(accepted).ShouldNot(BeClosed()) + + // now dial the real connection, and make sure it is accepted + d := getDialer(p2.ID, p2.PrivKey, p2.Addr) + _, err := d.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + + Eventually(accepted).Should(BeClosed()) + }) + + // This test kicks off N (=10) concurrent dials, which wait d (=20ms) seconds before failing. + // That wait holds up the handshake (multistream AND crypto), which will happen BEFORE + // l1.Accept() returns a connection. This test checks that the handshakes all happen + // concurrently in the listener side, and not sequentially. This ensures that a hanging dial + // will not block the listener from accepting other dials concurrently. + It("accepts concurrently", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1 := randPeerNetParams() + p2 := randPeerNetParams() + + l1 := getListener(ctx, p1) + defer l1.Close() + + n := 10 + delay := 50 * time.Millisecond + + accepted := make(chan struct{}) + go func() { + defer GinkgoRecover() + for i := 0; i < n; i++ { + conn, err := l1.Accept() Expect(err).ToNot(HaveOccurred()) - - Eventually(accepted).Should(BeClosed()) - }) - - // This test kicks off N (=10) concurrent dials, which wait d (=20ms) seconds before failing. - // That wait holds up the handshake (multistream AND crypto), which will happen BEFORE - // l1.Accept() returns a connection. This test checks that the handshakes all happen - // concurrently in the listener side, and not sequentially. This ensures that a hanging dial - // will not block the listener from accepting other dials concurrently. - It("accepts concurrently", func() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - p1 := randPeerNetParams(tr) - p2 := randPeerNetParams(tr) - - l1 := getListener(ctx, p1) - defer l1.Close() - - n := 10 - delay := 50 * time.Millisecond - - accepted := make(chan struct{}) - go func() { - defer GinkgoRecover() - for i := 0; i < n; i++ { - conn, err := l1.Accept() - Expect(err).ToNot(HaveOccurred()) - defer conn.Close() - } - close(accepted) - }() - - var wg sync.WaitGroup - for i := 0; i < n; i++ { - wg.Add(1) - go func() { - defer GinkgoRecover() - defer wg.Done() - d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) - d2.Wrapper = func(c tpt.Conn) tpt.Conn { - time.Sleep(delay) - return c - } - before := time.Now() - _, err := d2.Dial(ctx, p1.Addr, p1.ID) - Expect(err).ToNot(HaveOccurred()) - // make sure the delay actually worked - Expect(time.Now()).To(BeTemporally(">", before.Add(delay))) - }() + defer conn.Close() + } + close(accepted) + }() + + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + defer GinkgoRecover() + defer wg.Done() + d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) + d2.Wrapper = func(c tpt.Conn) tpt.Conn { + time.Sleep(delay) + return c } + before := time.Now() + _, err := d2.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + // make sure the delay actually worked + Expect(time.Now()).To(BeTemporally(">", before.Add(delay))) + }() + } + + wg.Wait() + // the Eventually timeout is 100ms, which is a lot smaller than n*delay = 500ms + Eventually(accepted).Should(BeClosed()) + }) + + Context("address filters", func() { + It("doesn't accept connections from filtered addresses", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1 := randPeerNetParams() + p2 := randPeerNetParams() + + filt := filter.NewFilters() + _, ipnet, err := net.ParseCIDR("127.0.1.2/16") + Expect(err).ToNot(HaveOccurred()) + filt.AddDialFilter(ipnet) + Expect(filt.AddrBlocked(p2.Addr)).To(BeTrue()) + + l := getListener(ctx, p1) + defer l.Close() + l.SetAddrFilters(filt) + + accepted := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, _ = l.Accept() + close(accepted) + }() + + d := getDialer(p2.ID, p2.PrivKey, p2.Addr) + _, err = d.Dial(ctx, p1.Addr, p1.ID) + Expect(err).To(HaveOccurred()) + Eventually(accepted).ShouldNot(BeClosed()) + }) + + It("accepts connections from addresses that are not filtered", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - wg.Wait() - // the Eventually timeout is 100ms, which is a lot smaller than n*delay = 500ms - Eventually(accepted).Should(BeClosed()) - }) - - Context("address filters", func() { - It("doesn't accept connections from filtered addresses", func() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - p1 := randPeerNetParams(tr) - p2 := randPeerNetParams(tr) - - filt := filter.NewFilters() - _, ipnet, err := net.ParseCIDR("127.0.1.2/16") - Expect(err).ToNot(HaveOccurred()) - filt.AddDialFilter(ipnet) - Expect(filt.AddrBlocked(p2.Addr)).To(BeTrue()) - - l := getListener(ctx, p1) - defer l.Close() - l.SetAddrFilters(filt) - - accepted := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, _ = l.Accept() - close(accepted) - }() - - d := getDialer(p2.ID, p2.PrivKey, p2.Addr) - _, err = d.Dial(ctx, p1.Addr, p1.ID) - Expect(err).To(HaveOccurred()) - Eventually(accepted).ShouldNot(BeClosed()) - }) - - It("accepts connections from addresses that are not filtered", func() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - p1 := randPeerNetParams(tr) - p2 := randPeerNetParams(tr) - - filt := filter.NewFilters() - _, ipnet, err := net.ParseCIDR("192.168.1.2/16") - Expect(err).ToNot(HaveOccurred()) - filt.AddDialFilter(ipnet) - Expect(filt.AddrBlocked(p2.Addr)).To(BeFalse()) - - l := getListener(ctx, p1) - defer l.Close() - l.SetAddrFilters(filt) - - accepted := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := l.Accept() - Expect(err).ToNot(HaveOccurred()) - close(accepted) - }() - - d := getDialer(p2.ID, p2.PrivKey, p2.Addr) - c2, err := d.Dial(ctx, p1.Addr, p1.ID) - Expect(err).ToNot(HaveOccurred()) - defer c2.Close() - Eventually(accepted).Should(BeClosed()) - time.Sleep(time.Second) - }) - }) + p1 := randPeerNetParams() + p2 := randPeerNetParams() + + filt := filter.NewFilters() + _, ipnet, err := net.ParseCIDR("192.168.1.2/16") + Expect(err).ToNot(HaveOccurred()) + filt.AddDialFilter(ipnet) + Expect(filt.AddrBlocked(p2.Addr)).To(BeFalse()) + + l := getListener(ctx, p1) + defer l.Close() + l.SetAddrFilters(filt) + + accepted := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := l.Accept() + Expect(err).ToNot(HaveOccurred()) + close(accepted) + }() + + d := getDialer(p2.ID, p2.PrivKey, p2.Addr) + c2, err := d.Dial(ctx, p1.Addr, p1.ID) + Expect(err).ToNot(HaveOccurred()) + defer c2.Close() + Eventually(accepted).Should(BeClosed()) + time.Sleep(time.Second) }) - } + }) }) }) diff --git a/protector_test.go b/protector_test.go index 5eb7b3a..0c0b6ae 100644 --- a/protector_test.go +++ b/protector_test.go @@ -4,86 +4,46 @@ import ( "context" "errors" - iconn "github.com/libp2p/go-libp2p-interface-conn" ipnet "github.com/libp2p/go-libp2p-interface-pnet" tpt "github.com/libp2p/go-libp2p-transport" - smux "github.com/libp2p/go-stream-muxer" tcpt "github.com/libp2p/go-tcp-transport" tu "github.com/libp2p/go-testutil" - quict "github.com/marten-seemann/libp2p-quic-transport" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -type fakeDuplexProtector struct { +type fakeProtector struct { used bool } -func (f *fakeDuplexProtector) Fingerprint() []byte { +func (f *fakeProtector) Fingerprint() []byte { return make([]byte, 32) } -func (f *fakeDuplexProtector) Protect(c tpt.Conn) (tpt.Conn, error) { +func (f *fakeProtector) Protect(c tpt.Conn) (tpt.Conn, error) { f.used = true - return &rot13CryptDuplex{c.(tpt.DuplexConn)}, nil + return &rot13Crypt{c}, nil } -type rot13CryptDuplex struct { - tpt.DuplexConn +type rot13Crypt struct { + tpt.Conn } -func (r *rot13CryptDuplex) Read(b []byte) (int, error) { - n, err := r.DuplexConn.Read(b) +func (r *rot13Crypt) Read(b []byte) (int, error) { + n, err := r.Conn.Read(b) for i := 0; i < n; i++ { b[i] = b[i] - 13 } return n, err } -func (r *rot13CryptDuplex) Write(b []byte) (int, error) { +func (r *rot13Crypt) Write(b []byte) (int, error) { p := make([]byte, len(b)) // write MUST NOT modify b for i := range b { p[i] = b[i] + 13 } - return r.DuplexConn.Write(p) -} - -type fakeMultiplexProtector struct { - used bool - crypt *rot13CryptMultiplex -} - -func (f *fakeMultiplexProtector) Fingerprint() []byte { - return make([]byte, 32) -} - -func (f *fakeMultiplexProtector) Protect(c tpt.Conn) (tpt.Conn, error) { - f.used = true - f.crypt = &rot13CryptMultiplex{c.(tpt.MultiplexConn), 0, 0} - return f.crypt, nil -} - -type rot13CryptMultiplex struct { - tpt.MultiplexConn - openedStreams int - acceptedStreams int -} - -func (r *rot13CryptMultiplex) OpenStream() (smux.Stream, error) { - r.openedStreams++ - str, err := r.MultiplexConn.OpenStream() - return &rot13Stream{str}, err -} - -func (r *rot13CryptMultiplex) AcceptStream() (smux.Stream, error) { - r.acceptedStreams++ - str, err := r.MultiplexConn.AcceptStream() - return &rot13Stream{str}, err -} - -type rot13Stream struct { - smux.Stream + return r.Conn.Write(p) } var errProtect = errors.New("protecting failed") @@ -98,31 +58,15 @@ func (f *erroringProtector) Protect(c tpt.Conn) (tpt.Conn, error) { return nil, errProtect } -func (r *rot13Stream) Read(b []byte) (int, error) { - n, err := r.Stream.Read(b) - for i := 0; i < n; i++ { - b[i] = b[i] - 13 - } - return n, err -} - -func (r *rot13Stream) Write(b []byte) (int, error) { - p := make([]byte, len(b)) // write MUST NOT modify b - for i := range b { - p[i] = b[i] + 13 - } - return r.Stream.Write(p) -} - var _ = Describe("using the protector", func() { It("uses a protector for single-stream connections", func() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - p1 := randPeerNetParams(duplexTransport) - p2 := randPeerNetParams(duplexTransport) - p1Protec := &fakeDuplexProtector{} - p2Protec := &fakeDuplexProtector{} + p1 := randPeerNetParams() + p2 := randPeerNetParams() + p1Protec := &fakeProtector{} + p2Protec := &fakeProtector{} list, err := tcpt.NewTCPTransport().Listen(p1.Addr) Expect(err).ToNot(HaveOccurred()) @@ -150,68 +94,14 @@ var _ = Describe("using the protector", func() { Expect(p1Protec.used).To(BeTrue()) }) - // TODO: enable this test when adding support for multiplex connections - PIt("uses a protector for multiplex connections", func() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - p1 := randPeerNetParams(multiplexTransport) - p2 := randPeerNetParams(multiplexTransport) - p1Protec := &fakeMultiplexProtector{} - p2Protec := &fakeMultiplexProtector{} - - list, err := quict.NewQuicTransport().Listen(p1.Addr) - Expect(err).ToNot(HaveOccurred()) - l1, err := WrapTransportListenerWithProtector(ctx, list, p1.ID, p1.PrivKey, streamMuxer, p1Protec) - Expect(err).ToNot(HaveOccurred()) - p1.Addr = l1.Multiaddr() // Addr has been determined by kernel. - - d2 := getDialer(p2.ID, p2.PrivKey, p2.Addr) - d2.Protector = p2Protec - - var c1 iconn.Conn - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - var err error - c1, err = l1.Accept() - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - - c2, err := d2.Dial(ctx, p1.Addr, p1.ID) - Expect(err).ToNot(HaveOccurred()) - defer c2.Close() - - Expect(p2Protec.used).To(BeTrue()) - <-done - Expect(p1Protec.used).To(BeTrue()) - - Expect(p1Protec.crypt.acceptedStreams).To(Equal(2)) - Expect(p1Protec.crypt.openedStreams).To(BeZero()) - Expect(p2Protec.crypt.openedStreams).To(Equal(2)) - Expect(p2Protec.crypt.acceptedStreams).To(BeZero()) - - str1, err := c1.OpenStream() - Expect(err).ToNot(HaveOccurred()) - _, err = str1.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - _, err = c2.AcceptStream() - Expect(err).ToNot(HaveOccurred()) - Expect(p1Protec.crypt.acceptedStreams).To(Equal(2)) - Expect(p1Protec.crypt.openedStreams).To(Equal(1)) - Expect(p2Protec.crypt.openedStreams).To(Equal(2)) - Expect(p2Protec.crypt.acceptedStreams).To(Equal(1)) - }) - Context("forcing a private network", func() { var p1, p2 *tu.PeerNetParams var list tpt.Listener BeforeEach(func() { ipnet.ForcePrivateNetwork = true - p1 = randPeerNetParams(duplexTransport) - p2 = randPeerNetParams(duplexTransport) + p1 = randPeerNetParams() + p2 = randPeerNetParams() var err error list, err = tcpt.NewTCPTransport().Listen(p1.Addr) Expect(err).ToNot(HaveOccurred()) @@ -237,8 +127,8 @@ var _ = Describe("using the protector", func() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - p1 := randPeerNetParams(duplexTransport) - p2 := randPeerNetParams(duplexTransport) + p1 := randPeerNetParams() + p2 := randPeerNetParams() p1Protec := &erroringProtector{} p2Protec := &erroringProtector{} diff --git a/secure_conn.go b/secure_conn.go index b753459..ee01dc5 100644 --- a/secure_conn.go +++ b/secure_conn.go @@ -9,54 +9,54 @@ import ( ma "github.com/multiformats/go-multiaddr" ) -// secureDuplexConn wraps another DuplexConn object with an encrypted channel. -type secureDuplexConn struct { - insecure tpt.DuplexConn // the wrapped conn - secure secio.Session // secure Session +// secureConn wraps another Conn object with an encrypted channel. +type secureConn struct { + insecure tpt.Conn // the wrapped conn + secure secio.Session // secure Session } -var _ tpt.DuplexConn = &secureDuplexConn{} +var _ tpt.Conn = &secureConn{} -func (c *secureDuplexConn) Read(buf []byte) (int, error) { +func (c *secureConn) Read(buf []byte) (int, error) { return c.secure.ReadWriter().Read(buf) } -func (c *secureDuplexConn) Write(buf []byte) (int, error) { +func (c *secureConn) Write(buf []byte) (int, error) { return c.secure.ReadWriter().Write(buf) } -func (c *secureDuplexConn) Close() error { +func (c *secureConn) Close() error { return c.secure.Close() } -func (c *secureDuplexConn) LocalAddr() net.Addr { +func (c *secureConn) LocalAddr() net.Addr { return c.insecure.LocalAddr() } -func (c *secureDuplexConn) LocalMultiaddr() ma.Multiaddr { +func (c *secureConn) LocalMultiaddr() ma.Multiaddr { return c.insecure.LocalMultiaddr() } -func (c *secureDuplexConn) RemoteAddr() net.Addr { +func (c *secureConn) RemoteAddr() net.Addr { return c.insecure.RemoteAddr() } -func (c *secureDuplexConn) RemoteMultiaddr() ma.Multiaddr { +func (c *secureConn) RemoteMultiaddr() ma.Multiaddr { return c.insecure.RemoteMultiaddr() } -func (c *secureDuplexConn) SetDeadline(t time.Time) error { +func (c *secureConn) SetDeadline(t time.Time) error { return c.insecure.SetDeadline(t) } -func (c *secureDuplexConn) SetReadDeadline(t time.Time) error { +func (c *secureConn) SetReadDeadline(t time.Time) error { return c.insecure.SetDeadline(t) } -func (c *secureDuplexConn) SetWriteDeadline(t time.Time) error { +func (c *secureConn) SetWriteDeadline(t time.Time) error { return c.insecure.SetDeadline(t) } -func (c *secureDuplexConn) Transport() tpt.Transport { +func (c *secureConn) Transport() tpt.Transport { return c.insecure.Transport() } From 1820ac9478a5b7477347538c3f3fe5d4548c7015 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 27 Oct 2017 12:15:50 +0700 Subject: [PATCH 09/10] simplify accepting new connections --- listen.go | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/listen.go b/listen.go index 36ae19b..4fe44ba 100644 --- a/listen.go +++ b/listen.go @@ -2,7 +2,6 @@ package conn import ( "context" - "errors" "fmt" "io" "net" @@ -91,13 +90,11 @@ func (l *listener) Accept() (iconn.Conn, error) { log.Warningf("listener %s listening INSECURELY!", l) } - for c := range l.incoming { - if c.err != nil { - return nil, c.err - } - return c.conn, nil + c, ok := <-l.incoming + if !ok { + return nil, fmt.Errorf("listener is closed") } - return nil, errors.New("listener is closed") + return c.conn, c.err } func (l *listener) Addr() net.Addr { From 8e5f94182d91254233a9b5454896f410d78cc59d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 6 Nov 2017 13:21:50 +0700 Subject: [PATCH 10/10] fix race condition when accepting connections --- listen.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/listen.go b/listen.go index 4fe44ba..b36b02d 100644 --- a/listen.go +++ b/listen.go @@ -160,6 +160,7 @@ func (l *listener) handleIncoming() { defer cancel() done := make(chan struct{}) + var singleConn iconn.Conn go func() { defer close(done) @@ -184,22 +185,21 @@ func (l *listener) handleIncoming() { return } - c, err := newSingleConn(ctx, l.local, "", l.privk, conn, l.streamMuxer, true) + singleConn, err = newSingleConn(ctx, l.local, "", l.privk, conn, l.streamMuxer, true) if err != nil { log.Warning("connection setup failed: ", err) conn.Close() - return } - - l.incoming <- connOrErr{conn: c} }() select { case <-ctx.Done(): log.Warning("incoming conn: conn not established in time:", ctx.Err().Error()) conn.Close() - return case <-done: // connection completed (or errored) + if singleConn != nil { + l.incoming <- connOrErr{conn: singleConn} + } } }() }