Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

webtransport: add PSK to constructor, and fail if it is used #1929

Merged
merged 1 commit into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion p2p/net/swarm/swarm_addr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func TestDialAddressSelection(t *testing.T) {
quicTr, err := quic.NewTransport(priv, reuse, nil, nil, nil)
require.NoError(t, err)
require.NoError(t, s.AddTransport(quicTr))
webtransportTr, err := webtransport.New(priv, reuse, nil, nil)
webtransportTr, err := webtransport.New(priv, nil, reuse, nil, nil)
require.NoError(t, err)
require.NoError(t, s.AddTransport(webtransportTr))
h := sha256.Sum256([]byte("foo"))
Expand Down
7 changes: 6 additions & 1 deletion p2p/transport/webtransport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/pnet"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/security/noise"
"github.com/libp2p/go-libp2p/p2p/security/noise/pb"
Expand Down Expand Up @@ -93,7 +94,11 @@ var _ tpt.Transport = &transport{}
var _ tpt.Resolver = &transport{}
var _ io.Closer = &transport{}

func New(key ic.PrivKey, connManager *quicreuse.ConnManager, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, opts ...Option) (tpt.Transport, error) {
func New(key ic.PrivKey, psk pnet.PSK, connManager *quicreuse.ConnManager, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, opts ...Option) (tpt.Transport, error) {
if len(psk) > 0 {
log.Error("WebTransport doesn't support private networks yet.")
return nil, errors.New("WebTransport doesn't support private networks yet")
}
id, err := peer.IDFromPrivateKey(key)
if err != nil {
return nil, err
Expand Down
64 changes: 32 additions & 32 deletions p2p/transport/webtransport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func newConnManager(t *testing.T, opts ...quicreuse.Option) *quicreuse.ConnManag

func TestTransport(t *testing.T) {
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), nil, &network.NullResourceManager{})
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
Expand All @@ -122,7 +122,7 @@ func TestTransport(t *testing.T) {
addrChan := make(chan ma.Multiaddr)
go func() {
_, clientKey := newIdentity(t)
tr2, err := libp2pwebtransport.New(clientKey, newConnManager(t), nil, &network.NullResourceManager{})
tr2, err := libp2pwebtransport.New(clientKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr2.(io.Closer).Close()

Expand Down Expand Up @@ -158,7 +158,7 @@ func TestTransport(t *testing.T) {

func TestHashVerification(t *testing.T) {
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), nil, &network.NullResourceManager{})
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
Expand All @@ -171,7 +171,7 @@ func TestHashVerification(t *testing.T) {
}()

_, clientKey := newIdentity(t)
tr2, err := libp2pwebtransport.New(clientKey, newConnManager(t), nil, &network.NullResourceManager{})
tr2, err := libp2pwebtransport.New(clientKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr2.(io.Closer).Close()

Expand Down Expand Up @@ -209,7 +209,7 @@ func TestCanDial(t *testing.T) {
}

_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{})
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()

Expand All @@ -235,7 +235,7 @@ func TestListenAddrValidity(t *testing.T) {
}

_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{})
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()

Expand All @@ -252,7 +252,7 @@ func TestListenAddrValidity(t *testing.T) {

func TestListenerAddrs(t *testing.T) {
_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{})
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()

Expand All @@ -275,7 +275,7 @@ func TestResourceManagerDialing(t *testing.T) {
p := peer.ID("foobar")

_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, rcmgr)
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, rcmgr)
require.NoError(t, err)
defer tr.(io.Closer).Close()

Expand All @@ -290,7 +290,7 @@ func TestResourceManagerDialing(t *testing.T) {

func TestResourceManagerListening(t *testing.T) {
clientID, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{})
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()

Expand All @@ -299,7 +299,7 @@ func TestResourceManagerListening(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, rcmgr)
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, rcmgr)
require.NoError(t, err)
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
Expand All @@ -325,7 +325,7 @@ func TestResourceManagerListening(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, rcmgr)
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, rcmgr)
require.NoError(t, err)
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
Expand Down Expand Up @@ -369,7 +369,7 @@ func TestConnectionGaterDialing(t *testing.T) {
connGater := NewMockConnectionGater(ctrl)

serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), nil, &network.NullResourceManager{})
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
Expand All @@ -380,7 +380,7 @@ func TestConnectionGaterDialing(t *testing.T) {
require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr())
})
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, newConnManager(t), connGater, &network.NullResourceManager{})
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), connGater, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
Expand All @@ -393,7 +393,7 @@ func TestConnectionGaterInterceptAccept(t *testing.T) {
connGater := NewMockConnectionGater(ctrl)

serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), connGater, &network.NullResourceManager{})
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), connGater, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
Expand All @@ -406,7 +406,7 @@ func TestConnectionGaterInterceptAccept(t *testing.T) {
})

_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{})
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
Expand All @@ -419,15 +419,15 @@ func TestConnectionGaterInterceptSecured(t *testing.T) {
connGater := NewMockConnectionGater(ctrl)

serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), connGater, &network.NullResourceManager{})
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), connGater, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
defer ln.Close()

