diff --git a/source/extensions/transport_sockets/alts/BUILD b/source/extensions/transport_sockets/alts/BUILD index e3b9374c1cab..8f7df7a46a13 100644 --- a/source/extensions/transport_sockets/alts/BUILD +++ b/source/extensions/transport_sockets/alts/BUILD @@ -32,7 +32,6 @@ envoy_cc_library( hdrs = [ "tsi_frame_protector.h", ], - repository = "@envoy", deps = [ ":grpc_tsi_wrapper", "//source/common/buffer:buffer_lib", @@ -54,6 +53,28 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "tsi_socket", + srcs = [ + "tsi_socket.cc", + ], + hdrs = [ + "tsi_socket.h", + ], + deps = [ + ":noop_transport_socket_callbacks_lib", + ":tsi_frame_protector", + ":tsi_handshaker", + "//include/envoy/network:transport_socket_interface", + "//source/common/buffer:buffer_lib", + "//source/common/common:cleanup_lib", + "//source/common/common:empty_string", + "//source/common/common:enum_to_int", + "//source/common/network:raw_buffer_socket_lib", + "//source/common/protobuf:utility_lib", + ], +) + envoy_cc_library( name = "noop_transport_socket_callbacks_lib", hdrs = ["noop_transport_socket_callbacks.h"], diff --git a/source/extensions/transport_sockets/alts/noop_transport_socket_callbacks.h b/source/extensions/transport_sockets/alts/noop_transport_socket_callbacks.h index 3fe6a038104a..d72b68e8eb20 100644 --- a/source/extensions/transport_sockets/alts/noop_transport_socket_callbacks.h +++ b/source/extensions/transport_sockets/alts/noop_transport_socket_callbacks.h @@ -29,6 +29,8 @@ class NoOpTransportSocketCallbacks : public Network::TransportSocketCallbacks { Network::TransportSocketCallbacks& parent_; }; +typedef std::unique_ptr NoOpTransportSocketCallbacksPtr; + } // namespace Alts } // namespace TransportSockets } // namespace Extensions diff --git a/source/extensions/transport_sockets/alts/tsi_socket.cc b/source/extensions/transport_sockets/alts/tsi_socket.cc new file mode 100644 index 000000000000..c20887f46dc0 --- /dev/null +++ b/source/extensions/transport_sockets/alts/tsi_socket.cc @@ -0,0 +1,245 @@ +#include "extensions/transport_sockets/alts/tsi_socket.h" + +#include "common/common/assert.h" +#include "common/common/cleanup.h" +#include "common/common/empty_string.h" +#include "common/common/enum_to_int.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Alts { + +TsiSocket::TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator, + Network::TransportSocketPtr&& raw_socket) + : handshaker_factory_(handshaker_factory), handshake_validator_(handshake_validator), + raw_buffer_socket_(std::move(raw_socket)) {} + +TsiSocket::TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator) + : TsiSocket(handshaker_factory, handshake_validator, + std::make_unique()) {} + +TsiSocket::~TsiSocket() { ASSERT(!handshaker_); } + +void TsiSocket::setTransportSocketCallbacks(Envoy::Network::TransportSocketCallbacks& callbacks) { + callbacks_ = &callbacks; + + noop_callbacks_ = std::make_unique(callbacks); + raw_buffer_socket_->setTransportSocketCallbacks(*noop_callbacks_); +} + +std::string TsiSocket::protocol() const { + // TSI doesn't have a generic way to indicate application layer protocol. + // TODO(lizan): support application layer protocol from TSI for known TSIs. + return EMPTY_STRING; +} + +Network::PostIoAction TsiSocket::doHandshake() { + ASSERT(!handshake_complete_); + ENVOY_CONN_LOG(debug, "TSI: doHandshake", callbacks_->connection()); + + if (!handshaker_) { + handshaker_ = handshaker_factory_(callbacks_->connection().dispatcher(), + callbacks_->connection().localAddress(), + callbacks_->connection().remoteAddress()); + handshaker_->setHandshakerCallbacks(*this); + } + + if (!handshaker_next_calling_) { + doHandshakeNext(); + } + return Network::PostIoAction::KeepOpen; +} + +void TsiSocket::doHandshakeNext() { + ENVOY_CONN_LOG(debug, "TSI: doHandshake next: received: {}", callbacks_->connection(), + raw_read_buffer_.length()); + handshaker_next_calling_ = true; + Buffer::OwnedImpl handshaker_buffer; + handshaker_buffer.move(raw_read_buffer_); + handshaker_->next(handshaker_buffer); +} + +Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result) { + ASSERT(next_result); + + ENVOY_CONN_LOG(debug, "TSI: doHandshake next done: status: {} to_send: {}", + callbacks_->connection(), next_result->status_, next_result->to_send_->length()); + + tsi_result status = next_result->status_; + tsi_handshaker_result* handshaker_result = next_result->result_.get(); + + if (status != TSI_INCOMPLETE_DATA && status != TSI_OK) { + ENVOY_CONN_LOG(debug, "TSI: Handshake failed: status: {}", callbacks_->connection(), status); + return Network::PostIoAction::Close; + } + + if (next_result->to_send_->length() > 0) { + raw_write_buffer_.move(*next_result->to_send_); + } + + if (status == TSI_OK && handshaker_result != nullptr) { + tsi_peer peer; + // returns TSI_OK assuming there is no fatal error. Asserting OK. + status = tsi_handshaker_result_extract_peer(handshaker_result, &peer); + ASSERT(status == TSI_OK); + Cleanup peer_cleanup([&peer]() { tsi_peer_destruct(&peer); }); + ENVOY_CONN_LOG(debug, "TSI: Handshake successful: peer properties: {}", + callbacks_->connection(), peer.property_count); + for (size_t i = 0; i < peer.property_count; ++i) { + ENVOY_CONN_LOG(debug, " {}: {}", callbacks_->connection(), peer.properties[i].name, + std::string(peer.properties[i].value.data, peer.properties[i].value.length)); + } + if (handshake_validator_) { + std::string err; + const bool peer_validated = handshake_validator_(peer, err); + if (peer_validated) { + ENVOY_CONN_LOG(debug, "TSI: Handshake validation succeeded.", callbacks_->connection()); + } else { + ENVOY_CONN_LOG(debug, "TSI: Handshake validation failed: {}", callbacks_->connection(), + err); + return Network::PostIoAction::Close; + } + } else { + ENVOY_CONN_LOG(debug, "TSI: Handshake validation skipped.", callbacks_->connection()); + } + + const unsigned char* unused_bytes; + size_t unused_byte_size; + + // returns TSI_OK assuming there is no fatal error. Asserting OK. + status = + tsi_handshaker_result_get_unused_bytes(handshaker_result, &unused_bytes, &unused_byte_size); + ASSERT(status == TSI_OK); + if (unused_byte_size > 0) { + raw_read_buffer_.prepend( + absl::string_view{reinterpret_cast(unused_bytes), unused_byte_size}); + } + ENVOY_CONN_LOG(debug, "TSI: Handshake successful: unused_bytes: {}", callbacks_->connection(), + unused_byte_size); + + // returns TSI_OK assuming there is no fatal error. Asserting OK. + tsi_frame_protector* frame_protector; + status = + tsi_handshaker_result_create_frame_protector(handshaker_result, NULL, &frame_protector); + ASSERT(status == TSI_OK); + frame_protector_ = std::make_unique(frame_protector); + + handshake_complete_ = true; + callbacks_->raiseEvent(Network::ConnectionEvent::Connected); + } + + if (read_error_ || (!handshake_complete_ && end_stream_read_)) { + ENVOY_CONN_LOG(debug, "TSI: Handshake failed: end of stream without enough data", + callbacks_->connection()); + return Network::PostIoAction::Close; + } + + if (raw_read_buffer_.length() > 0) { + callbacks_->setReadBufferReady(); + } + + // Try to write raw buffer when next call is done, even this is not in do[Read|Write] stack. + if (raw_write_buffer_.length() > 0) { + return raw_buffer_socket_->doWrite(raw_write_buffer_, false).action_; + } + + return Network::PostIoAction::KeepOpen; +} + +Network::IoResult TsiSocket::doRead(Buffer::Instance& buffer) { + Network::IoResult result = {Network::PostIoAction::KeepOpen, 0, false}; + if (!end_stream_read_ && !read_error_) { + result = raw_buffer_socket_->doRead(raw_read_buffer_); + ENVOY_CONN_LOG(debug, "TSI: raw read result action {} bytes {} end_stream {}", + callbacks_->connection(), enumToInt(result.action_), result.bytes_processed_, + result.end_stream_read_); + if (result.action_ == Network::PostIoAction::Close && result.bytes_processed_ == 0) { + return result; + } + + end_stream_read_ = result.end_stream_read_; + read_error_ = result.action_ == Network::PostIoAction::Close; + } + + if (!handshake_complete_) { + Network::PostIoAction action = doHandshake(); + if (action == Network::PostIoAction::Close || !handshake_complete_) { + return {action, 0, false}; + } + } + + if (handshake_complete_) { + ASSERT(frame_protector_); + + uint64_t read_size = raw_read_buffer_.length(); + ENVOY_CONN_LOG(debug, "TSI: unprotecting buffer size: {}", callbacks_->connection(), + raw_read_buffer_.length()); + tsi_result status = frame_protector_->unprotect(raw_read_buffer_, buffer); + ENVOY_CONN_LOG(debug, "TSI: unprotected buffer left: {} result: {}", callbacks_->connection(), + raw_read_buffer_.length(), tsi_result_to_string(status)); + result.bytes_processed_ = read_size - raw_read_buffer_.length(); + } + + ENVOY_CONN_LOG(debug, "TSI: do read result action {} bytes {} end_stream {}", + callbacks_->connection(), enumToInt(result.action_), result.bytes_processed_, + result.end_stream_read_); + return result; +} + +Network::IoResult TsiSocket::doWrite(Buffer::Instance& buffer, bool end_stream) { + if (!handshake_complete_) { + Network::PostIoAction action = doHandshake(); + ASSERT(action == Network::PostIoAction::KeepOpen); + // TODO(lizan): Handle synchronous handshake when TsiHandshaker supports it. + } + + if (handshake_complete_) { + ASSERT(frame_protector_); + ENVOY_CONN_LOG(debug, "TSI: protecting buffer size: {}", callbacks_->connection(), + buffer.length()); + tsi_result status = frame_protector_->protect(buffer, raw_write_buffer_); + ENVOY_CONN_LOG(debug, "TSI: protected buffer left: {} result: {}", callbacks_->connection(), + buffer.length(), tsi_result_to_string(status)); + } + + if (raw_write_buffer_.length() > 0) { + ENVOY_CONN_LOG(debug, "TSI: raw_write length {} end_stream {}", callbacks_->connection(), + raw_write_buffer_.length(), end_stream); + return raw_buffer_socket_->doWrite(raw_write_buffer_, end_stream && (buffer.length() == 0)); + } + return {Network::PostIoAction::KeepOpen, 0, false}; +} + +void TsiSocket::closeSocket(Network::ConnectionEvent) { + if (handshaker_) { + handshaker_.release()->deferredDelete(); + } +} + +void TsiSocket::onConnected() { ASSERT(!handshake_complete_); } + +void TsiSocket::onNextDone(NextResultPtr&& result) { + handshaker_next_calling_ = false; + + Network::PostIoAction action = doHandshakeNextDone(std::move(result)); + if (action == Network::PostIoAction::Close) { + callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); + } +} + +TsiSocketFactory::TsiSocketFactory(HandshakerFactory handshaker_factory, + HandshakeValidator handshake_validator) + : handshaker_factory_(std::move(handshaker_factory)), + handshake_validator_(std::move(handshake_validator)) {} + +bool TsiSocketFactory::implementsSecureTransport() const { return true; } + +Network::TransportSocketPtr TsiSocketFactory::createTransportSocket() const { + return std::make_unique(handshaker_factory_, handshake_validator_); +} + +} // namespace Alts +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/transport_sockets/alts/tsi_socket.h b/source/extensions/transport_sockets/alts/tsi_socket.h new file mode 100644 index 000000000000..70f8a1d7aeff --- /dev/null +++ b/source/extensions/transport_sockets/alts/tsi_socket.h @@ -0,0 +1,111 @@ +#pragma once + +#include "envoy/network/transport_socket.h" + +#include "common/buffer/buffer_impl.h" +#include "common/network/raw_buffer_socket.h" + +#include "extensions/transport_sockets/alts/noop_transport_socket_callbacks.h" +#include "extensions/transport_sockets/alts/tsi_frame_protector.h" +#include "extensions/transport_sockets/alts/tsi_handshaker.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Alts { + +/** + * A factory function to create TsiHandshaker + * @param dispatcher the dispatcher for the thread where the socket is running on. + * @param local_address the local address of the connection. + * @param remote_address the remote address of the connection. + */ +typedef std::function + HandshakerFactory; + +/** + * A function to validate the peer of the connection. + * @param peer the detail peer information of the connection. + * @param err an error message to indicate why the peer is invalid. This is an + * output param that should be populated by the function implementation. + * @return true if the peer is valid or false if the peer is invalid. + */ +typedef std::function HandshakeValidator; + +/** + * A implementation of Network::TransportSocket based on gRPC TSI + */ +class TsiSocket : public Network::TransportSocket, + public TsiHandshakerCallbacks, + public Logger::Loggable { +public: + // For Test + TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator, + Network::TransportSocketPtr&& raw_socket_ptr); + + /** + * @param handshaker_factory a function to initiate a TsiHandshaker + * @param handshake_validator a function to validate the peer. Called right + * after the handshake completed with peer data to do the peer validation. + * The connection will be closed immediately if it returns false. + */ + TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator); + virtual ~TsiSocket(); + + // Network::TransportSocket + void setTransportSocketCallbacks(Envoy::Network::TransportSocketCallbacks& callbacks) override; + std::string protocol() const override; + bool canFlushClose() override { return handshake_complete_; } + const Envoy::Ssl::Connection* ssl() const override { return nullptr; } + Network::IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override; + void closeSocket(Network::ConnectionEvent event) override; + Network::IoResult doRead(Buffer::Instance& buffer) override; + void onConnected() override; + + // TsiHandshakerCallbacks + void onNextDone(NextResultPtr&& result) override; + +private: + Network::PostIoAction doHandshake(); + void doHandshakeNext(); + Network::PostIoAction doHandshakeNextDone(NextResultPtr&& next_result); + + HandshakerFactory handshaker_factory_; + HandshakeValidator handshake_validator_; + TsiHandshakerPtr handshaker_{}; + bool handshaker_next_calling_{}; + + TsiFrameProtectorPtr frame_protector_; + + Envoy::Network::TransportSocketCallbacks* callbacks_{}; + NoOpTransportSocketCallbacksPtr noop_callbacks_; + Network::TransportSocketPtr raw_buffer_socket_; + + Envoy::Buffer::OwnedImpl raw_read_buffer_; + Envoy::Buffer::OwnedImpl raw_write_buffer_; + bool handshake_complete_{}; + bool end_stream_read_{}; + bool read_error_{}; +}; + +/** + * An implementation of Network::TransportSocketFactory for TsiSocket + */ +class TsiSocketFactory : public Network::TransportSocketFactory { +public: + TsiSocketFactory(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator); + + bool implementsSecureTransport() const override; + Network::TransportSocketPtr createTransportSocket() const override; + +private: + HandshakerFactory handshaker_factory_; + HandshakeValidator handshake_validator_; +}; + +} // namespace Alts +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/transport_sockets/alts/BUILD b/test/extensions/transport_sockets/alts/BUILD index d0ef9cb73ecc..590d74bced42 100644 --- a/test/extensions/transport_sockets/alts/BUILD +++ b/test/extensions/transport_sockets/alts/BUILD @@ -33,6 +33,19 @@ envoy_extension_cc_test( ], ) +envoy_extension_cc_test( + name = "tsi_socket_test", + srcs = ["tsi_socket_test.cc"], + extension_name = "envoy.transport_sockets.alts", + deps = [ + "//include/envoy/event:dispatcher_interface", + "//source/extensions/transport_sockets/alts:tsi_socket", + "//test/mocks/buffer:buffer_mocks", + "//test/mocks/event:event_mocks", + "//test/mocks/network:network_mocks", + ], +) + envoy_extension_cc_test( name = "noop_transport_socket_callbacks_test", srcs = ["noop_transport_socket_callbacks_test.cc"], diff --git a/test/extensions/transport_sockets/alts/tsi_socket_test.cc b/test/extensions/transport_sockets/alts/tsi_socket_test.cc new file mode 100644 index 000000000000..69ad81494f4c --- /dev/null +++ b/test/extensions/transport_sockets/alts/tsi_socket_test.cc @@ -0,0 +1,393 @@ +#include "common/buffer/buffer_impl.h" + +#include "extensions/transport_sockets/alts/tsi_socket.h" + +#include "test/mocks/network/mocks.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "src/core/tsi/fake_transport_security.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Alts { + +using testing::NiceMock; +using testing::Return; +using testing::ReturnRef; +using testing::StrictMock; + +class TsiSocketTest : public testing::Test { +protected: + TsiSocketTest() { + server_.handshaker_factory_ = [](Event::Dispatcher& dispatcher, + const Network::Address::InstanceConstSharedPtr&, + const Network::Address::InstanceConstSharedPtr&) { + CHandshakerPtr handshaker{tsi_create_fake_handshaker(/*is_client=*/0)}; + + return std::make_unique(std::move(handshaker), dispatcher); + }; + + client_.handshaker_factory_ = [](Event::Dispatcher& dispatcher, + const Network::Address::InstanceConstSharedPtr&, + const Network::Address::InstanceConstSharedPtr&) { + CHandshakerPtr handshaker{tsi_create_fake_handshaker(/*is_client=*/1)}; + + return std::make_unique(std::move(handshaker), dispatcher); + }; + } + + void TearDown() override { + client_.tsi_socket_->closeSocket(Network::ConnectionEvent::LocalClose); + server_.tsi_socket_->closeSocket(Network::ConnectionEvent::RemoteClose); + } + + void initialize(HandshakeValidator server_validator, HandshakeValidator client_validator) { + server_.raw_socket_ = new NiceMock(); + + server_.tsi_socket_ = + std::make_unique(server_.handshaker_factory_, server_validator, + Network::TransportSocketPtr{server_.raw_socket_}); + + client_.raw_socket_ = new NiceMock(); + + client_.tsi_socket_ = + std::make_unique(client_.handshaker_factory_, client_validator, + Network::TransportSocketPtr{client_.raw_socket_}); + + ON_CALL(client_.callbacks_.connection_, dispatcher()).WillByDefault(ReturnRef(dispatcher_)); + ON_CALL(server_.callbacks_.connection_, dispatcher()).WillByDefault(ReturnRef(dispatcher_)); + + ON_CALL(client_.callbacks_.connection_, id()).WillByDefault(Return(11)); + ON_CALL(server_.callbacks_.connection_, id()).WillByDefault(Return(12)); + + ON_CALL(*client_.raw_socket_, doWrite(_, _)) + .WillByDefault(Invoke([&](Buffer::Instance& buffer, bool) { + Network::IoResult result = {Network::PostIoAction::KeepOpen, buffer.length(), false}; + client_to_server_.move(buffer); + return result; + })); + ON_CALL(*server_.raw_socket_, doWrite(_, _)) + .WillByDefault(Invoke([&](Buffer::Instance& buffer, bool) { + Network::IoResult result = {Network::PostIoAction::KeepOpen, buffer.length(), false}; + server_to_client_.move(buffer); + return result; + })); + + ON_CALL(*client_.raw_socket_, doRead(_)).WillByDefault(Invoke([&](Buffer::Instance& buffer) { + Network::IoResult result = {Network::PostIoAction::KeepOpen, server_to_client_.length(), + false}; + buffer.move(server_to_client_); + return result; + })); + ON_CALL(*server_.raw_socket_, doRead(_)).WillByDefault(Invoke([&](Buffer::Instance& buffer) { + Network::IoResult result = {Network::PostIoAction::KeepOpen, client_to_server_.length(), + false}; + buffer.move(client_to_server_); + return result; + })); + + client_.tsi_socket_->setTransportSocketCallbacks(client_.callbacks_); + client_.tsi_socket_->onConnected(); + + server_.tsi_socket_->setTransportSocketCallbacks(server_.callbacks_); + server_.tsi_socket_->onConnected(); + } + + void expectIoResult(Network::IoResult expected, Network::IoResult actual) { + EXPECT_EQ(expected.action_, actual.action_); + EXPECT_EQ(expected.bytes_processed_, actual.bytes_processed_); + EXPECT_EQ(expected.end_stream_read_, actual.end_stream_read_); + } + + std::string makeFakeTsiFrame(const std::string& payload) { + uint32_t length = static_cast(payload.length()) + 4; + std::string frame; + frame.reserve(length); + frame.push_back(static_cast(length)); + length >>= 8; + frame.push_back(static_cast(length)); + length >>= 8; + frame.push_back(static_cast(length)); + length >>= 8; + frame.push_back(static_cast(length)); + + frame.append(payload); + return frame; + } + + void doFakeInitHandshake() { + EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doWrite(client_.write_buffer_, false)); + EXPECT_EQ(makeFakeTsiFrame("CLIENT_INIT"), client_to_server_.toString()); + + EXPECT_CALL(*server_.raw_socket_, doRead(_)); + EXPECT_CALL(*server_.raw_socket_, doWrite(_, false)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + server_.tsi_socket_->doRead(server_.read_buffer_)); + EXPECT_EQ(makeFakeTsiFrame("SERVER_INIT"), server_to_client_.toString()); + EXPECT_EQ(0L, server_.read_buffer_.length()); + } + + void doHandshakeAndExpectSuccess() { + doFakeInitHandshake(); + + EXPECT_CALL(*client_.raw_socket_, doRead(_)); + EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doRead(client_.read_buffer_)); + EXPECT_EQ(makeFakeTsiFrame("CLIENT_FINISHED"), client_to_server_.toString()); + EXPECT_EQ(0L, client_.read_buffer_.length()); + + EXPECT_CALL(*server_.raw_socket_, doRead(_)); + EXPECT_CALL(*server_.raw_socket_, doWrite(_, false)); + EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + server_.tsi_socket_->doRead(server_.read_buffer_)); + EXPECT_EQ(makeFakeTsiFrame("SERVER_FINISHED"), server_to_client_.toString()); + + EXPECT_CALL(*client_.raw_socket_, doRead(_)); + EXPECT_CALL(client_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doRead(client_.read_buffer_)); + } + + void expectTransferDataFromClientToServer(const std::string& data) { + + EXPECT_EQ(0L, server_.read_buffer_.length()); + EXPECT_EQ(0L, client_.read_buffer_.length()); + + EXPECT_EQ("", client_.tsi_socket_->protocol()); + + EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)); + expectIoResult({Network::PostIoAction::KeepOpen, 21UL, false}, + client_.tsi_socket_->doWrite(client_.write_buffer_, false)); + EXPECT_EQ(makeFakeTsiFrame(data), client_to_server_.toString()); + + EXPECT_CALL(*server_.raw_socket_, doRead(_)); + expectIoResult({Network::PostIoAction::KeepOpen, 21UL, false}, + server_.tsi_socket_->doRead(server_.read_buffer_)); + EXPECT_EQ(data, server_.read_buffer_.toString()); + } + + struct SocketForTest { + HandshakerFactory handshaker_factory_; + std::unique_ptr tsi_socket_; + NiceMock* raw_socket_{}; + NiceMock callbacks_; + Buffer::OwnedImpl read_buffer_; + Buffer::OwnedImpl write_buffer_; + }; + + SocketForTest client_; + SocketForTest server_; + + Buffer::OwnedImpl client_to_server_; + Buffer::OwnedImpl server_to_client_; + + NiceMock dispatcher_; +}; + +static const std::string ClientToServerData = "hello from client"; + +TEST_F(TsiSocketTest, DoesNotHaveSsl) { + initialize(nullptr, nullptr); + EXPECT_EQ(nullptr, client_.tsi_socket_->ssl()); + + const auto& socket_ = *client_.tsi_socket_; + EXPECT_EQ(nullptr, socket_.ssl()); +} + +TEST_F(TsiSocketTest, HandshakeWithoutValidationAndTransferData) { + // pass a nullptr validator to skip validation. + initialize(nullptr, nullptr); + + client_.write_buffer_.add(ClientToServerData); + + doHandshakeAndExpectSuccess(); + expectTransferDataFromClientToServer(ClientToServerData); +} + +TEST_F(TsiSocketTest, HandshakeWithSucessfulValidationAndTransferData) { + auto validator = [](const tsi_peer&, std::string&) { return true; }; + initialize(validator, validator); + + client_.write_buffer_.add(ClientToServerData); + + doHandshakeAndExpectSuccess(); + expectTransferDataFromClientToServer(ClientToServerData); +} + +TEST_F(TsiSocketTest, HandshakeValidationFail) { + auto validator = [](const tsi_peer&, std::string&) { return false; }; + initialize(validator, validator); + + client_.write_buffer_.add(ClientToServerData); + + doFakeInitHandshake(); + + EXPECT_CALL(*client_.raw_socket_, doRead(_)); + EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doRead(client_.read_buffer_)); + EXPECT_EQ(makeFakeTsiFrame("CLIENT_FINISHED"), client_to_server_.toString()); + EXPECT_EQ(0L, client_.read_buffer_.length()); + + EXPECT_CALL(*server_.raw_socket_, doRead(_)); + EXPECT_CALL(server_.callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush)); + // doRead won't immediately fail, but it will result connection close. + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + server_.tsi_socket_->doRead(server_.read_buffer_)); + EXPECT_EQ(0, server_to_client_.length()); +} + +TEST_F(TsiSocketTest, HandshakeWithUnusedData) { + initialize(nullptr, nullptr); + + doFakeInitHandshake(); + EXPECT_CALL(*client_.raw_socket_, doRead(_)); + EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doRead(client_.read_buffer_)); + EXPECT_EQ(makeFakeTsiFrame("CLIENT_FINISHED"), client_to_server_.toString()); + EXPECT_EQ(0L, client_.read_buffer_.length()); + + // Inject unused data + client_to_server_.add(makeFakeTsiFrame(ClientToServerData)); + + EXPECT_CALL(*server_.raw_socket_, doRead(_)); + EXPECT_CALL(*server_.raw_socket_, doWrite(_, false)); + EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected)); + expectIoResult({Network::PostIoAction::KeepOpen, 21UL, false}, + server_.tsi_socket_->doRead(server_.read_buffer_)); + EXPECT_EQ(makeFakeTsiFrame("SERVER_FINISHED"), server_to_client_.toString()); + EXPECT_EQ(ClientToServerData, server_.read_buffer_.toString()); + + EXPECT_CALL(*client_.raw_socket_, doRead(_)); + EXPECT_CALL(client_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doRead(client_.read_buffer_)); +} + +TEST_F(TsiSocketTest, HandshakeWithUnusedDataAndEndOfStream) { + initialize(nullptr, nullptr); + + doFakeInitHandshake(); + EXPECT_CALL(*client_.raw_socket_, doRead(_)); + EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doRead(client_.read_buffer_)); + EXPECT_EQ(makeFakeTsiFrame("CLIENT_FINISHED"), client_to_server_.toString()); + EXPECT_EQ(0L, client_.read_buffer_.length()); + + // Inject unused data + client_to_server_.add(makeFakeTsiFrame(ClientToServerData)); + + EXPECT_CALL(*server_.raw_socket_, doRead(_)).WillOnce(Invoke([&](Buffer::Instance& buffer) { + Network::IoResult result = {Network::PostIoAction::KeepOpen, client_to_server_.length(), true}; + buffer.move(client_to_server_); + return result; + })); + EXPECT_CALL(*server_.raw_socket_, doWrite(_, false)); + EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected)); + expectIoResult({Network::PostIoAction::KeepOpen, 21UL, true}, + server_.tsi_socket_->doRead(server_.read_buffer_)); + EXPECT_EQ(makeFakeTsiFrame("SERVER_FINISHED"), server_to_client_.toString()); + EXPECT_EQ(ClientToServerData, server_.read_buffer_.toString()); + + EXPECT_CALL(*client_.raw_socket_, doRead(_)); + EXPECT_CALL(client_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doRead(client_.read_buffer_)); +} + +TEST_F(TsiSocketTest, HandshakeWithImmediateReadError) { + initialize(nullptr, nullptr); + + EXPECT_CALL(*client_.raw_socket_, doRead(_)).WillOnce(Invoke([&](Buffer::Instance& buffer) { + Network::IoResult result = {Network::PostIoAction::Close, server_to_client_.length(), false}; + buffer.move(server_to_client_); + return result; + })); + EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)).Times(0); + expectIoResult({Network::PostIoAction::Close, 0UL, false}, + client_.tsi_socket_->doRead(client_.read_buffer_)); + EXPECT_EQ("", client_to_server_.toString()); + EXPECT_EQ(0L, client_.read_buffer_.length()); +} + +TEST_F(TsiSocketTest, HandshakeWithReadError) { + initialize(nullptr, nullptr); + + doFakeInitHandshake(); + + EXPECT_CALL(*client_.raw_socket_, doRead(_)).WillOnce(Invoke([&](Buffer::Instance& buffer) { + Network::IoResult result = {Network::PostIoAction::Close, server_to_client_.length(), false}; + buffer.move(server_to_client_); + return result; + })); + EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)).Times(0); + EXPECT_CALL(client_.callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doRead(client_.read_buffer_)); + EXPECT_EQ("", client_to_server_.toString()); + EXPECT_EQ(0L, client_.read_buffer_.length()); +} + +TEST_F(TsiSocketTest, HandshakeWithInternalError) { + auto raw_handshaker = tsi_create_fake_handshaker(/* is_client= */ 1); + const tsi_handshaker_vtable* vtable = raw_handshaker->vtable; + tsi_handshaker_vtable mock_vtable = *vtable; + mock_vtable.next = [](tsi_handshaker*, const unsigned char*, size_t, const unsigned char**, + size_t*, tsi_handshaker_result**, tsi_handshaker_on_next_done_cb, + void*) { return TSI_INTERNAL_ERROR; }; + raw_handshaker->vtable = &mock_vtable; + + client_.handshaker_factory_ = [&](Event::Dispatcher& dispatcher, + const Network::Address::InstanceConstSharedPtr&, + const Network::Address::InstanceConstSharedPtr&) { + CHandshakerPtr handshaker{raw_handshaker}; + + return std::make_unique(std::move(handshaker), dispatcher); + }; + + initialize(nullptr, nullptr); + + EXPECT_CALL(client_.callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush)); + // doWrite won't immediately fail, but it will result connection close. + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doWrite(client_.write_buffer_, false)); + + raw_handshaker->vtable = vtable; +} + +class TsiSocketFactoryTest : public testing::Test { +protected: + void SetUp() override { + auto handshaker_factory = [](Event::Dispatcher& dispatcher, + const Network::Address::InstanceConstSharedPtr&, + const Network::Address::InstanceConstSharedPtr&) { + CHandshakerPtr handshaker{tsi_create_fake_handshaker(/*is_client=*/0)}; + + return std::make_unique(std::move(handshaker), dispatcher); + }; + + socket_factory_ = std::make_unique(handshaker_factory, nullptr); + } + Network::TransportSocketFactoryPtr socket_factory_; +}; + +TEST_F(TsiSocketFactoryTest, CreateTransportSocket) { + EXPECT_NE(nullptr, socket_factory_->createTransportSocket()); +} + +TEST_F(TsiSocketFactoryTest, ImplementsSecureTransport) { + EXPECT_TRUE(socket_factory_->implementsSecureTransport()); +} + +} // namespace Alts +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/test/mocks/network/mocks.cc b/test/mocks/network/mocks.cc index 43ade3273e44..be9ca04110b6 100644 --- a/test/mocks/network/mocks.cc +++ b/test/mocks/network/mocks.cc @@ -203,11 +203,19 @@ MockListener::~MockListener() { onDestroy(); } MockConnectionHandler::MockConnectionHandler() {} MockConnectionHandler::~MockConnectionHandler() {} -MockTransportSocket::MockTransportSocket() {} +MockTransportSocket::MockTransportSocket() { + ON_CALL(*this, setTransportSocketCallbacks(_)) + .WillByDefault(Invoke([&](TransportSocketCallbacks& callbacks) { callbacks_ = &callbacks; })); +} MockTransportSocket::~MockTransportSocket() {} MockTransportSocketFactory::MockTransportSocketFactory() {} MockTransportSocketFactory::~MockTransportSocketFactory() {} +MockTransportSocketCallbacks::MockTransportSocketCallbacks() { + ON_CALL(*this, connection()).WillByDefault(ReturnRef(connection_)); +} +MockTransportSocketCallbacks::~MockTransportSocketCallbacks() {} + } // namespace Network } // namespace Envoy diff --git a/test/mocks/network/mocks.h b/test/mocks/network/mocks.h index 0bd73a87dbb9..318ed47d080c 100644 --- a/test/mocks/network/mocks.h +++ b/test/mocks/network/mocks.h @@ -434,6 +434,8 @@ class MockTransportSocket : public TransportSocket { MOCK_METHOD2(doWrite, IoResult(Buffer::Instance& buffer, bool end_stream)); MOCK_METHOD0(onConnected, void()); MOCK_CONST_METHOD0(ssl, const Ssl::Connection*()); + + TransportSocketCallbacks* callbacks_{}; }; class MockTransportSocketFactory : public TransportSocketFactory { @@ -445,5 +447,19 @@ class MockTransportSocketFactory : public TransportSocketFactory { MOCK_CONST_METHOD0(createTransportSocket, TransportSocketPtr()); }; +class MockTransportSocketCallbacks : public TransportSocketCallbacks { +public: + MockTransportSocketCallbacks(); + ~MockTransportSocketCallbacks(); + + MOCK_CONST_METHOD0(fd, int()); + MOCK_METHOD0(connection, Connection&()); + MOCK_METHOD0(shouldDrainReadBuffer, bool()); + MOCK_METHOD0(setReadBufferReady, void()); + MOCK_METHOD1(raiseEvent, void(ConnectionEvent)); + + testing::NiceMock connection_; +}; + } // namespace Network } // namespace Envoy