From 6b19c27306ecdfd7d02b8cabcd185e8fbc274b1d Mon Sep 17 00:00:00 2001 From: Aleksandr Lyapunov Date: Tue, 30 Aug 2022 12:32:50 +0300 Subject: [PATCH] Client: implement SSL connection By default SSL is turned off. Enable it with -DTNTCXX_ENABLE_SSL. Closes #28 --- CMakeLists.txt | 38 ++- examples/Simple.cpp | 8 +- src/Client/Connector.hpp | 36 ++- src/Client/EpollNetProvider.hpp | 19 +- src/Client/LibevNetProvider.hpp | 19 +- src/Client/Stream.hpp | 8 + src/Client/UnixSSLStream.hpp | 458 ++++++++++++++++++++++++++++++++ test/ClientTest.cpp | 71 +++-- test/Utils/System.hpp | 36 ++- test/cfg.lua | 3 + test/cfg_ssl.lua | 40 +++ test/gen_ssl.sh | 68 +++++ 12 files changed, 744 insertions(+), 60 deletions(-) create mode 100644 src/Client/UnixSSLStream.hpp create mode 100644 test/cfg_ssl.lua create mode 100755 test/gen_ssl.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 654dce9c4..10e54f808 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,10 +9,34 @@ ENDIF() MESSAGE(STATUS "Build type: ${CMAKE_BUILD_TYPE}") FIND_PACKAGE (benchmark QUIET) +SET(COMMON_LIB ev) + +# OpenSSL +IF (TNTCXX_ENABLE_SSL) +FIND_PACKAGE(OpenSSL) +IF (OPENSSL_FOUND) + MESSAGE(STATUS "OpenSSL ${OPENSSL_VERSION} found") + INCLUDE_DIRECTORIES(${OPENSSL_INCLUDE_DIR}) +ELSE() + MESSAGE(FATAL_ERROR "Could NOT find OpenSSL development files (libssl-dev/openssl-devel package)") +ENDIF() + +# OpenSSL can require Z library (depending on build time options), so we add +# it to libraries list in case of static openssl linking. +IF(OPENSSL_USE_STATIC_LIBS) + SET(OPENSSL_LIBRARIES ${OPENSSL_LIBRARIES} ${ZLIB_LIBRARIES}) +ENDIF() + +SET(COMMON_LIB ${COMMON_LIB} ${OPENSSL_LIBRARIES}) + +ENDIF() # IF (TNTCXX_ENABLE_SSL) + SET(CMAKE_CXX_STANDARD 17) SET(CMAKE_C_STANDARD 11) ADD_COMPILE_OPTIONS(-Wall -Wextra -Werror) CONFIGURE_FILE(./test/cfg.lua test_cfg.lua COPYONLY) +CONFIGURE_FILE(./test/cfg_ssl.lua test_cfg_ssl.lua COPYONLY) +CONFIGURE_FILE(./test/gen_ssl.sh test_gen_ssl.sh COPYONLY) INCLUDE_DIRECTORIES(.) INCLUDE_DIRECTORIES(./src) @@ -36,9 +60,14 @@ ADD_EXECUTABLE(EncDecUnit.test src/mpp/mpp.hpp test/EncDecTest.cpp) ADD_EXECUTABLE(Client.test src/Client/Connector.hpp test/ClientTest.cpp) ADD_EXECUTABLE(ClientPerfTest.test src/Client/Connector.hpp test/ClientPerfTest.cpp) ADD_EXECUTABLE(SimpleExample examples/Simple.cpp) -TARGET_LINK_LIBRARIES(SimpleExample ev) -TARGET_LINK_LIBRARIES(ClientPerfTest.test ev) -TARGET_LINK_LIBRARIES(Client.test ev) +TARGET_LINK_LIBRARIES(SimpleExample ${COMMON_LIB}) +TARGET_LINK_LIBRARIES(ClientPerfTest.test ${COMMON_LIB}) +TARGET_LINK_LIBRARIES(Client.test ${COMMON_LIB}) +IF (TNTCXX_ENABLE_SSL) + ADD_EXECUTABLE(ClientSSL.test src/Client/Connector.hpp test/ClientTest.cpp) + TARGET_LINK_LIBRARIES(ClientSSL.test ${COMMON_LIB}) + TARGET_COMPILE_DEFINITIONS(ClientSSL.test PUBLIC TNTCXX_ENABLE_SSL) +ENDIF() IF (benchmark_FOUND) ADD_EXECUTABLE(BufferGPerf.test src/Buffer/Buffer.hpp test/BufferGPerfTest.cpp) @@ -59,3 +88,6 @@ ADD_TEST(NAME ListUnit.test COMMAND ListUnit.test) ADD_TEST(NAME RulesUnit.test COMMAND RulesUnit.test) ADD_TEST(NAME EncDecUnit.test COMMAND EncDecUnit.test) ADD_TEST(NAME Client.test COMMAND Client.test) +IF (TNTCXX_ENABLE_SSL) + ADD_TEST(NAME ClientSSL.test COMMAND ClientSSL.test) +ENDIF() diff --git a/examples/Simple.cpp b/examples/Simple.cpp index 60bd5927a..c9e44e535 100644 --- a/examples/Simple.cpp +++ b/examples/Simple.cpp @@ -125,7 +125,9 @@ main() * exception free, so we rely only on return codes. */ //doclabel06-1 - int rc = client.connect(conn, address, port); + int rc = client.connect(conn, {.address = address, + .service = std::to_string(port), + /* .transport = STREAM_SSL, */}); //doclabel06-2 if (rc != 0) { //assert(conn.getError().saved_errno != 0); @@ -217,7 +219,9 @@ main() //doclabel11-3 /* Let's create another connection. */ Connection another(client); - if (client.connect(another, address, port) != 0) { + if (client.connect(another, {.address = address, + .service = std::to_string(port), + /* .transport = STREAM_SSL, */}) != 0) { std::cerr << conn.getError().msg << std::endl; return -1; } diff --git a/src/Client/Connector.hpp b/src/Client/Connector.hpp index 664b8ed10..bdf12deaf 100644 --- a/src/Client/Connector.hpp +++ b/src/Client/Connector.hpp @@ -30,12 +30,22 @@ * SUCH DAMAGE. */ #include "Connection.hpp" + +#ifdef TNTCXX_ENABLE_SSL +#include "UnixSSLStream.hpp" +#else #include "UnixPlainStream.hpp" +#endif + #include "../Utils/Timer.hpp" #include +#ifdef TNTCXX_ENABLE_SSL +using DefaultStream = UnixSSLStream; +#else using DefaultStream = UnixPlainStream; +#endif /** * MacOS does not have epoll so let's use Libev as default network provider. @@ -59,6 +69,8 @@ class Connector Connector(const Connector& connector) = delete; Connector& operator = (const Connector& connector) = delete; //////////////////////////////Main API////////////////////////////////// + int connect(Connection &conn, + const ConnectOptions &opts); int connect(Connection &conn, const std::string& addr, unsigned port); @@ -94,19 +106,33 @@ Connector::~Connector() template int Connector::connect(Connection &conn, - const std::string& addr, - unsigned port) + const ConnectOptions &opts) { //Make sure that connection is not yet established. assert(conn.get_strm().has_status(SS_DEAD)); - if (m_NetProvider.connect(conn, addr, port) != 0) { - LOG_ERROR("Failed to connect to ", addr, ':', port); + if (m_NetProvider.connect(conn, opts) != 0) { + LOG_ERROR("Failed to connect to ", + opts.address, ':', opts.service); return -1; } - LOG_DEBUG("Connection to ", addr, ':', port, " has been established"); + LOG_DEBUG("Connection to ", opts.address, ':', opts.service, + " has been established"); return 0; } +template +int +Connector::connect(Connection &conn, + const std::string& addr, + unsigned port) +{ + std::string service = port == 0 ? std::string{} : std::to_string(port); + return connect(conn, { + .address = addr, + .service = service, + }); +} + template void Connector::close(Connection &conn) diff --git a/src/Client/EpollNetProvider.hpp b/src/Client/EpollNetProvider.hpp index 4b04e124d..962a267d4 100644 --- a/src/Client/EpollNetProvider.hpp +++ b/src/Client/EpollNetProvider.hpp @@ -58,7 +58,7 @@ class EpollNetProvider { using Connector_t = Connector; EpollNetProvider(Connector_t &connector); ~EpollNetProvider(); - int connect(Conn_t &conn, const std::string& addr, uint16_t port); + int connect(Conn_t &conn, const ConnectOptions &opts); void close(Conn_t &conn); /** Read and write to sockets; polling using epoll. */ int wait(int timeout); @@ -135,21 +135,16 @@ EpollNetProvider::setPollSetting(Conn_t &conn, int setting) { template int -EpollNetProvider::connect(Conn_t &conn, const std::string &addr, - uint16_t port) +EpollNetProvider::connect(Conn_t &conn, + const ConnectOptions &opts) { auto &strm = conn.get_strm(); - std::string service = port == 0 ? std::string{} : std::to_string(port); - if (strm.connect({ - .address = addr, - .service = service, - }) < 0) { - conn.setError( - std::string("Failed to establish connection to ") + - std::string(addr)); + if (strm.connect(opts) < 0) { + conn.setError("Failed to establish connection to " + + opts.address); return -1; } - LOG_DEBUG("Connected to ", addr, ", socket is ", strm.get_fd()); + LOG_DEBUG("Connected to ", opts.address, ", socket is ", strm.get_fd()); conn.getImpl()->is_greeting_received = false; registerEpoll(conn); diff --git a/src/Client/LibevNetProvider.hpp b/src/Client/LibevNetProvider.hpp index 49bc6c477..3f803afa8 100644 --- a/src/Client/LibevNetProvider.hpp +++ b/src/Client/LibevNetProvider.hpp @@ -90,7 +90,7 @@ class LibevNetProvider { using Connector_t = Connector; LibevNetProvider(Connector_t &connector, struct ev_loop *loop = nullptr); - int connect(Conn_t &conn, const std::string& addr, uint16_t port); + int connect(Conn_t &conn, const ConnectOptions &opts); void close(Conn_t &conn); int wait(int timeout); @@ -285,21 +285,16 @@ LibevNetProvider::registerWatchers(Conn_t &conn, int fd) template int -LibevNetProvider::connect(Conn_t &conn, const std::string &addr, - uint16_t port) +LibevNetProvider::connect(Conn_t &conn, + const ConnectOptions &opts) { auto &strm = conn.get_strm(); - std::string service = port == 0 ? std::string{} : std::to_string(port); - if (strm.connect({ - .address = addr, - .service = service, - }) < 0) { - conn.setError( - std::string("Failed to establish connection to ") + - std::string(addr)); + if (strm.connect(opts) < 0) { + conn.setError("Failed to establish connection to " + + opts.address); return -1; } - LOG_DEBUG("Connected to ", addr, ", socket is ", strm.get_fd()); + LOG_DEBUG("Connected to ", opts.address, ", socket is ", strm.get_fd()); conn.getImpl()->is_greeting_received = false; registerWatchers(conn, strm.get_fd()); diff --git a/src/Client/Stream.hpp b/src/Client/Stream.hpp index 224ea91b6..5306fa6d9 100644 --- a/src/Client/Stream.hpp +++ b/src/Client/Stream.hpp @@ -98,6 +98,14 @@ struct ConnectOptions { /** Time span limit for connection establishment. */ size_t connect_timeout = DEFAULT_CONNECT_TIMEOUT; + /** SSL settings. */ + std::string ssl_cert_file{}; + std::string ssl_key_file{}; + std::string ssl_ca_file{}; + std::string ssl_ciphers{}; + std::string ssl_passwd{}; + std::string ssl_passwd_file{}; + /** Standard output. */ friend inline std::ostream & operator<<(std::ostream &strm, const ConnectOptions &opts); diff --git a/src/Client/UnixSSLStream.hpp b/src/Client/UnixSSLStream.hpp new file mode 100644 index 000000000..7eda0106d --- /dev/null +++ b/src/Client/UnixSSLStream.hpp @@ -0,0 +1,458 @@ +/* + Copyright 2010-2022 Tarantool AUTHORS: please see AUTHORS file. + + Redistribution and use in source and binary forms, with or + without modification, are permitted provided that the following + conditions are met: + + 1. Redistributions of source code must retain the above + copyright notice, this list of conditions and the + following disclaimer. + + 2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials + provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY AUTHORS ``AS IS'' AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL + AUTHORS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, + INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR + BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF + THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + SUCH DAMAGE. +*/ +#pragma once + +#include +#include +#include +#include + +#include "UnixPlainStream.hpp" + +#ifdef TNTCXX_ENABLE_SSL_GOST +extern void +ENGINE_load_gost(void); +#else +static inline void +ENGINE_load_gost(void) {} +#endif // TNTCXX_ENABLE_SSL_GOST + +/** + * Holder of SSL context. + */ +class SSLContext { +public: + SSLContext() = default; + inline ~SSLContext(); + SSLContext(const SSLContext&) = delete; + SSLContext &operator=(const SSLContext&) = delete; + SSLContext(SSLContext&&) = default; + SSLContext &operator=(SSLContext&&) = default; + + /** + * Create context with options. Return 0 on success, -1 on error. + * See get_last_error() in case of error. + */ + inline int create(const ConnectOptions &opts); + /** Cast to ssl context pointer. */ + operator SSL_CTX *() const { return ssl_ctx; } + /** Get last error that happend upon creation. */ + const char *get_last_error() const { return last_error; } + +private: + /** + * Dummy callback passed to SSL_CTX_set_default_passwd_cb. + * Used to disable command-line prompt. + */ + static inline int + dummy_passwd_cb(char *, int , int , void *) { return 0; } + + /** + * Loads SSL private key and returns 0 on success or -1 on error. + * + * The private key file may be encrypted. This function tries to decrypt + * the key using passwords in the following order: + * 1. String stored in the passwd argument. Skipped if passwd is NULL. + * 2. Every line from the file specified by the passwd_file argument. + * Skipped if passwd_file is NULL. + * 3. Empty password. + */ + inline int + load_private_key(const ConnectOptions &opts); + + Resource ssl_ctx; + const char *last_error = nullptr; +}; + +/** + * Unix stream that supports SSL encryption. + * Support non-encrypted connections too. + */ +class UnixSSLStream : public UnixPlainStream { +public: + UnixSSLStream() noexcept = default; + inline ~UnixSSLStream() noexcept; + UnixSSLStream(const UnixSSLStream&) = delete; + UnixSSLStream &operator=(const UnixSSLStream&) = delete; + UnixSSLStream(UnixSSLStream &&a) noexcept = default; + UnixSSLStream &operator=(UnixSSLStream &&a) noexcept = default; + + /** + * Connect to address. Return 0 on success, -1 on error. + * Pending (inprogress) connection has a successfull result. + */ + int connect(const ConnectOptions &opts); + /** + * Receive data to connection. + * Return positive number - number of bytes was received. + * Return 0 if nothing was received but there's no error. + * Return -1 on error. + * One must check the stream status to understand what happens. + */ + ssize_t send(struct iovec *iov, size_t iov_count); + /** + * Receive data to connection. + * Return positive number - number of bytes was received. + * Return 0 if nothing was received but there's no error. + * Return -1 on error. + * One must check the stream status to understand what happens. + */ + ssize_t recv(struct iovec *iov, size_t iov_count); + +private: + SSLContext ssl_context; + Resource ssl; +}; + +///////////////////////////////////////////////////////////////////// +////////////////////////// Implementation ////////////////////////// +///////////////////////////////////////////////////////////////////// + +namespace { + +class SSLInit { +public: + static SSLInit &instance(); +private: + inline SSLInit(); + inline ~SSLInit(); +}; + +SSLInit& SSLInit::instance() +{ + static SSLInit instance; + return instance; +} + +SSLInit::SSLInit() +{ + /* NB: GOST engine must be loaded before OpenSSL initialization. */ + ENGINE_load_gost(); +#if OPENSSL_VERSION_NUMBER < 0x10100000L || defined(LIBRESSL_VERSION_NUMBER) + OpenSSL_add_all_digests(); + OpenSSL_add_all_ciphers(); + ERR_load_crypto_strings(); +#else + OPENSSL_init_crypto(0, NULL); + OPENSSL_init_ssl(0, NULL); +#endif +} + +SSLInit::~SSLInit() +{ +#ifdef OPENSSL_cleanup + OPENSSL_cleanup(); +#endif +} + +} // anonymous namespace + +SSLContext::~SSLContext() +{ + if (ssl_ctx != nullptr) + SSL_CTX_free(ssl_ctx); +} + +int SSLContext::create(const ConnectOptions &opts) +{ + SSLInit::instance(); + + const char *cert_file = opts.ssl_cert_file.empty() ? + nullptr : opts.ssl_cert_file.c_str(); + const char *key_file = opts.ssl_key_file.empty() ? + nullptr : opts.ssl_key_file.c_str(); + const char *ca_file = opts.ssl_ca_file.empty() ? + nullptr : opts.ssl_ca_file.c_str(); + const char *ciphers = opts.ssl_ciphers.empty() ? + nullptr : opts.ssl_ciphers.c_str(); + + if (ssl_ctx != nullptr) + SSL_CTX_free(ssl_ctx); + const SSL_METHOD *method = TLS_client_method(); + ssl_ctx = SSL_CTX_new(method); + if (ssl_ctx == NULL) { + last_error = "SSL_CTX_new failed"; + return -1; + } + + /* + * Require TLSv1.2, because other protocol versions don't seem to + * support the GOST cipher: + * + * $ openssl ciphers -s -tls1_2 | tr ':' '\n' | grep GOST + * + * (Should we add a configuration parameter for this?) + */ + if (SSL_CTX_set_min_proto_version(ssl_ctx, TLS1_2_VERSION) != 1 || + SSL_CTX_set_max_proto_version(ssl_ctx, TLS1_2_VERSION) != 1) { + last_error = "Error setting SSL protocol version"; + return -1; + } + if (cert_file != NULL && + SSL_CTX_use_certificate_file(ssl_ctx, cert_file, + SSL_FILETYPE_PEM) != 1) { + last_error = "Error loading SSL certificate"; + return -1; + } + if (key_file != NULL && + load_private_key(opts) != 0) + return -1; + if (ca_file != NULL && + SSL_CTX_load_verify_locations(ssl_ctx, ca_file, NULL) != 1) { + last_error = "Error loading SSL CA"; + return -1; + } + if (ca_file != NULL) { + SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER | + SSL_VERIFY_FAIL_IF_NO_PEER_CERT, + NULL); + } + /* + * NB: SSL_CTX_set_cipher_list() only works for procol versions TLSv1.2 + * and below. For TLSv1.3 we'd have to use SSL_CTX_set_ciphersuites() + * instead. + */ + if (ciphers != NULL && + SSL_CTX_set_cipher_list(ssl_ctx, ciphers) != 1) { + last_error = "Error setting SSL ciphers"; + return -1; + } + return 0; +} + +int SSLContext::load_private_key(const ConnectOptions &opts) +{ + const char *key_file = opts.ssl_key_file.empty() ? + nullptr : opts.ssl_key_file.c_str(); + const char *passwd = opts.ssl_passwd.empty() ? + nullptr : opts.ssl_passwd.c_str(); + const char *passwd_file = opts.ssl_passwd_file.empty() ? + nullptr : opts.ssl_passwd_file.c_str(); + + /* + * Set the password callback to NULL to make the SSL library use + * the callback userdata for a password. + */ + SSL_CTX_set_default_passwd_cb(ssl_ctx, NULL); + + if (passwd != NULL) { + /* + * Try to load the key file using the password specified + * in the passwd argument. + */ + SSL_CTX_set_default_passwd_cb_userdata(ssl_ctx, (void *)passwd); + int ret = SSL_CTX_use_PrivateKey_file(ssl_ctx, key_file, + SSL_FILETYPE_PEM); + SSL_CTX_set_default_passwd_cb_userdata(ssl_ctx, NULL); + if (ret == 1) + return 0; + } + if (passwd_file != NULL) { + /* + * Try to load the key file using every password stored in + * the password file. + */ + FILE *f = fopen(passwd_file, "r"); + if (f == NULL) { + last_error = "Error reading SSL password file"; + return -1; + } + char *buf = NULL; + size_t buf_size = 0; + bool is_error = false; + bool is_loaded = false; + while (true) { + /* Read a line from the password file. */ + errno = 0; + ssize_t len = getline(&buf, &buf_size, f); + if (len <= 0) { + if (errno == 0) + break; /* EOF */ + last_error = "Error reading SSL password file"; + is_error = true; + break; + } + char *s = buf; + /* Trim a terminating new line. */ + if (s[len - 1] == '\n') + s[len - 1] = '\0'; + /* Try to load the key file using the password. */ + SSL_CTX_set_default_passwd_cb_userdata(ssl_ctx, s); + int ret = SSL_CTX_use_PrivateKey_file(ssl_ctx, key_file, + SSL_FILETYPE_PEM); + SSL_CTX_set_default_passwd_cb_userdata(ssl_ctx, NULL); + if (ret == 1) { + is_loaded = true; + break; + } + /* Ignore the error and try another password. */ + ERR_clear_error(); + } + free(buf); + fclose(f); + if (is_loaded) + return 0; + if (is_error) + return -1; + } + /* Try to load the key file without a password. */ + SSL_CTX_set_default_passwd_cb(ssl_ctx, dummy_passwd_cb); + int ret = SSL_CTX_use_PrivateKey_file(ssl_ctx, key_file, + SSL_FILETYPE_PEM); + SSL_CTX_set_default_passwd_cb(ssl_ctx, NULL); + if (ret != 1) { + last_error = "Error loading SSL private key"; + return -1; + } + return 0; +} + +UnixSSLStream::~UnixSSLStream() +{ + if (ssl != nullptr) + SSL_free(ssl); +} + +int UnixSSLStream::connect(const ConnectOptions &opts_arg) +{ + if (ssl != nullptr) { + SSL_free(ssl); + ssl = nullptr; + } + + if (UnixStream::connect(opts_arg) != 0) + return -1; + if (opts.transport == STREAM_PLAIN) + return 0; + assert(opts.transport == STREAM_SSL); + + if (ssl_context.create(opts) != 0) + return US_DIE("SSL_context create failed", + ssl_context.get_last_error()); + + if ((ssl = SSL_new(ssl_context)) == NULL) + return US_DIE("SSL_new failed"); + + if (SSL_set_fd(ssl, get_fd()) != 1) + return US_DIE("SSL_set_fd failed"); + + SSL_set_connect_state(ssl); + + return 0; +} + +ssize_t UnixSSLStream::send(struct iovec *iov, size_t iov_count) +{ + if (opts.transport == STREAM_PLAIN) + return UnixPlainStream::send(iov, iov_count); + assert(opts.transport == STREAM_SSL); + + if (!(has_status(SS_ESTABLISHED))) { + if (has_status(SS_DEAD)) + return US_DIE("Send to dead stream"); + if (check_pending() != 0) + return -1; + if (iov_count == 0) + return 0; + } + + reset_status(SS_NEED_EVENT_FOR_WRITE); + size_t sent; + int ret = SSL_write_ex(ssl, iov->iov_base, iov->iov_len, &sent); + if (ret == 1) + return sent; + + int err = SSL_get_error(ssl, ret); + switch (err) { + case SSL_ERROR_WANT_READ: + return set_status(SS_NEED_READ_EVENT_FOR_WRITE); + case SSL_ERROR_WANT_WRITE: + return set_status(SS_NEED_WRITE_EVENT_FOR_WRITE); + case SSL_ERROR_SSL: + return US_DIE("SSL send failed"); + default: + assert(err == SSL_ERROR_SYSCALL); + if (errno == 0) { + /* + * The remote end closed the socket for reading. + * The OpenSSL library treats this situation as + * a system error with errno = 0. We report it + * as EPIPE. + */ + errno = EPIPE; + } + return US_DIE("Send failed", strerror(errno)); + } +} + +ssize_t UnixSSLStream::recv(struct iovec *iov, size_t iov_count) +{ + if (opts.transport == STREAM_PLAIN) + return UnixPlainStream::recv(iov, iov_count); + assert(opts.transport == STREAM_SSL); + + if (!(has_status(SS_ESTABLISHED))) { + if (has_status(SS_DEAD)) + return US_DIE("Recv from dead stream"); + else + return US_DIE("Recv from pending stream"); + } + + reset_status(SS_NEED_EVENT_FOR_READ); + errno = 0; + size_t rcvd; + int ret = SSL_read_ex(ssl, iov->iov_base, iov->iov_len, &rcvd); + if (ret == 1) + return rcvd; + + int err = SSL_get_error(ssl, ret); + switch (err) { + case SSL_ERROR_ZERO_RETURN: + return 0; + case SSL_ERROR_WANT_READ: + return set_status(SS_NEED_READ_EVENT_FOR_READ); + case SSL_ERROR_WANT_WRITE: + return set_status(SS_NEED_WRITE_EVENT_FOR_READ); + case SSL_ERROR_SSL: + return US_DIE("SSL revc failed"); + default: + assert(err == SSL_ERROR_SYSCALL); + if (errno == 0) { + /* + * The remote end closed the socket for writing. + * The OpenSSL library treats this situation as + * a system error with errno = 0. We ignore it. + */ + return 0; + } + return US_DIE("Send failed", strerror(errno)); + } +} diff --git a/test/ClientTest.cpp b/test/ClientTest.cpp index facc5b276..b46529959 100644 --- a/test/ClientTest.cpp +++ b/test/ClientTest.cpp @@ -40,6 +40,27 @@ int port = 3301; const char *unixsocket = "./tnt.sock"; int WAIT_TIMEOUT = 1000; //milliseconds +#ifdef TNTCXX_ENABLE_SSL +constexpr bool enable_ssl = true; +constexpr StreamTransport transport = STREAM_SSL; +#else +constexpr bool enable_ssl = false; +constexpr StreamTransport transport = STREAM_PLAIN; +#endif + +template +static int +test_connect(Connector &client, Connection &conn, const std::string &addr, + unsigned port) +{ + std::string service = port == 0 ? std::string{} : std::to_string(port); + return client.connect(conn, { + .address = addr, + .service = service, + .transport = transport, + }); +} + enum ResultFormat { TUPLES = 0, MULTI_RETURN, @@ -105,16 +126,16 @@ trivial(Connector &client) fail_unless(rc != 0); /* Connect to the wrong address. */ TEST_CASE("Bad address"); - rc = client.connect(conn, "asdasd", port); + rc = test_connect(client, conn, "asdasd", port); fail_unless(rc != 0); TEST_CASE("Unreachable address"); - rc = client.connect(conn, "101.101.101", port); + rc = test_connect(client, conn, "101.101.101", port); fail_unless(rc != 0); TEST_CASE("Wrong port"); - rc = client.connect(conn, localhost, -666); + rc = test_connect(client, conn, localhost, -666); fail_unless(rc != 0); TEST_CASE("Connect timeout"); - rc = client.connect(conn, "8.8.8.8", port); + rc = test_connect(client, conn, "8.8.8.8", port); fail_unless(rc != 0); } @@ -125,7 +146,7 @@ single_conn_ping(Connector &client) { TEST_INIT(0); Connection conn(client); - int rc = client.connect(conn, localhost, port); + int rc = test_connect(client, conn, localhost, port); fail_unless(rc == 0); rid_t f = conn.ping(); fail_unless(!conn.futureIsReady(f)); @@ -180,18 +201,18 @@ many_conn_ping(Connector &client) Connection conn1(client); Connection conn2(client); Connection conn3(client); - int rc = client.connect(conn1, localhost, port); + int rc = test_connect(client, conn1, localhost, port); fail_unless(rc == 0); /* Try to connect to the same port */ - rc = client.connect(conn2, localhost, port); + rc = test_connect(client, conn2, localhost, port); fail_unless(rc == 0); /* * Try to re-connect to another address whithout closing * current connection. */ - //rc = client.connect(conn2, localhost, port + 2); + //rc = test_connect(client, conn2, localhost, port + 2); //fail_unless(rc != 0); - rc = client.connect(conn3, localhost, port); + rc = test_connect(client, conn3, localhost, port); fail_unless(rc == 0); rid_t f1 = conn1.ping(); rid_t f2 = conn2.ping(); @@ -212,7 +233,7 @@ single_conn_error(Connector &client) { TEST_INIT(0); Connection conn(client); - int rc = client.connect(conn, localhost, port); + int rc = test_connect(client, conn, localhost, port); fail_unless(rc == 0); /* Fake space id. */ uint32_t space_id = -111; @@ -249,7 +270,7 @@ single_conn_replace(Connector &client) { TEST_INIT(0); Connection conn(client); - int rc = client.connect(conn, localhost, port); + int rc = test_connect(client, conn, localhost, port); fail_unless(rc == 0); uint32_t space_id = 512; std::tuple data = std::make_tuple(666, "111", 1.01); @@ -268,6 +289,7 @@ single_conn_replace(Connector &client) client.wait(conn, f2, WAIT_TIMEOUT); fail_unless(conn.futureIsReady(f2)); response = conn.getResponse(f2); + printResponse(conn, *response); fail_unless(response != std::nullopt); fail_unless(response->body.data != std::nullopt); fail_unless(response->body.error_stack == std::nullopt); @@ -282,7 +304,7 @@ single_conn_insert(Connector &client) { TEST_INIT(0); Connection conn(client); - int rc = client.connect(conn, localhost, port); + int rc = test_connect(client, conn, localhost, port); fail_unless(rc == 0); TEST_CASE("Successful inserts"); uint32_t space_id = 512; @@ -326,7 +348,7 @@ single_conn_update(Connector &client) { TEST_INIT(0); Connection conn(client); - int rc = client.connect(conn, localhost, port); + int rc = test_connect(client, conn, localhost, port); fail_unless(rc == 0); TEST_CASE("Successful update"); uint32_t space_id = 512; @@ -361,7 +383,7 @@ single_conn_delete(Connector &client) { TEST_INIT(0); Connection conn(client); - int rc = client.connect(conn, localhost, port); + int rc = test_connect(client, conn, localhost, port); fail_unless(rc == 0); TEST_CASE("Successful deletes"); uint32_t space_id = 512; @@ -405,7 +427,7 @@ single_conn_upsert(Connector &client) { TEST_INIT(0); Connection conn(client); - int rc = client.connect(conn, localhost, port); + int rc = test_connect(client, conn, localhost, port); fail_unless(rc == 0); TEST_CASE("upsert-insert"); uint32_t space_id = 512; @@ -436,7 +458,7 @@ single_conn_select(Connector &client) { TEST_INIT(0); Connection conn(client); - int rc = client.connect(conn, localhost, port); + int rc = test_connect(client, conn, localhost, port); fail_unless(rc == 0); uint32_t space_id = 512; uint32_t index_id = 0; @@ -497,7 +519,7 @@ single_conn_call(Connector &client) const static char *return_multi = "remote_multi"; Connection conn(client); - int rc = client.connect(conn, localhost, port); + int rc = test_connect(client, conn, localhost, port); fail_unless(rc == 0); TEST_CASE("call remote_replace"); @@ -508,17 +530,17 @@ single_conn_call(Connector &client) fail_unless(conn.futureIsReady(f1)); std::optional> response = conn.getResponse(f1); fail_unless(response != std::nullopt); + printResponse(conn, *response); fail_unless(response->body.data != std::nullopt); fail_unless(response->body.error_stack == std::nullopt); - printResponse(conn, *response); client.wait(conn, f2, WAIT_TIMEOUT); fail_unless(conn.futureIsReady(f2)); response = conn.getResponse(f2); fail_unless(response != std::nullopt); + printResponse(conn, *response); fail_unless(response->body.data != std::nullopt); fail_unless(response->body.error_stack == std::nullopt); - printResponse(conn, *response); TEST_CASE("call remote_uint"); rid_t f4 = conn.call(return_uint, std::make_tuple()); @@ -574,7 +596,7 @@ replace_unix_socket(Connector &client) TEST_INIT(0); Connection conn(client); - int rc = client.connect(conn, unixsocket, 0); + int rc = test_connect(client, conn, unixsocket, 0); fail_unless(rc == 0); TEST_CASE("select from unix socket"); @@ -597,8 +619,15 @@ int main() { if (cleanDir() != 0) return -1; - if (launchTarantool() != 0) + +#ifdef TNTCXX_ENABLE_SSL + if (genSSLCert() != 0) return -1; +#endif + + if (launchTarantool(enable_ssl) != 0) + return -1; + sleep(1); #ifdef __linux__ using NetEpoll_t = EpollNetProvider; diff --git a/test/Utils/System.hpp b/test/Utils/System.hpp index 2e73f8b1c..29760c7ce 100644 --- a/test/Utils/System.hpp +++ b/test/Utils/System.hpp @@ -38,8 +38,8 @@ #include #endif -int -launchTarantool() +inline int +launchTarantool(bool enable_ssl = false) { pid_t ppid_before_fork = getpid(); pid_t pid = fork(); @@ -66,7 +66,8 @@ launchTarantool() "just before prctl call"); exit(EXIT_FAILURE); } - if (execlp("tarantool", "tarantool", "test_cfg.lua", NULL) == -1) { + const char *script = enable_ssl ? "test_cfg_ssl.lua" : "test_cfg.lua"; + if (execlp("tarantool", "tarantool", script, NULL) == -1) { fprintf(stderr, "Can't launch Tarantool: execlp failed! %s\n", strerror(errno)); kill(getppid(), SIGKILL); @@ -74,7 +75,7 @@ launchTarantool() exit(EXIT_FAILURE); } -int +inline int cleanDir() { pid_t pid = fork(); if (pid == -1) { @@ -90,9 +91,34 @@ cleanDir() { fprintf(stderr, "wait: child finished with error \n"); return -1; } - if (execlp("/bin/sh", "/bin/sh", "-c", "rm *xlog *snap", NULL) == -1) { + if (execlp("/bin/sh", "/bin/sh", "-c", + "rm -f *.xlog *.snap tarantool.log", NULL) == -1) { fprintf(stderr, "Failed to clean directory: execlp failed! %s\n", strerror(errno)); } exit(EXIT_FAILURE); } + +inline int +genSSLCert() { + pid_t pid = fork(); + if (pid == -1) { + fprintf(stderr, "Failed to clean directory: fork failed! %s\n", + strerror(errno)); + return -1; + } + if (pid != 0) { + int status; + wait(&status); + if (WIFEXITED(status) != 0) + return 0; + fprintf(stderr, "wait: child finished with error \n"); + return -1; + } + if (execlp("/bin/sh", "/bin/sh", "-c", + "./test_gen_ssl.sh", NULL) == -1) { + fprintf(stderr, "Failed to generate ssl: execlp failed! %s\n", + strerror(errno)); + } + exit(EXIT_FAILURE); +} diff --git a/test/cfg.lua b/test/cfg.lua index 905bd85ab..d24bebee2 100644 --- a/test/cfg.lua +++ b/test/cfg.lua @@ -30,3 +30,6 @@ end function get_rps() return box.stat.net().REQUESTS.rps end + +box.schema.user.grant('guest', 'read,write', 'space', 'T', nil, {if_not_exists=true}) +box.schema.user.grant('guest', 'execute', 'universe', nil, {if_not_exists=true}) diff --git a/test/cfg_ssl.lua b/test/cfg_ssl.lua new file mode 100644 index 000000000..2299885c1 --- /dev/null +++ b/test/cfg_ssl.lua @@ -0,0 +1,40 @@ +local ssl_params = {transport = 'ssl', ssl_cert_file = './ssl_test/server.crt', ssl_key_file = './ssl_test/server.key'} +box.cfg{listen = {{uri = 'localhost:3301', params = ssl_params}, + {uri = 'unix/:./tnt.sock', params = ssl_params}}, + net_msg_max = 10000, + readahead = 163200, + log = 'tarantool.log', +} + +if box.space.t then box.space.t:drop() end +s = box.schema.space.create('T') +s:format{{name='id',type='integer'},{name='a',type='string'},{name='b',type='number'}} +s:create_index('primary') +s:replace{1, 'asd', 1.123} + +function remote_replace(arg1, arg2, arg3) + return box.space.T:replace({arg1, arg2, arg3}) +end + +function remote_select() + return box.space.T:select() +end + +function remote_uint() + return 666 +end + +function remote_multi() + return 'Hello', 1, 6.66 +end + +function bench_func(...) + return {...} +end + +function get_rps() + return box.stat.net().REQUESTS.rps +end + +box.schema.user.grant('guest', 'read,write', 'space', 'T', nil, {if_not_exists=true}) +box.schema.user.grant('guest', 'execute', 'universe', nil, {if_not_exists=true}) diff --git a/test/gen_ssl.sh b/test/gen_ssl.sh new file mode 100755 index 000000000..797751bac --- /dev/null +++ b/test/gen_ssl.sh @@ -0,0 +1,68 @@ +#!/bin/bash +# +# This script generates SSL keys and certificates used for testing. + +mkdir -p ssl_test +cd ssl_test + +HOST=tarantool.io +NEWKEY_ARG=rsa:4096 +DAYS_ARG=36500 + +# +# Generates new CA. +# +# The new private key and certificate are written to files "${ca}.key" and +# "${ca}.crt" respectively where "${ca}" is the name of the new CA as given +# in the first argument. +# +gen_ca() +{ + local ca="${1}" + openssl req -new -nodes -newkey "${NEWKEY_ARG}" -days "${DAYS_ARG}" -x509 \ + -subj "/OU=Unknown/O=Unknown/L=Unknown/ST=unknown/C=AU" \ + -keyout "${ca}.key" -out "${ca}.crt" +} + +# +# Generates new certificate and private key signed by the given CA. +# +# The new private key and certificate are written to files "${cert}.key" and +# "${cert}.crt" respectively where "${cert}" is the certificate name as given +# in the first argument. The CA and private key used for signing the new +# certificate should be located in "${ca}.cert" and "${ca}.key" where "${ca}" +# is the value of the second argument. +# +gen_cert() +{ + local cert="${1}" + local ca="${2}" + openssl req -new -nodes -newkey "${NEWKEY_ARG}" \ + -subj "/CN=${HOST}/OU=Unknown/O=Unknown/L=Unknown/ST=unknown/C=AU" \ + -keyout "${cert}.key" -out "${cert}.csr" + openssl x509 -req -days "${DAYS_ARG}" \ + -CAcreateserial -CA "${ca}.crt" -CAkey "${ca}.key" \ + -in "${cert}.csr" -out "${cert}.crt" + rm -f "${cert}.csr" +} + +# +# Encrypt private key file. +# $1 - file name without extension +# $2 - pass phrase +# +# Encrypted key is written to ${1}.enc.key +# +encrypt_key() +{ + local key="${1}" + local pass="${2}" + openssl rsa -aes256 -passout "pass:${pass}" \ + -in "${key}.key" -out "${key}.enc.key" +} + +gen_ca ca +gen_cert server ca +gen_cert client ca +encrypt_key server 1q2w3e +encrypt_key client 123qwe