Skip to content
This repository has been archived by the owner on May 26, 2022. It is now read-only.

pass a conn that can be type asserted to a net.UDPConn to quic-go #180

Merged
merged 6 commits into from
Oct 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 26 additions & 53 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,54 +8,22 @@ import (
"io/ioutil"
mrand "math/rand"
"net"
"sync"
"sync/atomic"
"time"

"github.com/libp2p/go-libp2p-core/control"
gomock "github.com/golang/mock/gomock"
ic "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
tpt "github.com/libp2p/go-libp2p-core/transport"

quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)

type mockGater struct {
lk sync.Mutex
acceptAll bool
blockedPeer peer.ID
}

func (c *mockGater) InterceptAccept(addrs network.ConnMultiaddrs) bool {
c.lk.Lock()
defer c.lk.Unlock()
return c.acceptAll || !manet.IsIPLoopback(addrs.RemoteMultiaddr())
}

func (c *mockGater) InterceptPeerDial(p peer.ID) (allow bool) {
return true
}

func (c *mockGater) InterceptAddrDial(peer.ID, ma.Multiaddr) (allow bool) {
return true
}

func (c *mockGater) InterceptSecured(_ network.Direction, p peer.ID, _ network.ConnMultiaddrs) (allow bool) {
c.lk.Lock()
defer c.lk.Unlock()
return p != c.blockedPeer
}

func (c *mockGater) InterceptUpgraded(network.Conn) (allow bool, reason control.DisconnectReason) {
return true, 0
}

