Skip to content
This repository has been archived by the owner on Aug 19, 2022. It is now read-only.

Commit

Permalink
Merge pull request #88 from libp2p/handle-critical-extension
Browse files Browse the repository at this point in the history
fix: don't fail the handshake when the libp2p extension is critical
  • Loading branch information
marten-seemann authored Aug 7, 2021
2 parents 7530faa + abc670e commit 9a75550
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 41 deletions.
22 changes: 15 additions & 7 deletions crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const certificatePrefix = "libp2p-tls-handshake:"
const alpn string = "libp2p"

var extensionID = getPrefixedExtensionID([]int{1, 1})
var extensionCritical bool // so we can mark the extension critical in tests

type signedKey struct {
PubKey []byte
Expand Down Expand Up @@ -113,25 +114,32 @@ func PubKeyFromCertChain(chain []*x509.Certificate) (ic.PubKey, error) {
cert := chain[0]
pool := x509.NewCertPool()
pool.AddCert(cert)
if _, err := cert.Verify(x509.VerifyOptions{Roots: pool}); err != nil {
// If we return an x509 error here, it will be sent on the wire.
// Wrap the error to avoid that.
return nil, fmt.Errorf("certificate verification failed: %s", err)
}

var found bool
var keyExt pkix.Extension
// find the libp2p key extension, skipping all unknown extensions
for _, ext := range cert.Extensions {
if extensionIDEqual(ext.Id, extensionID) {
keyExt = ext
found = true
for i, oident := range cert.UnhandledCriticalExtensions {
if oident.Equal(ext.Id) {
// delete the extension from UnhandledCriticalExtensions
cert.UnhandledCriticalExtensions = append(cert.UnhandledCriticalExtensions[:i], cert.UnhandledCriticalExtensions[i+1:]...)
break
}
}
break
}
}
if !found {
return nil, errors.New("expected certificate to contain the key extension")
}
if _, err := cert.Verify(x509.VerifyOptions{Roots: pool}); err != nil {
// If we return an x509 error here, it will be sent on the wire.
// Wrap the error to avoid that.
return nil, fmt.Errorf("certificate verification failed: %s", err)
}

var sk signedKey
if _, err := asn1.Unmarshal(keyExt.Value, &sk); err != nil {
return nil, fmt.Errorf("unmarshalling signed certificate failed: %s", err)
Expand Down Expand Up @@ -190,7 +198,7 @@ func keyToCertificate(sk ic.PrivKey) (*tls.Certificate, error) {
NotAfter: time.Now().Add(certValidityPeriod),
// after calling CreateCertificate, these will end up in Certificate.Extensions
ExtraExtensions: []pkix.Extension{
{Id: extensionID, Value: value},
{Id: extensionID, Critical: extensionCritical, Value: value},
},
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, certKey.Public(), certKey)
Expand Down
79 changes: 45 additions & 34 deletions transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,42 +87,50 @@ var _ = Describe("Transport", func() {
clientID, clientKey = createPeer()
})

It("handshakes", func() {
clientTransport, err := New(clientKey)
Expect(err).ToNot(HaveOccurred())
serverTransport, err := New(serverKey)
Expect(err).ToNot(HaveOccurred())
Context("successful handshakes", func() {
for _, critical := range []bool{true, false} {
crit := critical

clientInsecureConn, serverInsecureConn := connect()
It(fmt.Sprintf("handshakes, extension critical: %t", crit), func() {
extensionCritical = crit
defer func() { extensionCritical = false }()
clientTransport, err := New(clientKey)
Expect(err).ToNot(HaveOccurred())
serverTransport, err := New(serverKey)
Expect(err).ToNot(HaveOccurred())

serverConnChan := make(chan sec.SecureConn)
go func() {
defer GinkgoRecover()
serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn)
Expect(err).ToNot(HaveOccurred())
serverConnChan <- serverConn
}()
clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID)
Expect(err).ToNot(HaveOccurred())
var serverConn sec.SecureConn
Eventually(serverConnChan).Should(Receive(&serverConn))
defer clientConn.Close()
defer serverConn.Close()
Expect(clientConn.LocalPeer()).To(Equal(clientID))
Expect(serverConn.LocalPeer()).To(Equal(serverID))
Expect(clientConn.LocalPrivateKey()).To(Equal(clientKey))
Expect(serverConn.LocalPrivateKey()).To(Equal(serverKey))
Expect(clientConn.RemotePeer()).To(Equal(serverID))
Expect(serverConn.RemotePeer()).To(Equal(clientID))
Expect(clientConn.RemotePublicKey()).To(Equal(serverKey.GetPublic()))
Expect(serverConn.RemotePublicKey()).To(Equal(clientKey.GetPublic()))
// exchange some data
_, err = serverConn.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6)
_, err = clientConn.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(string(b)).To(Equal("foobar"))
clientInsecureConn, serverInsecureConn := connect()

serverConnChan := make(chan sec.SecureConn)
go func() {
defer GinkgoRecover()
serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn)
Expect(err).ToNot(HaveOccurred())
serverConnChan <- serverConn
}()
clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID)
Expect(err).ToNot(HaveOccurred())
var serverConn sec.SecureConn
Eventually(serverConnChan).Should(Receive(&serverConn))
defer clientConn.Close()
defer serverConn.Close()
Expect(clientConn.LocalPeer()).To(Equal(clientID))
Expect(serverConn.LocalPeer()).To(Equal(serverID))
Expect(clientConn.LocalPrivateKey()).To(Equal(clientKey))
Expect(serverConn.LocalPrivateKey()).To(Equal(serverKey))
Expect(clientConn.RemotePeer()).To(Equal(serverID))
Expect(serverConn.RemotePeer()).To(Equal(clientID))
Expect(clientConn.RemotePublicKey()).To(Equal(serverKey.GetPublic()))
Expect(serverConn.RemotePublicKey()).To(Equal(clientKey.GetPublic()))
// exchange some data
_, err = serverConn.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6)
_, err = clientConn.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(string(b)).To(Equal("foobar"))
})
}
})

It("fails when the context of the outgoing connection is canceled", func() {
Expand Down Expand Up @@ -243,6 +251,9 @@ var _ = Describe("Transport", func() {
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(-time.Minute),
ExtraExtensions: []pkix.Extension{
{Id: extensionID, Value: []byte("foobar")},
},
})
identity.config.Certificates = []tls.Certificate{cert}
}
Expand Down

0 comments on commit 9a75550

Please sign in to comment.