clientID, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{})
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()

Expand Down Expand Up @@ -485,7 +485,7 @@ func TestStaticTLSConf(t *testing.T) {
tlsConf := getTLSConf(t, net.ParseIP("127.0.0.1"), time.Now(), time.Now().Add(365*24*time.Hour))

serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSConfig(tlsConf))
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSConfig(tlsConf))
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
Expand All @@ -495,7 +495,7 @@ func TestStaticTLSConf(t *testing.T) {

t.Run("fails when the certificate is invalid", func(t *testing.T) {
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{})
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()

Expand All @@ -509,7 +509,7 @@ func TestStaticTLSConf(t *testing.T) {

t.Run("fails when dialing with a wrong certhash", func(t *testing.T) {
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{})
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()

Expand All @@ -524,7 +524,7 @@ func TestStaticTLSConf(t *testing.T) {
store := x509.NewCertPool()
store.AddCert(tlsConf.Certificates[0].Leaf)
tlsConf := &tls.Config{RootCAs: store}
cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSClientConfig(tlsConf))
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSClientConfig(tlsConf))
require.NoError(t, err)
defer cl.(io.Closer).Close()

Expand All @@ -537,7 +537,7 @@ func TestStaticTLSConf(t *testing.T) {

func TestAcceptQueueFilledUp(t *testing.T) {
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), nil, &network.NullResourceManager{})
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
Expand All @@ -547,7 +547,7 @@ func TestAcceptQueueFilledUp(t *testing.T) {
newConn := func() (tpt.CapableConn, error) {
t.Helper()
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{})
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()
return cl.Dial(context.Background(), ln.Multiaddr(), serverID)
Expand Down Expand Up @@ -577,15 +577,15 @@ func TestSNIIsSent(t *testing.T) {
return tlsConf, nil
},
}
tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSConfig(tlsConf))
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSConfig(tlsConf))
require.NoError(t, err)
defer tr.(io.Closer).Close()

ln1, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)

_, key2 := newIdentity(t)
clientTr, err := libp2pwebtransport.New(key2, newConnManager(t), nil, &network.NullResourceManager{})
clientTr, err := libp2pwebtransport.New(key2, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()

Expand Down Expand Up @@ -643,7 +643,7 @@ func TestFlowControlWindowIncrease(t *testing.T) {
serverID, serverKey := newIdentity(t)
serverWindowIncreases := make(chan int, 100)
serverRcmgr := &reportingRcmgr{report: serverWindowIncreases}
tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), nil, serverRcmgr)
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, serverRcmgr)
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
Expand All @@ -670,7 +670,7 @@ func TestFlowControlWindowIncrease(t *testing.T) {
_, clientKey := newIdentity(t)
clientWindowIncreases := make(chan int, 100)
clientRcmgr := &reportingRcmgr{report: clientWindowIncreases}
tr2, err := libp2pwebtransport.New(clientKey, newConnManager(t), nil, clientRcmgr)
tr2, err := libp2pwebtransport.New(clientKey, nil, newConnManager(t), nil, clientRcmgr)
require.NoError(t, err)
defer tr2.(io.Closer).Close()

Expand Down Expand Up @@ -754,7 +754,7 @@ func serverSendsBackValidCert(t *testing.T, timeSinceUnixEpoch time.Duration, ke

priv, _, err := test.SeededTestKeyPair(ic.Ed25519, 256, keySeed)
require.NoError(t, err)
tr, err := libp2pwebtransport.New(priv, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
tr, err := libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
require.NoError(t, err)
l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
Expand Down Expand Up @@ -833,7 +833,7 @@ func TestServerRotatesCertCorrectly(t *testing.T) {
if err != nil {
return false
}
tr, err := libp2pwebtransport.New(priv, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
tr, err := libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
if err != nil {
return false
}
Expand All @@ -847,7 +847,7 @@ func TestServerRotatesCertCorrectly(t *testing.T) {

// These two certificates together are valid for at most certValidity - (4*clockSkewAllowance)
cl.Add(certValidity - (4 * clockSkewAllowance) - time.Second)
tr, err = libp2pwebtransport.New(priv, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
tr, err = libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
if err != nil {
return false
}
Expand Down Expand Up @@ -883,7 +883,7 @@ func TestServerRotatesCertCorrectlyAfterSteps(t *testing.T) {

priv, _, err := test.RandTestKeyPair(ic.Ed25519, 256)
require.NoError(t, err)
tr, err := libp2pwebtransport.New(priv, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
tr, err := libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
require.NoError(t, err)

l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
Expand All @@ -896,7 +896,7 @@ func TestServerRotatesCertCorrectlyAfterSteps(t *testing.T) {
// e.g. certhash/A/certhash/B ... -> ... certhash/B/certhash/C ... -> ... certhash/C/certhash/D
for i := 0; i < 200; i++ {
cl.Add(24 * time.Hour)
tr, err := libp2pwebtransport.New(priv, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
tr, err := libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
require.NoError(t, err)
l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
Expand Down