//go:generate sh -c "mockgen -package libp2pquic -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p-core/connmgr ConnectionGater && goimports -w mock_connection_gater_test.go"
var _ = Describe("Connection", func() {
var (
serverKey, clientKey ic.PrivKey
Expand Down Expand Up @@ -200,33 +168,38 @@ var _ = Describe("Connection", func() {
})

It("gates accepted connections", func() {
testMA, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/1234/quic")
Expect(err).ToNot(HaveOccurred())
cg := &mockGater{}
Expect(cg.InterceptAccept(&connAddrs{rmAddr: testMA})).To(BeFalse())

cg := NewMockConnectionGater(mockCtrl)
cg.EXPECT().InterceptAccept(gomock.Any())
serverTransport, err := NewTransport(serverKey, nil, cg)
Expect(err).ToNot(HaveOccurred())
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()

accepted := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(accepted)
_, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
}()

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())

// make sure that connection attempts fails
clientTransport.(*transport).clientConfig.HandshakeTimeout = 250 * time.Millisecond
_, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
_, err = conn.AcceptStream()
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
Expect(err.Error()).To(ContainSubstring("connection gated"))

// now allow the address and make sure the connection goes through
cg.EXPECT().InterceptAccept(gomock.Any()).Return(true)
cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
clientTransport.(*transport).clientConfig.HandshakeTimeout = 2 * time.Second
cg.lk.Lock()
cg.acceptAll = true
cg.lk.Unlock()
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
conn, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
conn.Close()
defer conn.Close()
Eventually(accepted).Should(BeClosed())
})

It("gates secured connections", func() {
Expand All @@ -235,20 +208,20 @@ var _ = Describe("Connection", func() {
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()

cg := &mockGater{acceptAll: true, blockedPeer: serverID}
cg := NewMockConnectionGater(mockCtrl)
cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any())

clientTransport, err := NewTransport(clientKey, nil, cg)
Expect(err).ToNot(HaveOccurred())

// make sure that connection attempts fails
clientTransport.(*transport).clientConfig.HandshakeTimeout = 250 * time.Millisecond
_, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("connection gated"))

// now allow the peerId and make sure the connection goes through
cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
clientTransport.(*transport).clientConfig.HandshakeTimeout = 2 * time.Second
cg.lk.Lock()
cg.blockedPeer = "none"
cg.lk.Unlock()
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
conn.Close()
Expand Down
58 changes: 0 additions & 58 deletions filtered_conn.go

This file was deleted.

1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/libp2p/go-libp2p-quic-transport
go 1.14

require (
github.com/golang/mock v1.4.4
github.com/ipfs/go-log v1.0.4
github.com/libp2p/go-libp2p-core v0.7.0
github.com/libp2p/go-libp2p-tls v0.1.3
Expand Down
6 changes: 6 additions & 0 deletions libp2pquic_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"
"time"

gomock "github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go"

. "github.com/onsi/ginkgo"
Expand All @@ -27,6 +28,7 @@ var (
garbageCollectIntervalOrig time.Duration
maxUnusedDurationOrig time.Duration
origQuicConfig *quic.Config
mockCtrl *gomock.Controller
)

func isGarbageCollectorRunning() bool {
Expand All @@ -36,6 +38,8 @@ func isGarbageCollectorRunning() bool {
}

var _ = BeforeEach(func() {
mockCtrl = gomock.NewController(GinkgoT())

Expect(isGarbageCollectorRunning()).To(BeFalse())
garbageCollectIntervalOrig = garbageCollectInterval
maxUnusedDurationOrig = maxUnusedDuration
Expand All @@ -46,6 +50,8 @@ var _ = BeforeEach(func() {
})

var _ = AfterEach(func() {
mockCtrl.Finish()

Eventually(isGarbageCollectorRunning).Should(BeFalse())
garbageCollectInterval = garbageCollectIntervalOrig
maxUnusedDuration = maxUnusedDurationOrig
Expand Down
16 changes: 8 additions & 8 deletions listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package libp2pquic
import (
"context"
"crypto/tls"
"fmt"
"net"

ic "github.com/libp2p/go-libp2p-core/crypto"
Expand All @@ -17,6 +16,8 @@ import (
ma "github.com/multiformats/go-multiaddr"
)

var quicListen = quic.Listen // so we can mock it in tests

// A listener listens for QUIC connections.
type listener struct {
quicListener quic.Listener
Expand All @@ -39,7 +40,7 @@ func newListener(rconn *reuseConn, t *transport, localPeer peer.ID, key ic.PrivK
conf, _ := identity.ConfigForAny()
return conf, nil
}
ln, err := quic.Listen(rconn, &tlsConf, t.serverConfig)
ln, err := quicListen(rconn, &tlsConf, t.serverConfig)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -69,11 +70,15 @@ func (l *listener) Accept() (tpt.CapableConn, error) {
sess.CloseWithError(0, err.Error())
continue
}
if l.transport.gater != nil && !(l.transport.gater.InterceptAccept(conn) && l.transport.gater.InterceptSecured(n.DirInbound, conn.remotePeerID, conn)) {
sess.CloseWithError(errorCodeConnectionGating, "connection gated")
continue
}
return conn, nil
}
}

func (l *listener) setupConn(sess quic.Session) (tpt.CapableConn, error) {
func (l *listener) setupConn(sess quic.Session) (*conn, error) {
// The tls.Config used to establish this connection already verified the certificate chain.
// Since we don't have any way of knowing which tls.Config was used though,
// we have to re-determine the peer's identity here.
Expand All @@ -92,11 +97,6 @@ func (l *listener) setupConn(sess quic.Session) (tpt.CapableConn, error) {
return nil, err
}

connaddrs := &connAddrs{lmAddr: l.localMultiaddr, rmAddr: remoteMultiaddr}
if l.transport.gater != nil && !l.transport.gater.InterceptSecured(n.DirInbound, remotePeerID, connaddrs) {
return nil, fmt.Errorf("secured connection gated")
}

return &conn{
sess: sess,
transport: l.transport,
Expand Down
30 changes: 30 additions & 0 deletions listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,29 @@ package libp2pquic
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"syscall"

ic "github.com/libp2p/go-libp2p-core/crypto"
tpt "github.com/libp2p/go-libp2p-core/transport"
quic "github.com/lucas-clemente/quic-go"

ma "github.com/multiformats/go-multiaddr"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)

// interface containing some methods defined on the net.UDPConn, but not the net.PacketConn
type udpConn interface {
ReadFromUDP(b []byte) (int, *net.UDPAddr, error)
SetReadBuffer(bytes int) error
SyscallConn() (syscall.RawConn, error)
}

var _ = Describe("Listener", func() {
var t tpt.Transport

Expand All @@ -27,6 +38,25 @@ var _ = Describe("Listener", func() {
Expect(err).ToNot(HaveOccurred())
})

It("uses a conn that can interface assert to a UDPConn for listening", func() {
origQuicListen := quicListen
defer func() { quicListen = origQuicListen }()

var conn net.PacketConn
quicListen = func(c net.PacketConn, _ *tls.Config, _ *quic.Config) (quic.Listener, error) {
conn = c
return nil, errors.New("listen error")
}
localAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic")
Expect(err).ToNot(HaveOccurred())
_, err = t.Listen(localAddr)
Expect(err).To(MatchError("listen error"))
Expect(conn).ToNot(BeNil())
defer conn.Close()
_, ok := conn.(udpConn)
Expect(ok).To(BeTrue())
})

Context("listening on the right address", func() {
It("returns the address it is listening on", func() {
localAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic")
Expand Down
Loading