Skip to content
This repository has been archived by the owner on Aug 19, 2022. It is now read-only.

fix: don't fail the handshake when the libp2p extension is critical #88

Merged
merged 1 commit into from
Aug 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const certificatePrefix = "libp2p-tls-handshake:"
const alpn string = "libp2p"

var extensionID = getPrefixedExtensionID([]int{1, 1})
var extensionCritical bool // so we can mark the extension critical in tests

type signedKey struct {
PubKey []byte
Expand Down Expand Up @@ -113,25 +114,32 @@ func PubKeyFromCertChain(chain []*x509.Certificate) (ic.PubKey, error) {
cert := chain[0]
pool := x509.NewCertPool()
pool.AddCert(cert)
if _, err := cert.Verify(x509.VerifyOptions{Roots: pool}); err != nil {
// If we return an x509 error here, it will be sent on the wire.
// Wrap the error to avoid that.
return nil, fmt.Errorf("certificate verification failed: %s", err)
}

var found bool
var keyExt pkix.Extension
// find the libp2p key extension, skipping all unknown extensions
for _, ext := range cert.Extensions {
if extensionIDEqual(ext.Id, extensionID) {
keyExt = ext
found = true
for i, oident := range cert.UnhandledCriticalExtensions {
if oident.Equal(ext.Id) {
// delete the extension from UnhandledCriticalExtensions
cert.UnhandledCriticalExtensions = append(cert.UnhandledCriticalExtensions[:i], cert.UnhandledCriticalExtensions[i+1:]...)
break
}
}
break
}
}
if !found {
return nil, errors.New("expected certificate to contain the key extension")
}
if _, err := cert.Verify(x509.VerifyOptions{Roots: pool}); err != nil {
// If we return an x509 error here, it will be sent on the wire.
// Wrap the error to avoid that.
return nil, fmt.Errorf("certificate verification failed: %s", err)
}

var sk signedKey
if _, err := asn1.Unmarshal(keyExt.Value, &sk); err != nil {
return nil, fmt.Errorf("unmarshalling signed certificate failed: %s", err)
Expand Down Expand Up @@ -190,7 +198,7 @@ func keyToCertificate(sk ic.PrivKey) (*tls.Certificate, error) {
NotAfter: time.Now().Add(certValidityPeriod),
// after calling CreateCertificate, these will end up in Certificate.Extensions
ExtraExtensions: []pkix.Extension{
{Id: extensionID, Value: value},
{Id: extensionID, Critical: extensionCritical, Value: value},
},
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, certKey.Public(), certKey)
Expand Down
79 changes: 45 additions & 34 deletions transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,42 +87,50 @@ var _ = Describe("Transport", func() {
clientID, clientKey = createPeer()
})

It("handshakes", func() {
clientTransport, err := New(clientKey)
Expect(err).ToNot(HaveOccurred())
serverTransport, err := New(serverKey)
Expect(err).ToNot(HaveOccurred())
Context("successful handshakes", func() {
for _, critical := range []bool{true, false} {
crit := critical

clientInsecureConn, serverInsecureConn := connect()
It(fmt.Sprintf("handshakes, extension critical: %t", crit), func() {
extensionCritical = crit
defer func() { extensionCritical = false }()
clientTransport, err := New(clientKey)
Expect(err).ToNot(HaveOccurred())
serverTransport, err := New(serverKey)
Expect(err).ToNot(HaveOccurred())

serverConnChan := make(chan sec.SecureConn)
go func() {
defer GinkgoRecover()
serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn)
Expect(err).ToNot(HaveOccurred())
serverConnChan <- serverConn
}()
clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID)
Expect(err).ToNot(HaveOccurred())
var serverConn sec.SecureConn
Eventually(serverConnChan).Should(Receive(&serverConn))
defer clientConn.Close()
defer serverConn.Close()
Expect(clientConn.LocalPeer()).To(Equal(clientID))
Expect(serverConn.LocalPeer()).To(Equal(serverID))
Expect(clientConn.LocalPrivateKey()).To(Equal(clientKey))
Expect(serverConn.LocalPrivateKey()).To(Equal(serverKey))
Expect(clientConn.RemotePeer()).To(Equal(serverID))
Expect(serverConn.RemotePeer()).To(Equal(clientID))
Expect(clientConn.RemotePublicKey()).To(Equal(serverKey.GetPublic()))
Expect(serverConn.RemotePublicKey()).To(Equal(clientKey.GetPublic()))
// exchange some data
_, err = serverConn.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6)
_, err = clientConn.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(string(b)).To(Equal("foobar"))
clientInsecureConn, serverInsecureConn := connect()

serverConnChan := make(chan sec.SecureConn)
go func() {
defer GinkgoRecover()
serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn)
Expect(err).ToNot(HaveOccurred())
serverConnChan <- serverConn
}()
clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID)
Expect(err).ToNot(HaveOccurred())
var serverConn sec.SecureConn
Eventually(serverConnChan).Should(Receive(&serverConn))
defer clientConn.Close()
defer serverConn.Close()
Expect(clientConn.LocalPeer()).To(Equal(clientID))
Expect(serverConn.LocalPeer()).To(Equal(serverID))
Expect(clientConn.LocalPrivateKey()).To(Equal(clientKey))
Expect(serverConn.LocalPrivateKey()).To(Equal(serverKey))
Expect(clientConn.RemotePeer()).To(Equal(serverID))
Expect(serverConn.RemotePeer()).To(Equal(clientID))
Expect(clientConn.RemotePublicKey()).To(Equal(serverKey.GetPublic()))
Expect(serverConn.RemotePublicKey()).To(Equal(clientKey.GetPublic()))
// exchange some data
_, err = serverConn.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6)
_, err = clientConn.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(string(b)).To(Equal("foobar"))
})
}
})

It("fails when the context of the outgoing connection is canceled", func() {
Expand Down Expand Up @@ -243,6 +251,9 @@ var _ = Describe("Transport", func() {
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(-time.Minute),
ExtraExtensions: []pkix.Extension{
{Id: extensionID, Value: []byte("foobar")},
},
})
identity.config.Certificates = []tls.Certificate{cert}
}
Expand Down