diff --git a/conn_test.go b/conn_test.go index 67d48e7..172acff 100644 --- a/conn_test.go +++ b/conn_test.go @@ -8,54 +8,22 @@ import ( "io/ioutil" mrand "math/rand" "net" - "sync" "sync/atomic" "time" - "github.com/libp2p/go-libp2p-core/control" + gomock "github.com/golang/mock/gomock" ic "github.com/libp2p/go-libp2p-core/crypto" - "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" tpt "github.com/libp2p/go-libp2p-core/transport" quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" ma "github.com/multiformats/go-multiaddr" - manet "github.com/multiformats/go-multiaddr/net" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -type mockGater struct { - lk sync.Mutex - acceptAll bool - blockedPeer peer.ID -} - -func (c *mockGater) InterceptAccept(addrs network.ConnMultiaddrs) bool { - c.lk.Lock() - defer c.lk.Unlock() - return c.acceptAll || !manet.IsIPLoopback(addrs.RemoteMultiaddr()) -} - -func (c *mockGater) InterceptPeerDial(p peer.ID) (allow bool) { - return true -} - -func (c *mockGater) InterceptAddrDial(peer.ID, ma.Multiaddr) (allow bool) { - return true -} - -func (c *mockGater) InterceptSecured(_ network.Direction, p peer.ID, _ network.ConnMultiaddrs) (allow bool) { - c.lk.Lock() - defer c.lk.Unlock() - return p != c.blockedPeer -} - -func (c *mockGater) InterceptUpgraded(network.Conn) (allow bool, reason control.DisconnectReason) { - return true, 0 -} - +//go:generate sh -c "mockgen -package libp2pquic -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p-core/connmgr ConnectionGater && goimports -w mock_connection_gater_test.go" var _ = Describe("Connection", func() { var ( serverKey, clientKey ic.PrivKey @@ -200,33 +168,38 @@ var _ = Describe("Connection", func() { }) It("gates accepted connections", func() { - testMA, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/1234/quic") - Expect(err).ToNot(HaveOccurred()) - cg := &mockGater{} - Expect(cg.InterceptAccept(&connAddrs{rmAddr: testMA})).To(BeFalse()) - + cg := NewMockConnectionGater(mockCtrl) + cg.EXPECT().InterceptAccept(gomock.Any()) serverTransport, err := NewTransport(serverKey, nil, cg) Expect(err).ToNot(HaveOccurred()) ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") defer ln.Close() + accepted := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(accepted) + _, err := ln.Accept() + Expect(err).ToNot(HaveOccurred()) + }() + clientTransport, err := NewTransport(clientKey, nil, nil) Expect(err).ToNot(HaveOccurred()) - // make sure that connection attempts fails - clientTransport.(*transport).clientConfig.HandshakeTimeout = 250 * time.Millisecond - _, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) + conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) + Expect(err).ToNot(HaveOccurred()) + _, err = conn.AcceptStream() Expect(err).To(HaveOccurred()) - Expect(err.(net.Error).Timeout()).To(BeTrue()) + Expect(err.Error()).To(ContainSubstring("connection gated")) // now allow the address and make sure the connection goes through + cg.EXPECT().InterceptAccept(gomock.Any()).Return(true) + cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) clientTransport.(*transport).clientConfig.HandshakeTimeout = 2 * time.Second - cg.lk.Lock() - cg.acceptAll = true - cg.lk.Unlock() - conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) + conn, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) Expect(err).ToNot(HaveOccurred()) - conn.Close() + defer conn.Close() + Eventually(accepted).Should(BeClosed()) }) It("gates secured connections", func() { @@ -235,20 +208,20 @@ var _ = Describe("Connection", func() { ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") defer ln.Close() - cg := &mockGater{acceptAll: true, blockedPeer: serverID} + cg := NewMockConnectionGater(mockCtrl) + cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()) + clientTransport, err := NewTransport(clientKey, nil, cg) Expect(err).ToNot(HaveOccurred()) // make sure that connection attempts fails - clientTransport.(*transport).clientConfig.HandshakeTimeout = 250 * time.Millisecond _, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("connection gated")) // now allow the peerId and make sure the connection goes through + cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) clientTransport.(*transport).clientConfig.HandshakeTimeout = 2 * time.Second - cg.lk.Lock() - cg.blockedPeer = "none" - cg.lk.Unlock() conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) Expect(err).ToNot(HaveOccurred()) conn.Close() diff --git a/filtered_conn.go b/filtered_conn.go deleted file mode 100644 index c5f046e..0000000 --- a/filtered_conn.go +++ /dev/null @@ -1,58 +0,0 @@ -package libp2pquic - -import ( - "net" - - "github.com/libp2p/go-libp2p-core/connmgr" - - ma "github.com/multiformats/go-multiaddr" -) - -type connAddrs struct { - lmAddr ma.Multiaddr - rmAddr ma.Multiaddr -} - -func (c *connAddrs) LocalMultiaddr() ma.Multiaddr { - return c.lmAddr -} - -func (c *connAddrs) RemoteMultiaddr() ma.Multiaddr { - return c.rmAddr -} - -type filteredConn struct { - net.PacketConn - - lmAddr ma.Multiaddr - gater connmgr.ConnectionGater -} - -func newFilteredConn(c net.PacketConn, gater connmgr.ConnectionGater) net.PacketConn { - lmAddr, err := toQuicMultiaddr(c.LocalAddr()) - if err != nil { - panic(err) - } - - return &filteredConn{PacketConn: c, gater: gater, lmAddr: lmAddr} -} - -func (c *filteredConn) ReadFrom(b []byte) (n int, addr net.Addr, rerr error) { - for { - n, addr, rerr = c.PacketConn.ReadFrom(b) - // Short Header packet, see https://tools.ietf.org/html/draft-ietf-quic-invariants-07#section-4.2. - if n < 1 || b[0]&0x80 == 0 { - return - } - rmAddr, err := toQuicMultiaddr(addr) - if err != nil { - panic(err) - } - - connAddrs := &connAddrs{lmAddr: c.lmAddr, rmAddr: rmAddr} - - if c.gater.InterceptAccept(connAddrs) { - return - } - } -} diff --git a/go.mod b/go.mod index c9c56e3..5b53601 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/libp2p/go-libp2p-quic-transport go 1.14 require ( + github.com/golang/mock v1.4.4 github.com/ipfs/go-log v1.0.4 github.com/libp2p/go-libp2p-core v0.7.0 github.com/libp2p/go-libp2p-tls v0.1.3 diff --git a/libp2pquic_suite_test.go b/libp2pquic_suite_test.go index 5905763..0415fed 100644 --- a/libp2pquic_suite_test.go +++ b/libp2pquic_suite_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + gomock "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go" . "github.com/onsi/ginkgo" @@ -27,6 +28,7 @@ var ( garbageCollectIntervalOrig time.Duration maxUnusedDurationOrig time.Duration origQuicConfig *quic.Config + mockCtrl *gomock.Controller ) func isGarbageCollectorRunning() bool { @@ -36,6 +38,8 @@ func isGarbageCollectorRunning() bool { } var _ = BeforeEach(func() { + mockCtrl = gomock.NewController(GinkgoT()) + Expect(isGarbageCollectorRunning()).To(BeFalse()) garbageCollectIntervalOrig = garbageCollectInterval maxUnusedDurationOrig = maxUnusedDuration @@ -46,6 +50,8 @@ var _ = BeforeEach(func() { }) var _ = AfterEach(func() { + mockCtrl.Finish() + Eventually(isGarbageCollectorRunning).Should(BeFalse()) garbageCollectInterval = garbageCollectIntervalOrig maxUnusedDuration = maxUnusedDurationOrig diff --git a/listener.go b/listener.go index 1574ea2..70fe471 100644 --- a/listener.go +++ b/listener.go @@ -3,7 +3,6 @@ package libp2pquic import ( "context" "crypto/tls" - "fmt" "net" ic "github.com/libp2p/go-libp2p-core/crypto" @@ -17,6 +16,8 @@ import ( ma "github.com/multiformats/go-multiaddr" ) +var quicListen = quic.Listen // so we can mock it in tests + // A listener listens for QUIC connections. type listener struct { quicListener quic.Listener @@ -39,7 +40,7 @@ func newListener(rconn *reuseConn, t *transport, localPeer peer.ID, key ic.PrivK conf, _ := identity.ConfigForAny() return conf, nil } - ln, err := quic.Listen(rconn, &tlsConf, t.serverConfig) + ln, err := quicListen(rconn, &tlsConf, t.serverConfig) if err != nil { return nil, err } @@ -69,11 +70,15 @@ func (l *listener) Accept() (tpt.CapableConn, error) { sess.CloseWithError(0, err.Error()) continue } + if l.transport.gater != nil && !(l.transport.gater.InterceptAccept(conn) && l.transport.gater.InterceptSecured(n.DirInbound, conn.remotePeerID, conn)) { + sess.CloseWithError(errorCodeConnectionGating, "connection gated") + continue + } return conn, nil } } -func (l *listener) setupConn(sess quic.Session) (tpt.CapableConn, error) { +func (l *listener) setupConn(sess quic.Session) (*conn, error) { // The tls.Config used to establish this connection already verified the certificate chain. // Since we don't have any way of knowing which tls.Config was used though, // we have to re-determine the peer's identity here. @@ -92,11 +97,6 @@ func (l *listener) setupConn(sess quic.Session) (tpt.CapableConn, error) { return nil, err } - connaddrs := &connAddrs{lmAddr: l.localMultiaddr, rmAddr: remoteMultiaddr} - if l.transport.gater != nil && !l.transport.gater.InterceptSecured(n.DirInbound, remotePeerID, connaddrs) { - return nil, fmt.Errorf("secured connection gated") - } - return &conn{ sess: sess, transport: l.transport, diff --git a/listener_test.go b/listener_test.go index 3388f1a..cb5836e 100644 --- a/listener_test.go +++ b/listener_test.go @@ -3,18 +3,29 @@ package libp2pquic import ( "crypto/rand" "crypto/rsa" + "crypto/tls" "crypto/x509" + "errors" "fmt" "net" + "syscall" ic "github.com/libp2p/go-libp2p-core/crypto" tpt "github.com/libp2p/go-libp2p-core/transport" + quic "github.com/lucas-clemente/quic-go" ma "github.com/multiformats/go-multiaddr" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) +// interface containing some methods defined on the net.UDPConn, but not the net.PacketConn +type udpConn interface { + ReadFromUDP(b []byte) (int, *net.UDPAddr, error) + SetReadBuffer(bytes int) error + SyscallConn() (syscall.RawConn, error) +} + var _ = Describe("Listener", func() { var t tpt.Transport @@ -27,6 +38,25 @@ var _ = Describe("Listener", func() { Expect(err).ToNot(HaveOccurred()) }) + It("uses a conn that can interface assert to a UDPConn for listening", func() { + origQuicListen := quicListen + defer func() { quicListen = origQuicListen }() + + var conn net.PacketConn + quicListen = func(c net.PacketConn, _ *tls.Config, _ *quic.Config) (quic.Listener, error) { + conn = c + return nil, errors.New("listen error") + } + localAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") + Expect(err).ToNot(HaveOccurred()) + _, err = t.Listen(localAddr) + Expect(err).To(MatchError("listen error")) + Expect(conn).ToNot(BeNil()) + defer conn.Close() + _, ok := conn.(udpConn) + Expect(ok).To(BeTrue()) + }) + Context("listening on the right address", func() { It("returns the address it is listening on", func() { localAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") diff --git a/mock_connection_gater_test.go b/mock_connection_gater_test.go new file mode 100644 index 0000000..899a0c6 --- /dev/null +++ b/mock_connection_gater_test.go @@ -0,0 +1,109 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/libp2p/go-libp2p-core/connmgr (interfaces: ConnectionGater) + +// Package libp2pquic is a generated GoMock package. +package libp2pquic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + control "github.com/libp2p/go-libp2p-core/control" + network "github.com/libp2p/go-libp2p-core/network" + peer "github.com/libp2p/go-libp2p-core/peer" + multiaddr "github.com/multiformats/go-multiaddr" +) + +// MockConnectionGater is a mock of ConnectionGater interface +type MockConnectionGater struct { + ctrl *gomock.Controller + recorder *MockConnectionGaterMockRecorder +} + +// MockConnectionGaterMockRecorder is the mock recorder for MockConnectionGater +type MockConnectionGaterMockRecorder struct { + mock *MockConnectionGater +} + +// NewMockConnectionGater creates a new mock instance +func NewMockConnectionGater(ctrl *gomock.Controller) *MockConnectionGater { + mock := &MockConnectionGater{ctrl: ctrl} + mock.recorder = &MockConnectionGaterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockConnectionGater) EXPECT() *MockConnectionGaterMockRecorder { + return m.recorder +} + +// InterceptAccept mocks base method +func (m *MockConnectionGater) InterceptAccept(arg0 network.ConnMultiaddrs) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptAccept", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptAccept indicates an expected call of InterceptAccept +func (mr *MockConnectionGaterMockRecorder) InterceptAccept(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptAccept", reflect.TypeOf((*MockConnectionGater)(nil).InterceptAccept), arg0) +} + +// InterceptAddrDial mocks base method +func (m *MockConnectionGater) InterceptAddrDial(arg0 peer.ID, arg1 multiaddr.Multiaddr) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptAddrDial", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptAddrDial indicates an expected call of InterceptAddrDial +func (mr *MockConnectionGaterMockRecorder) InterceptAddrDial(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptAddrDial", reflect.TypeOf((*MockConnectionGater)(nil).InterceptAddrDial), arg0, arg1) +} + +// InterceptPeerDial mocks base method +func (m *MockConnectionGater) InterceptPeerDial(arg0 peer.ID) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptPeerDial", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptPeerDial indicates an expected call of InterceptPeerDial +func (mr *MockConnectionGaterMockRecorder) InterceptPeerDial(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptPeerDial", reflect.TypeOf((*MockConnectionGater)(nil).InterceptPeerDial), arg0) +} + +// InterceptSecured mocks base method +func (m *MockConnectionGater) InterceptSecured(arg0 network.Direction, arg1 peer.ID, arg2 network.ConnMultiaddrs) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptSecured", arg0, arg1, arg2) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptSecured indicates an expected call of InterceptSecured +func (mr *MockConnectionGaterMockRecorder) InterceptSecured(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptSecured", reflect.TypeOf((*MockConnectionGater)(nil).InterceptSecured), arg0, arg1, arg2) +} + +// InterceptUpgraded mocks base method +func (m *MockConnectionGater) InterceptUpgraded(arg0 network.Conn) (bool, control.DisconnectReason) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptUpgraded", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(control.DisconnectReason) + return ret0, ret1 +} + +// InterceptUpgraded indicates an expected call of InterceptUpgraded +func (mr *MockConnectionGaterMockRecorder) InterceptUpgraded(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptUpgraded", reflect.TypeOf((*MockConnectionGater)(nil).InterceptUpgraded), arg0) +} diff --git a/reuse.go b/reuse.go index d7f7f76..3aa6940 100644 --- a/reuse.go +++ b/reuse.go @@ -17,18 +17,15 @@ var ( ) type reuseConn struct { - net.PacketConn + *net.UDPConn mutex sync.Mutex refCount int unusedSince time.Time } -func newReuseConn(conn net.PacketConn, gater connmgr.ConnectionGater) *reuseConn { - if gater != nil { - conn = newFilteredConn(conn, gater) - } - return &reuseConn{PacketConn: conn} +func newReuseConn(conn *net.UDPConn, gater connmgr.ConnectionGater) *reuseConn { + return &reuseConn{UDPConn: conn} } func (c *reuseConn) IncreaseCount() { diff --git a/transport.go b/transport.go index 25d0e8a..ea850f1 100644 --- a/transport.go +++ b/transport.go @@ -27,6 +27,8 @@ import ( var log = logging.Logger("quic-transport") +var quicDialContext = quic.DialContext // so we can mock it in tests + var quicConfig = &quic.Config{ MaxIncomingStreams: 1000, MaxIncomingUniStreams: -1, // disable unidirectional streams @@ -40,6 +42,7 @@ var quicConfig = &quic.Config{ } const statelessResetKeyInfo = "libp2p quic stateless reset key" +const errorCodeConnectionGating = 0x47415445 // GATE in ASCII type connManager struct { reuseUDP4 *reuse @@ -156,7 +159,7 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp if err != nil { return nil, err } - sess, err := quic.DialContext(ctx, pconn, addr, host, tlsConf, t.clientConfig) + sess, err := quicDialContext(ctx, pconn, addr, host, tlsConf, t.clientConfig) if err != nil { pconn.DecreaseCount() return nil, err @@ -181,14 +184,7 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp sess.CloseWithError(0, "") return nil, err } - - connaddrs := &connAddrs{lmAddr: localMultiaddr, rmAddr: remoteMultiaddr} - if t.gater != nil && !t.gater.InterceptSecured(n.DirOutbound, p, connaddrs) { - sess.CloseWithError(0, "") - return nil, fmt.Errorf("secured connection gated") - } - - return &conn{ + conn := &conn{ sess: sess, transport: t, privKey: t.privKey, @@ -197,7 +193,12 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp remotePubKey: remotePubKey, remotePeerID: p, remoteMultiaddr: remoteMultiaddr, - }, nil + } + if t.gater != nil && !t.gater.InterceptSecured(n.DirOutbound, p, conn) { + sess.CloseWithError(errorCodeConnectionGating, "connection gated") + return nil, fmt.Errorf("secured connection gated") + } + return conn, nil } // Don't use mafmt.QUIC as we don't want to dial DNS addresses. Just /ip{4,6}/udp/quic @@ -222,7 +223,12 @@ func (t *transport) Listen(addr ma.Multiaddr) (tpt.Listener, error) { if err != nil { return nil, err } - return newListener(conn, t, t.localPeer, t.privKey, t.identity) + ln, err := newListener(conn, t, t.localPeer, t.privKey, t.identity) + if err != nil { + conn.DecreaseCount() + return nil, err + } + return ln, nil } // Proxy returns true if this transport proxies. diff --git a/transport_test.go b/transport_test.go index 111b702..226ec2f 100644 --- a/transport_test.go +++ b/transport_test.go @@ -1,7 +1,17 @@ package libp2pquic import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "errors" + "net" + + ic "github.com/libp2p/go-libp2p-core/crypto" tpt "github.com/libp2p/go-libp2p-core/transport" + quic "github.com/lucas-clemente/quic-go" ma "github.com/multiformats/go-multiaddr" . "github.com/onsi/ginkgo" @@ -12,7 +22,12 @@ var _ = Describe("Transport", func() { var t tpt.Transport BeforeEach(func() { - t = &transport{} + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + Expect(err).ToNot(HaveOccurred()) + key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) + Expect(err).ToNot(HaveOccurred()) + t, err = NewTransport(key, nil, nil) + Expect(err).ToNot(HaveOccurred()) }) It("says if it can dial an address", func() { @@ -35,4 +50,23 @@ var _ = Describe("Transport", func() { Expect(protocols).To(HaveLen(1)) Expect(protocols[0]).To(Equal(ma.P_QUIC)) }) + + It("uses a conn that can interface assert to a UDPConn for dialing", func() { + origQuicDialContext := quicDialContext + defer func() { quicDialContext = origQuicDialContext }() + + var conn net.PacketConn + quicDialContext = func(_ context.Context, c net.PacketConn, _ net.Addr, _ string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + conn = c + return nil, errors.New("listen error") + } + remoteAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") + Expect(err).ToNot(HaveOccurred()) + _, err = t.Dial(context.Background(), remoteAddr, "remote peer id") + Expect(err).To(MatchError("listen error")) + Expect(conn).ToNot(BeNil()) + defer conn.Close() + _, ok := conn.(udpConn) + Expect(ok).To(BeTrue()) + }) })