Skip to content

Commit

Permalink
webtransport: only add cert hashes if we already started listening (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann authored and MarcoPolo committed May 5, 2023
1 parent a48be6a commit 6644291
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 23 deletions.
27 changes: 27 additions & 0 deletions libp2p_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,3 +329,30 @@ func TestTransportCustomAddressWebTransport(t *testing.T) {
require.Equal(t, secondToLastComp.Protocol().Code, ma.P_CERTHASH)
require.True(t, restOfAddr.Equal(customAddr))
}

// TestTransportCustomAddressWebTransportDoesNotStall tests that if the user
// manually returns a webtransport address from AddrsFactory, but we aren't
// listening on a webtranport address, we don't stall.
func TestTransportCustomAddressWebTransportDoesNotStall(t *testing.T) {
customAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")
if err != nil {
t.Fatal(err)
}
h, err := New(
Transport(webtransport.New),
// Purposely not listening on the custom address so that we make sure the node doesn't stall if it fails to add a certhash to the multiaddr
// ListenAddrs(customAddr),
DisableRelay(),
AddrsFactory(func(multiaddrs []ma.Multiaddr) []ma.Multiaddr {
return []ma.Multiaddr{customAddr}
}),
)
require.NoError(t, err)
defer h.Close()
addrs := h.Addrs()
require.Len(t, addrs, 1)
_, lastComp := ma.SplitLast(addrs[0])
require.NotEqual(t, lastComp.Protocol().Code, ma.P_CERTHASH)
// We did not add the certhash to the multiaddr
require.Equal(t, addrs[0], customAddr)
}
8 changes: 6 additions & 2 deletions p2p/host/basic/basic_host.go
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ func (h *BasicHost) Addrs() []ma.Multiaddr {
}

type addCertHasher interface {
AddCertHashes(m ma.Multiaddr) ma.Multiaddr
AddCertHashes(m ma.Multiaddr) (ma.Multiaddr, bool)
}

addrs := h.AddrsFactory(h.AllAddrs())
Expand All @@ -793,7 +793,11 @@ func (h *BasicHost) Addrs() []ma.Multiaddr {
if !ok {
continue
}
addrs[i] = tpt.AddCertHashes(addr)
addrWithCerthash, added := tpt.AddCertHashes(addr)
addrs[i] = addrWithCerthash
if !added {
log.Debug("Couldn't add certhashes to webtransport multiaddr because we aren't listening on webtransport")
}
}
}
return addrs
Expand Down
42 changes: 21 additions & 21 deletions p2p/transport/webtransport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"io"
"sync"
"sync/atomic"
"time"

"github.com/libp2p/go-libp2p/core/connmgr"
Expand Down Expand Up @@ -68,12 +69,12 @@ type transport struct {
rcmgr network.ResourceManager
gater connmgr.ConnectionGater

listenOnce sync.Once
listenOnceErr error
certManager *certManager
certManagerReady chan struct{} // Closed when the certManager has been instantiated.
staticTLSConf *tls.Config
tlsClientConf *tls.Config
listenOnce sync.Once
listenOnceErr error
certManager *certManager
hasCertManager atomic.Bool // set to true once the certManager is initialized
staticTLSConf *tls.Config
tlsClientConf *tls.Config

noise *noise.Transport

Expand All @@ -98,14 +99,13 @@ func New(key ic.PrivKey, psk pnet.PSK, connManager *quicreuse.ConnManager, gater
return nil, err
}
t := &transport{
pid: id,
privKey: key,
rcmgr: rcmgr,
gater: gater,
clock: clock.New(),
connManager: connManager,
conns: map[uint64]*conn{},
certManagerReady: make(chan struct{}),
pid: id,
privKey: key,
rcmgr: rcmgr,
gater: gater,
clock: clock.New(),
connManager: connManager,
conns: map[uint64]*conn{},
}
for _, opt := range opts {
if err := opt(t); err != nil {
Expand Down Expand Up @@ -300,13 +300,12 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) {
if t.staticTLSConf == nil {
t.listenOnce.Do(func() {
t.certManager, t.listenOnceErr = newCertManager(t.privKey, t.clock)
close(t.certManagerReady)
t.hasCertManager.Store(true)
})
if t.listenOnceErr != nil {
return nil, t.listenOnceErr
}
} else {
close(t.certManagerReady)
return nil, errors.New("static TLS config not supported on WebTransport")
}
tlsConf := t.staticTLSConf.Clone()
Expand Down Expand Up @@ -405,10 +404,11 @@ func (t *transport) Resolve(_ context.Context, maddr ma.Multiaddr) ([]ma.Multiad
return []ma.Multiaddr{beforeQuicMA.Encapsulate(quicComponent).Encapsulate(sniComponent).Encapsulate(afterQuicMA)}, nil
}

func (t *transport) AddCertHashes(m ma.Multiaddr) ma.Multiaddr {
<-t.certManagerReady
if t.certManager == nil {
return m
// AddCertHashes adds the current certificate hashes to a multiaddress.
// If called before Listen, it's a no-op.
func (t *transport) AddCertHashes(m ma.Multiaddr) (ma.Multiaddr, bool) {
if !t.hasCertManager.Load() {
return m, false
}
return m.Encapsulate(t.certManager.AddrComponent())
return m.Encapsulate(t.certManager.AddrComponent()), true
}

0 comments on commit 6644291

Please sign in to comment.