diff --git a/lib/base/tlsstream.cpp b/lib/base/tlsstream.cpp index db54c919ed5..98400a93bbc 100644 --- a/lib/base/tlsstream.cpp +++ b/lib/base/tlsstream.cpp @@ -33,6 +33,11 @@ std::shared_ptr UnbufferedAsioTlsStream::GetPeerCertificate() return std::shared_ptr(SSL_get_peer_certificate(native_handle()), X509_free); } +STACK_OF(X509) *UnbufferedAsioTlsStream::GetPeerCertificateChain() +{ + return SSL_get_peer_cert_chain(native_handle()); +} + void UnbufferedAsioTlsStream::BeforeHandshake(handshake_type type) { namespace ssl = boost::asio::ssl; diff --git a/lib/base/tlsstream.hpp b/lib/base/tlsstream.hpp index f6e52097e11..c8b99b0dab0 100644 --- a/lib/base/tlsstream.hpp +++ b/lib/base/tlsstream.hpp @@ -77,6 +77,7 @@ class UnbufferedAsioTlsStream : public AsioTcpTlsStream bool IsVerifyOK() const; String GetVerifyError() const; std::shared_ptr GetPeerCertificate(); + STACK_OF(X509) *GetPeerCertificateChain(); template inline diff --git a/lib/base/tlsutility.cpp b/lib/base/tlsutility.cpp index 246bd5aee42..23b13cfd5a6 100644 --- a/lib/base/tlsutility.cpp +++ b/lib/base/tlsutility.cpp @@ -981,7 +981,7 @@ String BinaryToHex(const unsigned char* data, size_t length) { return output; } -bool VerifyCertificate(const std::shared_ptr &caCertificate, const std::shared_ptr &certificate, const String& crlFile) +bool VerifyCertificate(const std::shared_ptr &caCertificate, const std::shared_ptr &certificate, const String& crlFile, STACK_OF(X509) *chain) { X509_STORE *store = X509_STORE_new(); @@ -995,7 +995,7 @@ bool VerifyCertificate(const std::shared_ptr &caCertificate, const std::sh } X509_STORE_CTX *csc = X509_STORE_CTX_new(); - X509_STORE_CTX_init(csc, store, certificate.get(), nullptr); + X509_STORE_CTX_init(csc, store, certificate.get(), chain); int rc = X509_verify_cert(csc); diff --git a/lib/base/tlsutility.hpp b/lib/base/tlsutility.hpp index b0641202011..74f6964b3a8 100644 --- a/lib/base/tlsutility.hpp +++ b/lib/base/tlsutility.hpp @@ -78,7 +78,7 @@ String SHA256(const String& s); String RandomString(int length); String BinaryToHex(const unsigned char* data, size_t length); -bool VerifyCertificate(const std::shared_ptr& caCertificate, const std::shared_ptr& certificate, const String& crlFile); +bool VerifyCertificate(const std::shared_ptr& caCertificate, const std::shared_ptr& certificate, const String& crlFile, STACK_OF(X509) *chain = nullptr); bool IsCa(const std::shared_ptr& cacert); int GetCertificateVersion(const std::shared_ptr& cert); String GetSignatureAlgorithm(const std::shared_ptr& cert); diff --git a/lib/remote/jsonrpcconnection-pki.cpp b/lib/remote/jsonrpcconnection-pki.cpp index 340e12b301e..d25130c0aa3 100644 --- a/lib/remote/jsonrpcconnection-pki.cpp +++ b/lib/remote/jsonrpcconnection-pki.cpp @@ -30,6 +30,7 @@ Value RequestCertificateHandler(const MessageOrigin::Ptr& origin, const Dictiona String certText = params->Get("cert_request"); std::shared_ptr cert; + STACK_OF(X509) *chain; Dictionary::Ptr result = new Dictionary(); auto& tlsConn (origin->FromClient->GetStream()->next_layer()); @@ -37,8 +38,10 @@ Value RequestCertificateHandler(const MessageOrigin::Ptr& origin, const Dictiona /* Use the presented client certificate if not provided. */ if (certText.IsEmpty()) { cert = tlsConn.GetPeerCertificate(); + chain = tlsConn.GetPeerCertificateChain(); } else { cert = StringToCertificate(certText); + chain = nullptr; } if (!cert) { @@ -62,7 +65,7 @@ Value RequestCertificateHandler(const MessageOrigin::Ptr& origin, const Dictiona logmsg << "Received certificate request for CN '" << cn << "'"; try { - signedByCA = VerifyCertificate(cacert, cert, listener->GetCrlPath()); + signedByCA = VerifyCertificate(cacert, cert, listener->GetCrlPath(), chain); if (!signedByCA) { logmsg << " not"; }