Skip to content

Commit

Permalink
add integration tests using a very long certificate chain
Browse files Browse the repository at this point in the history
This will trigger the amplification protection.
  • Loading branch information
marten-seemann committed May 11, 2020
1 parent 918bfd5 commit 3bc2e98
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 93 deletions.
97 changes: 55 additions & 42 deletions integrationtests/self/handshake_drop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package self_test

import (
"context"
"crypto/tls"
"fmt"
mrand "math/rand"
"net"
Expand Down Expand Up @@ -32,7 +33,7 @@ var _ = Describe("Handshake drop tests", func() {

const timeout = 10 * time.Minute

startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, version protocol.VersionNumber) {
startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool, version protocol.VersionNumber) {
conf := getQuicConfigForServer(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeTimeout: timeout,
Expand All @@ -41,8 +42,14 @@ var _ = Describe("Handshake drop tests", func() {
if !doRetry {
conf.AcceptToken = func(net.Addr, *quic.Token) bool { return true }
}
var tlsConf *tls.Config
if longCertChain {
tlsConf = getTLSConfigWithLongCertChain()
} else {
tlsConf = getTLSConfig()
}
var err error
ln, err = quic.ListenAddr("localhost:0", getTLSConfig(), conf)
ln, err = quic.ListenAddr("localhost:0", tlsConf, conf)
Expect(err).ToNot(HaveOccurred())
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
Expand Down Expand Up @@ -184,46 +191,52 @@ var _ = Describe("Handshake drop tests", func() {
}

Context(desc, func() {
for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} {
app := a

Context(app.name, func() {
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() {
var incoming, outgoing int32
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var p int32
switch d {
case quicproxy.DirectionIncoming:
p = atomic.AddInt32(&incoming, 1)
case quicproxy.DirectionOutgoing:
p = atomic.AddInt32(&outgoing, 1)
}
return p == 1 && d.Is(direction)
}, doRetry, version)
app.run(version)
})

It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() {
var incoming, outgoing int32
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var p int32
switch d {
case quicproxy.DirectionIncoming:
p = atomic.AddInt32(&incoming, 1)
case quicproxy.DirectionOutgoing:
p = atomic.AddInt32(&outgoing, 1)
}
return p == 2 && d.Is(direction)
}, doRetry, version)
app.run(version)
})

It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() {
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
return d.Is(direction) && stochasticDropper(3)
}, doRetry, version)
app.run(version)
})
for _, lcc := range []bool{false, true} {
longCertChain := lcc

Context(fmt.Sprintf("using a long certificate chain: %t", longCertChain), func() {
for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} {
app := a

Context(app.name, func() {
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() {
var incoming, outgoing int32
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var p int32
switch d {
case quicproxy.DirectionIncoming:
p = atomic.AddInt32(&incoming, 1)
case quicproxy.DirectionOutgoing:
p = atomic.AddInt32(&outgoing, 1)
}
return p == 1 && d.Is(direction)
}, doRetry, longCertChain, version)
app.run(version)
})

It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() {
var incoming, outgoing int32
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var p int32
switch d {
case quicproxy.DirectionIncoming:
p = atomic.AddInt32(&incoming, 1)
case quicproxy.DirectionOutgoing:
p = atomic.AddInt32(&outgoing, 1)
}
return p == 2 && d.Is(direction)
}, doRetry, longCertChain, version)
app.run(version)
})

It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() {
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
return d.Is(direction) && stochasticDropper(3)
}, doRetry, longCertChain, version)
app.run(version)
})
})
}
})
}
})
Expand Down
46 changes: 28 additions & 18 deletions integrationtests/self/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,12 @@ var _ = Describe("Handshake tests", func() {
server quic.Listener
serverConfig *quic.Config
acceptStopped chan struct{}
tlsServerConf *tls.Config
)

BeforeEach(func() {
server = nil
acceptStopped = make(chan struct{})
serverConfig = getQuicConfigForServer(nil)
tlsServerConf = getTLSConfig()
})

AfterEach(func() {
Expand All @@ -68,10 +66,10 @@ var _ = Describe("Handshake tests", func() {
}
})

runServer := func() quic.Listener {
runServer := func(tlsConf *tls.Config) {
var err error
// start the server
server, err = quic.ListenAddr("localhost:0", tlsServerConf, serverConfig)
server, err = quic.ListenAddr("localhost:0", tlsConf, serverConfig)
Expect(err).ToNot(HaveOccurred())

go func() {
Expand All @@ -83,7 +81,6 @@ var _ = Describe("Handshake tests", func() {
}
}
}()
return server
}

if !israce.Enabled {
Expand All @@ -103,7 +100,7 @@ var _ = Describe("Handshake tests", func() {
// the server doesn't support the highest supported version, which is the first one the client will try
// but it supports a bunch of versions that the client doesn't speak
serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9}
server := runServer()
runServer(getTLSConfig())
defer server.Close()
sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
Expand All @@ -119,7 +116,7 @@ var _ = Describe("Handshake tests", func() {
// the server doesn't support the highest supported version, which is the first one the client will try
// but it supports a bunch of versions that the client doesn't speak
serverConfig.Versions = supportedVersions
server := runServer()
runServer(getTLSConfig())
defer server.Close()
sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
Expand All @@ -145,9 +142,11 @@ var _ = Describe("Handshake tests", func() {
suiteID := id

It(fmt.Sprintf("using %s", name), func() {
tlsServerConf.CipherSuites = []uint16{suiteID}
ln, err := quic.ListenAddr("localhost:0", tlsServerConf, serverConfig)
tlsConf := getTLSConfig()
tlsConf.CipherSuites = []uint16{suiteID}
ln, err := quic.ListenAddr("localhost:0", tlsConf, serverConfig)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()

go func() {
defer GinkgoRecover()
Expand Down Expand Up @@ -177,7 +176,7 @@ var _ = Describe("Handshake tests", func() {
}
})

Context("Certifiate validation", func() {
Context("Certificate validation", func() {
for _, v := range protocol.SupportedVersions {
version := v

Expand All @@ -189,11 +188,8 @@ var _ = Describe("Handshake tests", func() {
clientConfig = getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}})
})

JustBeforeEach(func() {
runServer()
})

It("accepts the certificate", func() {
runServer(getTLSConfig())
_, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
Expand All @@ -202,7 +198,18 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
})

It("works with a long certificate chain", func() {
runServer(getTLSConfigWithLongCertChain())
_, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}}),
)
Expect(err).ToNot(HaveOccurred())
})

It("errors if the server name doesn't match", func() {
runServer(getTLSConfig())
_, err := quic.DialAddr(
fmt.Sprintf("127.0.0.1:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
Expand All @@ -212,7 +219,10 @@ var _ = Describe("Handshake tests", func() {
})

It("fails the handshake if the client fails to provide the requested client cert", func() {
tlsServerConf.ClientAuth = tls.RequireAndVerifyClientCert
tlsConf := getTLSConfig()
tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
runServer(tlsConf)

sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
Expand All @@ -234,6 +244,7 @@ var _ = Describe("Handshake tests", func() {
})

It("uses the ServerName in the tls.Config", func() {
runServer(getTLSConfig())
tlsConf := getTLSClientConfig()
tlsConf.ServerName = "localhost"
_, err := quic.DialAddr(
Expand Down Expand Up @@ -350,7 +361,7 @@ var _ = Describe("Handshake tests", func() {

Context("ALPN", func() {
It("negotiates an application protocol", func() {
ln, err := quic.ListenAddr("localhost:0", tlsServerConf, serverConfig)
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())

done := make(chan struct{})
Expand Down Expand Up @@ -379,7 +390,7 @@ var _ = Describe("Handshake tests", func() {
})

It("errors if application protocol negotiation fails", func() {
server := runServer()
runServer(getTLSConfig())

tlsConf := getTLSClientConfig()
tlsConf.NextProtos = []string{"foobar"}
Expand All @@ -391,7 +402,6 @@ var _ = Describe("Handshake tests", func() {
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("CRYPTO_ERROR"))
Expect(err.Error()).To(ContainSubstring("no application protocol"))
Expect(server.Close()).To(Succeed())
})
})

Expand Down
Loading

0 comments on commit 3bc2e98

Please sign in to comment.