From abc670ed0c35d03ade43f52c024350dc6c63b7e1 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 7 Aug 2021 11:43:15 +0100 Subject: [PATCH] fix: don't fail the handshake when the libp2p extension is critical --- crypto.go | 22 ++++++++----- transport_test.go | 79 +++++++++++++++++++++++++++-------------------- 2 files changed, 60 insertions(+), 41 deletions(-) diff --git a/crypto.go b/crypto.go index e6d6d5f..5a2fcf4 100644 --- a/crypto.go +++ b/crypto.go @@ -24,6 +24,7 @@ const certificatePrefix = "libp2p-tls-handshake:" const alpn string = "libp2p" var extensionID = getPrefixedExtensionID([]int{1, 1}) +var extensionCritical bool // so we can mark the extension critical in tests type signedKey struct { PubKey []byte @@ -113,12 +114,6 @@ func PubKeyFromCertChain(chain []*x509.Certificate) (ic.PubKey, error) { cert := chain[0] pool := x509.NewCertPool() pool.AddCert(cert) - if _, err := cert.Verify(x509.VerifyOptions{Roots: pool}); err != nil { - // If we return an x509 error here, it will be sent on the wire. - // Wrap the error to avoid that. - return nil, fmt.Errorf("certificate verification failed: %s", err) - } - var found bool var keyExt pkix.Extension // find the libp2p key extension, skipping all unknown extensions @@ -126,12 +121,25 @@ func PubKeyFromCertChain(chain []*x509.Certificate) (ic.PubKey, error) { if extensionIDEqual(ext.Id, extensionID) { keyExt = ext found = true + for i, oident := range cert.UnhandledCriticalExtensions { + if oident.Equal(ext.Id) { + // delete the extension from UnhandledCriticalExtensions + cert.UnhandledCriticalExtensions = append(cert.UnhandledCriticalExtensions[:i], cert.UnhandledCriticalExtensions[i+1:]...) + break + } + } break } } if !found { return nil, errors.New("expected certificate to contain the key extension") } + if _, err := cert.Verify(x509.VerifyOptions{Roots: pool}); err != nil { + // If we return an x509 error here, it will be sent on the wire. + // Wrap the error to avoid that. + return nil, fmt.Errorf("certificate verification failed: %s", err) + } + var sk signedKey if _, err := asn1.Unmarshal(keyExt.Value, &sk); err != nil { return nil, fmt.Errorf("unmarshalling signed certificate failed: %s", err) @@ -190,7 +198,7 @@ func keyToCertificate(sk ic.PrivKey) (*tls.Certificate, error) { NotAfter: time.Now().Add(certValidityPeriod), // after calling CreateCertificate, these will end up in Certificate.Extensions ExtraExtensions: []pkix.Extension{ - {Id: extensionID, Value: value}, + {Id: extensionID, Critical: extensionCritical, Value: value}, }, } certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, certKey.Public(), certKey) diff --git a/transport_test.go b/transport_test.go index 776c45e..a68164f 100644 --- a/transport_test.go +++ b/transport_test.go @@ -87,42 +87,50 @@ var _ = Describe("Transport", func() { clientID, clientKey = createPeer() }) - It("handshakes", func() { - clientTransport, err := New(clientKey) - Expect(err).ToNot(HaveOccurred()) - serverTransport, err := New(serverKey) - Expect(err).ToNot(HaveOccurred()) + Context("successful handshakes", func() { + for _, critical := range []bool{true, false} { + crit := critical - clientInsecureConn, serverInsecureConn := connect() + It(fmt.Sprintf("handshakes, extension critical: %t", crit), func() { + extensionCritical = crit + defer func() { extensionCritical = false }() + clientTransport, err := New(clientKey) + Expect(err).ToNot(HaveOccurred()) + serverTransport, err := New(serverKey) + Expect(err).ToNot(HaveOccurred()) - serverConnChan := make(chan sec.SecureConn) - go func() { - defer GinkgoRecover() - serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn) - Expect(err).ToNot(HaveOccurred()) - serverConnChan <- serverConn - }() - clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) - Expect(err).ToNot(HaveOccurred()) - var serverConn sec.SecureConn - Eventually(serverConnChan).Should(Receive(&serverConn)) - defer clientConn.Close() - defer serverConn.Close() - Expect(clientConn.LocalPeer()).To(Equal(clientID)) - Expect(serverConn.LocalPeer()).To(Equal(serverID)) - Expect(clientConn.LocalPrivateKey()).To(Equal(clientKey)) - Expect(serverConn.LocalPrivateKey()).To(Equal(serverKey)) - Expect(clientConn.RemotePeer()).To(Equal(serverID)) - Expect(serverConn.RemotePeer()).To(Equal(clientID)) - Expect(clientConn.RemotePublicKey()).To(Equal(serverKey.GetPublic())) - Expect(serverConn.RemotePublicKey()).To(Equal(clientKey.GetPublic())) - // exchange some data - _, err = serverConn.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 6) - _, err = clientConn.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(string(b)).To(Equal("foobar")) + clientInsecureConn, serverInsecureConn := connect() + + serverConnChan := make(chan sec.SecureConn) + go func() { + defer GinkgoRecover() + serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn) + Expect(err).ToNot(HaveOccurred()) + serverConnChan <- serverConn + }() + clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) + Expect(err).ToNot(HaveOccurred()) + var serverConn sec.SecureConn + Eventually(serverConnChan).Should(Receive(&serverConn)) + defer clientConn.Close() + defer serverConn.Close() + Expect(clientConn.LocalPeer()).To(Equal(clientID)) + Expect(serverConn.LocalPeer()).To(Equal(serverID)) + Expect(clientConn.LocalPrivateKey()).To(Equal(clientKey)) + Expect(serverConn.LocalPrivateKey()).To(Equal(serverKey)) + Expect(clientConn.RemotePeer()).To(Equal(serverID)) + Expect(serverConn.RemotePeer()).To(Equal(clientID)) + Expect(clientConn.RemotePublicKey()).To(Equal(serverKey.GetPublic())) + Expect(serverConn.RemotePublicKey()).To(Equal(clientKey.GetPublic())) + // exchange some data + _, err = serverConn.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 6) + _, err = clientConn.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(string(b)).To(Equal("foobar")) + }) + } }) It("fails when the context of the outgoing connection is canceled", func() { @@ -243,6 +251,9 @@ var _ = Describe("Transport", func() { SerialNumber: big.NewInt(1), NotBefore: time.Now().Add(-time.Hour), NotAfter: time.Now().Add(-time.Minute), + ExtraExtensions: []pkix.Extension{ + {Id: extensionID, Value: []byte("foobar")}, + }, }) identity.config.Certificates = []tls.Certificate{cert} }