diff --git a/include/envoy/network/transport_socket.h b/include/envoy/network/transport_socket.h index a5390c29853a..8f849211ebd3 100644 --- a/include/envoy/network/transport_socket.h +++ b/include/envoy/network/transport_socket.h @@ -17,7 +17,9 @@ enum class PostIoAction { // Close the connection. Close, // Keep the connection open. - KeepOpen + KeepOpen, + // Wait for the connection to complete asynchronously. + Wait, }; /** @@ -73,6 +75,9 @@ class TransportSocketCallbacks { * @param event supplies the connection event */ virtual void raiseEvent(ConnectionEvent event) PURE; + + virtual void registerFdReadyCallback(int fd, std::function) PURE; + virtual void disableFdReadyCallback(std::function) PURE; }; /** diff --git a/source/common/network/connection_impl.h b/source/common/network/connection_impl.h index a9e796e3b02f..b31c40c204bb 100644 --- a/source/common/network/connection_impl.h +++ b/source/common/network/connection_impl.h @@ -112,6 +112,17 @@ class ConnectionImpl : public virtual Connection, // Reconsider how to make fairness happen. void setReadBufferReady() override { file_event_->activate(Event::FileReadyType::Read); } + void registerFdReadyCallback(int fd, std::function cb) override { + ssl_async_event_ = dispatcher_.createFileEvent( + fd, [this, cb, fd](uint32_t events) -> void { + cb(fd); + }, Event::FileTriggerType::Edge, Event::FileReadyType::Read); + } + + void disableFdReadyCallback(std::function cb) override { + ssl_async_event_->setEnabled(0); + } + // Obtain global next connection ID. This should only be used in tests. static uint64_t nextGlobalIdForTest() { return next_global_id_; } @@ -135,6 +146,7 @@ class ConnectionImpl : public virtual Connection, ConnectionEvent immediate_error_event_{ConnectionEvent::Connected}; bool bind_error_{false}; Event::FileEventPtr file_event_; + Event::FileEventPtr ssl_async_event_; private: void onFileEvent(uint32_t events); diff --git a/source/common/ssl/ssl_socket.cc b/source/common/ssl/ssl_socket.cc index 1700249c7764..633bcae7df67 100644 --- a/source/common/ssl/ssl_socket.cc +++ b/source/common/ssl/ssl_socket.cc @@ -7,22 +7,36 @@ #include "common/ssl/utility.h" #include "absl/strings/str_replace.h" +#include "openssl/async.h" #include "openssl/err.h" #include "openssl/x509v3.h" +#include + + using Envoy::Network::PostIoAction; namespace Envoy { namespace Ssl { SslSocket::SslSocket(ContextSharedPtr ctx, InitialState state) - : ctx_(std::dynamic_pointer_cast(ctx)), ssl_(ctx_->newSsl()) { + : ctx_(std::dynamic_pointer_cast(ctx)), ssl_(ctx_->newSsl()), + handshake_in_progress_(false) { if (state == InitialState::Client) { SSL_set_connect_state(ssl_.get()); } else { ASSERT(state == InitialState::Server); SSL_set_accept_state(ssl_.get()); } + SSL_set_mode(ssl_.get(), SSL_MODE_ASYNC); +} + +SslSocket::~SslSocket() { + // If we let the SSL socket be destroyed while there is a pending async SSL operation, + // it seems that the callback handler will use already freed memory. Busyloop here (yuck!) + // to prevent a crash. + + while (SSL_waiting_for_async(ssl_.get())); } void SslSocket::setTransportSocketCallbacks(Network::TransportSocketCallbacks& callbacks) { @@ -93,14 +107,34 @@ Network::IoResult SslSocket::doRead(Buffer::Instance& read_buffer) { return {action, bytes_read, end_stream}; } +void SslSocket::asyncCb(int fd) { + ENVOY_CONN_LOG(debug, "SSL async done! fd: {}", callbacks_->connection(), fd); + + // TODO: how do we know which OpenSSL async request just ended, because we always + // seem to get the same FD back for all requests? + + // We lose the return value here, so might consider propagating it with an event + // in case we run into "Close" result from the handshake handler. + if (!handshake_complete_) + this->doHandshake(); +} + PostIoAction SslSocket::doHandshake() { ASSERT(!handshake_complete_); + + std::function cb = [this](int fd) void { this->asyncCb(fd); }; + int rc = SSL_do_handshake(ssl_.get()); if (rc == 1) { ENVOY_CONN_LOG(debug, "handshake complete", callbacks_->connection()); handshake_complete_ = true; + if (handshake_in_progress_) { + // No need to disable if no async engine in use. + callbacks_->disableFdReadyCallback(cb); + } + handshake_in_progress_ = false; ctx_->logHandshake(ssl_.get()); callbacks_->raiseEvent(Network::ConnectionEvent::Connected); @@ -109,12 +143,48 @@ PostIoAction SslSocket::doHandshake() { ? PostIoAction::KeepOpen : PostIoAction::Close; } else { + OSSL_ASYNC_FD *fds; + size_t numfds; int err = SSL_get_error(ssl_.get(), rc); ENVOY_CONN_LOG(debug, "handshake error: {}", callbacks_->connection(), err); switch (err) { case SSL_ERROR_WANT_READ: + ENVOY_CONN_LOG(debug, "SSL handshake: SSL_ERROR_WANT_READ", callbacks_->connection()); + return PostIoAction::KeepOpen; case SSL_ERROR_WANT_WRITE: + ENVOY_CONN_LOG(debug, "SSL handshake: SSL_ERROR_WANT_WRITE", callbacks_->connection()); return PostIoAction::KeepOpen; + case SSL_ERROR_WANT_ASYNC: + ENVOY_CONN_LOG(debug, "SSL handshake: request async handling", callbacks_->connection()); + + if (handshake_in_progress_) { + return PostIoAction::Wait; + } + + handshake_in_progress_ = true; + rc = SSL_get_all_async_fds(ssl_.get(), NULL, &numfds); + + if (rc == 0) { + drainErrorQueue(); + return PostIoAction::Close; + } + + /* FIXME: we only wait for the first fd here! Wiil fail if multiple async engines. */ + ASSERT(numfds == 1); + + fds = (OSSL_ASYNC_FD *) malloc(numfds * sizeof(OSSL_ASYNC_FD)); + if (fds == NULL) { + drainErrorQueue(); + return PostIoAction::Close; + } + + rc = SSL_get_all_async_fds(ssl_.get(), fds, &numfds); + + callbacks_->registerFdReadyCallback(fds[0], cb); + ENVOY_CONN_LOG(debug, "SSL async fd: {}, numfds: {}", callbacks_->connection(), fds[0], numfds); + + free(fds); + return PostIoAction::Wait; default: drainErrorQueue(); return PostIoAction::Close; @@ -210,8 +280,19 @@ void SslSocket::onConnected() { void SslSocket::shutdownSsl() { ASSERT(handshake_complete_); if (!shutdown_sent_ && callbacks_->connection().state() != Network::Connection::State::Closed) { - int rc = SSL_shutdown(ssl_.get()); - ENVOY_CONN_LOG(debug, "SSL shutdown: rc={}", callbacks_->connection(), rc); + int rc = 0, err = 0; + + // The SSL_shutdown() function also becomes asynchronous. However, we don't attach + // it to the main event loop. Just poll until it is done. + do { + rc = SSL_shutdown(ssl_.get()); + ENVOY_CONN_LOG(debug, "SSL shutdown: rc={}", callbacks_->connection(), rc); + if (rc < 0) { + err = SSL_get_error(ssl_.get(), rc); + ENVOY_CONN_LOG(debug, "SSL shutdown: err {}", callbacks_->connection(), err); + } + } while (rc < 0 && err == SSL_ERROR_WANT_ASYNC); + drainErrorQueue(); shutdown_sent_ = true; } diff --git a/source/common/ssl/ssl_socket.h b/source/common/ssl/ssl_socket.h index 68fec106eb91..3806bb527256 100644 --- a/source/common/ssl/ssl_socket.h +++ b/source/common/ssl/ssl_socket.h @@ -21,6 +21,7 @@ class SslSocket : public Network::TransportSocket, protected Logger::Loggable { public: SslSocket(ContextSharedPtr ctx, InitialState state); + ~SslSocket(); // Ssl::Connection bool peerCertificatePresented() const override; @@ -51,6 +52,7 @@ class SslSocket : public Network::TransportSocket, Network::PostIoAction doHandshake(); void drainErrorQueue(); void shutdownSsl(); + void asyncCb(int fd); // TODO: Move helper functions to the `Ssl::Utility` namespace. std::string getUriSanFromCertificate(X509* cert) const; @@ -61,6 +63,7 @@ class SslSocket : public Network::TransportSocket, ContextImplSharedPtr ctx_; bssl::UniquePtr ssl_; bool handshake_complete_{}; + bool handshake_in_progress_{}; bool shutdown_sent_{}; uint64_t bytes_to_retry_{}; mutable std::string cached_sha_256_peer_certificate_digest_;