From ce89fd90708811313c483fe8edb709ec095b7858 Mon Sep 17 00:00:00 2001 From: Lizan Zhou Date: Tue, 14 Aug 2018 19:46:20 -0700 Subject: [PATCH 01/10] alts: add gRPC TSI socket Signed-off-by: Lizan Zhou --- .../extensions/transport_sockets/alts/BUILD | 20 ++ .../transport_sockets/alts/tsi_socket.cc | 214 ++++++++++++++ .../transport_sockets/alts/tsi_socket.h | 119 ++++++++ test/extensions/transport_sockets/alts/BUILD | 13 + .../transport_sockets/alts/tsi_socket_test.cc | 265 ++++++++++++++++++ test/mocks/network/mocks.cc | 9 +- test/mocks/network/mocks.h | 16 ++ 7 files changed, 655 insertions(+), 1 deletion(-) create mode 100644 source/extensions/transport_sockets/alts/tsi_socket.cc create mode 100644 source/extensions/transport_sockets/alts/tsi_socket.h create mode 100644 test/extensions/transport_sockets/alts/tsi_socket_test.cc diff --git a/source/extensions/transport_sockets/alts/BUILD b/source/extensions/transport_sockets/alts/BUILD index da086ff2d6e2..f54fbe16e69b 100644 --- a/source/extensions/transport_sockets/alts/BUILD +++ b/source/extensions/transport_sockets/alts/BUILD @@ -53,3 +53,23 @@ envoy_cc_library( "//source/common/buffer:buffer_lib", ], ) + +envoy_cc_library( + name = "tsi_socket", + srcs = [ + "tsi_socket.cc", + ], + hdrs = [ + "tsi_socket.h", + ], + repository = "@envoy", + deps = [ + ":tsi_frame_protector", + ":tsi_handshaker", + "//include/envoy/network:transport_socket_interface", + "//source/common/buffer:buffer_lib", + "//source/common/common:enum_to_int", + "//source/common/network:raw_buffer_socket_lib", + "//source/common/protobuf:utility_lib", + ], +) diff --git a/source/extensions/transport_sockets/alts/tsi_socket.cc b/source/extensions/transport_sockets/alts/tsi_socket.cc new file mode 100644 index 000000000000..89114917b087 --- /dev/null +++ b/source/extensions/transport_sockets/alts/tsi_socket.cc @@ -0,0 +1,214 @@ +#include "extensions/transport_sockets/alts/tsi_socket.h" + +#include "common/common/assert.h" +#include "common/common/enum_to_int.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Alts { + +TsiSocket::TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator, + Network::TransportSocketPtr&& raw_socket) + : handshaker_factory_(handshaker_factory), handshake_validator_(handshake_validator), + raw_buffer_callbacks_(*this), raw_buffer_socket_(std::move(raw_socket)) { + raw_buffer_socket_->setTransportSocketCallbacks(raw_buffer_callbacks_); +} + +TsiSocket::TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator) + : TsiSocket(handshaker_factory, handshake_validator, + std::make_unique()) {} + +TsiSocket::~TsiSocket() { ASSERT(!handshaker_); } + +void TsiSocket::setTransportSocketCallbacks(Envoy::Network::TransportSocketCallbacks& callbacks) { + callbacks_ = &callbacks; + + handshaker_ = handshaker_factory_(callbacks.connection().dispatcher()); + handshaker_->setHandshakerCallbacks(*this); +} + +std::string TsiSocket::protocol() const { return ""; } + +Network::PostIoAction TsiSocket::doHandshake() { + ASSERT(!handshake_complete_); + ENVOY_CONN_LOG(debug, "TSI: doHandshake", callbacks_->connection()); + + if (!handshaker_next_calling_) { + doHandshakeNext(); + } + return Network::PostIoAction::KeepOpen; +} + +void TsiSocket::doHandshakeNext() { + ENVOY_CONN_LOG(debug, "TSI: doHandshake next: received: {}", callbacks_->connection(), + raw_read_buffer_.length()); + handshaker_next_calling_ = true; + Buffer::OwnedImpl handshaker_buffer; + handshaker_buffer.move(raw_read_buffer_); + handshaker_->next(handshaker_buffer); +} + +Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result) { + ASSERT(next_result); + + ENVOY_CONN_LOG(debug, "TSI: doHandshake next done: status: {} to_send: {}", + callbacks_->connection(), next_result->status_, next_result->to_send_->length()); + + tsi_result status = next_result->status_; + tsi_handshaker_result* handshaker_result = next_result->result_.get(); + + if (status != TSI_INCOMPLETE_DATA && status != TSI_OK) { + ENVOY_CONN_LOG(debug, "TSI: Handshake failed: status: {}", callbacks_->connection(), status); + return Network::PostIoAction::Close; + } + + if (next_result->to_send_->length() > 0) { + raw_write_buffer_.move(*next_result->to_send_); + } + + if (status == TSI_OK && handshaker_result != nullptr) { + tsi_peer peer; + tsi_handshaker_result_extract_peer(handshaker_result, &peer); + ENVOY_CONN_LOG(debug, "TSI: Handshake successful: peer properties: {}", + callbacks_->connection(), peer.property_count); + for (size_t i = 0; i < peer.property_count; ++i) { + ENVOY_CONN_LOG(debug, " {}: {}", callbacks_->connection(), peer.properties[i].name, + std::string(peer.properties[i].value.data, peer.properties[i].value.length)); + } + if (handshake_validator_) { + std::string err; + bool peer_validated = handshake_validator_(peer, err); + if (peer_validated) { + ENVOY_CONN_LOG(info, "TSI: Handshake validation succeeded.", callbacks_->connection()); + } else { + ENVOY_CONN_LOG(warn, "TSI: Handshake validation failed: {}", callbacks_->connection(), err); + tsi_peer_destruct(&peer); + return Network::PostIoAction::Close; + } + } else { + ENVOY_CONN_LOG(info, "TSI: Handshake validation skipped.", callbacks_->connection()); + } + tsi_peer_destruct(&peer); + + const unsigned char* unused_bytes; + size_t unused_byte_size; + + status = + tsi_handshaker_result_get_unused_bytes(handshaker_result, &unused_bytes, &unused_byte_size); + ASSERT(status == TSI_OK); + if (unused_byte_size > 0) { + raw_read_buffer_.add(unused_bytes, unused_byte_size); + } + ENVOY_CONN_LOG(debug, "TSI: Handshake successful: unused_bytes: {}", callbacks_->connection(), + unused_byte_size); + + tsi_frame_protector* frame_protector; + status = + tsi_handshaker_result_create_frame_protector(handshaker_result, NULL, &frame_protector); + ASSERT(status == TSI_OK); + frame_protector_ = std::make_unique(frame_protector); + + handshake_complete_ = true; + callbacks_->raiseEvent(Network::ConnectionEvent::Connected); + } + + if (raw_read_buffer_.length() > 0) { + callbacks_->setReadBufferReady(); + } + + // Try to write raw buffer when next call is done, even this is not in do[Read|Write] stack. + if (raw_write_buffer_.length() > 0) { + return raw_buffer_socket_->doWrite(raw_write_buffer_, false).action_; + } + + return Network::PostIoAction::KeepOpen; +} + +Network::IoResult TsiSocket::doRead(Buffer::Instance& buffer) { + Network::IoResult result = raw_buffer_socket_->doRead(raw_read_buffer_); + ENVOY_CONN_LOG(debug, "TSI: raw read result action {} bytes {} end_stream {}", + callbacks_->connection(), enumToInt(result.action_), result.bytes_processed_, + result.end_stream_read_); + if (result.action_ == Network::PostIoAction::Close && result.bytes_processed_ == 0) { + return result; + } + + if (!handshake_complete_) { + Network::PostIoAction action = doHandshake(); + if (action == Network::PostIoAction::Close || !handshake_complete_) { + return {action, 0, false}; + } + } + + if (handshake_complete_) { + ASSERT(frame_protector_); + + uint64_t read_size = raw_read_buffer_.length(); + ENVOY_CONN_LOG(debug, "TSI: unprotecting buffer size: {}", callbacks_->connection(), + raw_read_buffer_.length()); + tsi_result status = frame_protector_->unprotect(raw_read_buffer_, buffer); + ENVOY_CONN_LOG(debug, "TSI: unprotected buffer left: {} result: {}", callbacks_->connection(), + raw_read_buffer_.length(), tsi_result_to_string(status)); + result.bytes_processed_ = read_size - raw_read_buffer_.length(); + } + + ENVOY_CONN_LOG(debug, "TSI: do read result action {} bytes {} end_stream {}", + callbacks_->connection(), enumToInt(result.action_), result.bytes_processed_, + result.end_stream_read_); + return result; +} + +Network::IoResult TsiSocket::doWrite(Buffer::Instance& buffer, bool end_stream) { + if (!handshake_complete_) { + Network::PostIoAction action = doHandshake(); + if (action == Network::PostIoAction::Close) { + return {action, 0, false}; + } + } + + if (handshake_complete_) { + ASSERT(frame_protector_); + ENVOY_CONN_LOG(debug, "TSI: protecting buffer size: {}", callbacks_->connection(), + buffer.length()); + tsi_result status = frame_protector_->protect(buffer, raw_write_buffer_); + ENVOY_CONN_LOG(debug, "TSI: protected buffer left: {} result: {}", callbacks_->connection(), + buffer.length(), tsi_result_to_string(status)); + } + + if (raw_write_buffer_.length() > 0) { + ENVOY_CONN_LOG(debug, "TSI: raw_write length {} end_stream {}", callbacks_->connection(), + raw_write_buffer_.length(), end_stream); + return raw_buffer_socket_->doWrite(raw_write_buffer_, end_stream && (buffer.length() == 0)); + } + return {Network::PostIoAction::KeepOpen, 0, false}; +} + +void TsiSocket::closeSocket(Network::ConnectionEvent) { handshaker_.release()->deferredDelete(); } + +void TsiSocket::onConnected() { ASSERT(!handshake_complete_); } + +void TsiSocket::onNextDone(NextResultPtr&& result) { + handshaker_next_calling_ = false; + + Network::PostIoAction action = doHandshakeNextDone(std::move(result)); + if (action == Network::PostIoAction::Close) { + callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); + } +} + +TsiSocketFactory::TsiSocketFactory(HandshakerFactory handshaker_factory, + HandshakeValidator handshake_validator) + : handshaker_factory_(std::move(handshaker_factory)), + handshake_validator_(std::move(handshake_validator)) {} + +bool TsiSocketFactory::implementsSecureTransport() const { return true; } + +Network::TransportSocketPtr TsiSocketFactory::createTransportSocket() const { + return std::make_unique(handshaker_factory_, handshake_validator_); +} + +} // namespace Alts +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy \ No newline at end of file diff --git a/source/extensions/transport_sockets/alts/tsi_socket.h b/source/extensions/transport_sockets/alts/tsi_socket.h new file mode 100644 index 000000000000..5a1cf0ed25fc --- /dev/null +++ b/source/extensions/transport_sockets/alts/tsi_socket.h @@ -0,0 +1,119 @@ +#pragma once + +#include "envoy/network/transport_socket.h" + +#include "common/buffer/buffer_impl.h" +#include "common/network/raw_buffer_socket.h" + +#include "extensions/transport_sockets/alts/tsi_frame_protector.h" +#include "extensions/transport_sockets/alts/tsi_handshaker.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Alts { + +typedef std::function HandshakerFactory; + +/** + * A function to validate the peer of the connection. + * @param peer the detail peer information of the connection. + * @param err an error message to indicate why the peer is invalid. This is an + * output param that should be populated by the function implementation. + * @return true if the peer is valid or false if the peer is invalid. + */ +typedef std::function HandshakeValidator; + +/** + * A implementation of Network::TransportSocket based on gRPC TSI + */ +class TsiSocket : public Network::TransportSocket, + public TsiHandshakerCallbacks, + public Logger::Loggable { +public: + // For Test + TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator, + Network::TransportSocketPtr&& raw_socket_ptr); + + /** + * @param handshaker_factory a function to initiate a TsiHandshaker + * @param handshake_validator a function to validate the peer. Called right + * after the handshake completed with peer data to do the peer validation. + * The connection will be closed immediately if it returns false. + */ + TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator); + virtual ~TsiSocket(); + + // Network::TransportSocket + void setTransportSocketCallbacks(Envoy::Network::TransportSocketCallbacks& callbacks) override; + std::string protocol() const override; + bool canFlushClose() override { return handshake_complete_; } + Envoy::Ssl::Connection* ssl() override { return nullptr; } + const Envoy::Ssl::Connection* ssl() const override { return nullptr; } + Network::IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override; + void closeSocket(Network::ConnectionEvent event) override; + Network::IoResult doRead(Buffer::Instance& buffer) override; + void onConnected() override; + + // TsiHandshakerCallbacks + void onNextDone(NextResultPtr&& result) override; + +private: + /** + * Callbacks for underlying RawBufferSocket, it proxies fd() and connection() + * but not raising event or flow control since they have to be handled in + * TsiSocket. + */ + class RawBufferCallbacks : public Network::TransportSocketCallbacks { + public: + explicit RawBufferCallbacks(TsiSocket& parent) : parent_(parent) {} + + int fd() const override { return parent_.callbacks_->fd(); } + Network::Connection& connection() override { return parent_.callbacks_->connection(); } + bool shouldDrainReadBuffer() override { return false; } + void setReadBufferReady() override {} + void raiseEvent(Network::ConnectionEvent) override {} + + private: + TsiSocket& parent_; + }; + + Network::PostIoAction doHandshake(); + void doHandshakeNext(); + Network::PostIoAction doHandshakeNextDone(NextResultPtr&& next_result); + + HandshakerFactory handshaker_factory_; + HandshakeValidator handshake_validator_; + TsiHandshakerPtr handshaker_{}; + bool handshaker_next_calling_{}; + + TsiFrameProtectorPtr frame_protector_; + + Envoy::Network::TransportSocketCallbacks* callbacks_{}; + RawBufferCallbacks raw_buffer_callbacks_; + Network::TransportSocketPtr raw_buffer_socket_; + + Envoy::Buffer::OwnedImpl raw_read_buffer_; + Envoy::Buffer::OwnedImpl raw_write_buffer_; + bool handshake_complete_{}; +}; + +/** + * An implementation of Network::TransportSocketFactory for TsiSocket + */ +class TsiSocketFactory : public Network::TransportSocketFactory { +public: + TsiSocketFactory(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator); + + bool implementsSecureTransport() const override; + Network::TransportSocketPtr createTransportSocket() const override; + +private: + HandshakerFactory handshaker_factory_; + HandshakeValidator handshake_validator_; +}; + +} // namespace Alts +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy \ No newline at end of file diff --git a/test/extensions/transport_sockets/alts/BUILD b/test/extensions/transport_sockets/alts/BUILD index 171a97fd28e4..b57d518eb1f1 100644 --- a/test/extensions/transport_sockets/alts/BUILD +++ b/test/extensions/transport_sockets/alts/BUILD @@ -32,3 +32,16 @@ envoy_extension_cc_test( "//test/mocks/event:event_mocks", ], ) + +envoy_extension_cc_test( + name = "tsi_socket_test", + srcs = ["tsi_socket_test.cc"], + extension_name = "envoy.transport_sockets.alts", + deps = [ + "//include/envoy/event:dispatcher_interface", + "//source/extensions/transport_sockets/alts:tsi_socket", + "//test/mocks/buffer:buffer_mocks", + "//test/mocks/event:event_mocks", + "//test/mocks/network:network_mocks", + ], +) diff --git a/test/extensions/transport_sockets/alts/tsi_socket_test.cc b/test/extensions/transport_sockets/alts/tsi_socket_test.cc new file mode 100644 index 000000000000..8a65121e1888 --- /dev/null +++ b/test/extensions/transport_sockets/alts/tsi_socket_test.cc @@ -0,0 +1,265 @@ +#include "common/buffer/buffer_impl.h" + +#include "extensions/transport_sockets/alts/tsi_socket.h" + +#include "test/mocks/network/mocks.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "src/core/tsi/fake_transport_security.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Alts { + +using testing::NiceMock; +using testing::Return; +using testing::ReturnRef; +using testing::StrictMock; + +class TsiSocketTest : public testing::Test { +protected: + TsiSocketTest() {} + + void initialize(HandshakeValidator server_validator, HandshakeValidator client_validator) { + auto server_handshaker_factory = [](Event::Dispatcher& dispatcher) { + CHandshakerPtr handshaker{tsi_create_fake_handshaker(/*is_client=*/0)}; + + return std::make_unique(std::move(handshaker), dispatcher); + }; + + server_.raw_socket_ = new NiceMock(); + + server_.tsi_socket_ = + std::make_unique(server_handshaker_factory, server_validator, + Network::TransportSocketPtr{server_.raw_socket_}); + + auto client_handshaker_factory = [](Event::Dispatcher& dispatcher) { + CHandshakerPtr handshaker{tsi_create_fake_handshaker(/*is_client=*/1)}; + + return std::make_unique(std::move(handshaker), dispatcher); + }; + + client_.raw_socket_ = new NiceMock(); + + client_.tsi_socket_ = + std::make_unique(client_handshaker_factory, client_validator, + Network::TransportSocketPtr{client_.raw_socket_}); + + ON_CALL(client_.callbacks_.connection_, dispatcher()).WillByDefault(ReturnRef(dispatcher_)); + ON_CALL(server_.callbacks_.connection_, dispatcher()).WillByDefault(ReturnRef(dispatcher_)); + + ON_CALL(client_.callbacks_.connection_, id()).WillByDefault(Return(11)); + ON_CALL(server_.callbacks_.connection_, id()).WillByDefault(Return(12)); + + ON_CALL(*client_.raw_socket_, doWrite(_, _)) + .WillByDefault(Invoke([&](Buffer::Instance& buffer, bool) { + Network::IoResult result = {Network::PostIoAction::KeepOpen, buffer.length(), false}; + client_to_server_.move(buffer); + return result; + })); + ON_CALL(*server_.raw_socket_, doWrite(_, _)) + .WillByDefault(Invoke([&](Buffer::Instance& buffer, bool) { + Network::IoResult result = {Network::PostIoAction::KeepOpen, buffer.length(), false}; + server_to_client_.move(buffer); + return result; + })); + + ON_CALL(*client_.raw_socket_, doRead(_)).WillByDefault(Invoke([&](Buffer::Instance& buffer) { + Network::IoResult result = {Network::PostIoAction::KeepOpen, server_to_client_.length(), + false}; + buffer.move(server_to_client_); + return result; + })); + ON_CALL(*server_.raw_socket_, doRead(_)).WillByDefault(Invoke([&](Buffer::Instance& buffer) { + Network::IoResult result = {Network::PostIoAction::KeepOpen, client_to_server_.length(), + false}; + buffer.move(client_to_server_); + return result; + })); + + client_.tsi_socket_->setTransportSocketCallbacks(client_.callbacks_); + client_.tsi_socket_->onConnected(); + + server_.tsi_socket_->setTransportSocketCallbacks(server_.callbacks_); + server_.tsi_socket_->onConnected(); + } + + void expectIoResult(Network::IoResult expected, Network::IoResult actual) { + EXPECT_EQ(expected.action_, actual.action_); + EXPECT_EQ(expected.bytes_processed_, actual.bytes_processed_); + EXPECT_EQ(expected.end_stream_read_, actual.end_stream_read_); + } + + std::string makeFakeTsiFrame(const std::string& payload) { + uint32_t length = static_cast(payload.length()) + 4; + std::string frame; + frame.reserve(length); + frame.push_back(static_cast(length)); + length >>= 8; + frame.push_back(static_cast(length)); + length >>= 8; + frame.push_back(static_cast(length)); + length >>= 8; + frame.push_back(static_cast(length)); + + frame.append(payload); + return frame; + } + + void doFakeInitHandshake() { + EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doWrite(client_.write_buffer_, false)); + EXPECT_EQ(makeFakeTsiFrame("CLIENT_INIT"), client_to_server_.toString()); + + EXPECT_CALL(*server_.raw_socket_, doRead(_)); + EXPECT_CALL(*server_.raw_socket_, doWrite(_, false)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + server_.tsi_socket_->doRead(server_.read_buffer_)); + EXPECT_EQ(makeFakeTsiFrame("SERVER_INIT"), server_to_client_.toString()); + EXPECT_EQ(0L, server_.read_buffer_.length()); + } + + void doHandshakeAndExpectSuccess() { + doFakeInitHandshake(); + + EXPECT_CALL(*client_.raw_socket_, doRead(_)); + EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doRead(client_.read_buffer_)); + EXPECT_EQ(makeFakeTsiFrame("CLIENT_FINISHED"), client_to_server_.toString()); + EXPECT_EQ(0L, client_.read_buffer_.length()); + + EXPECT_CALL(*server_.raw_socket_, doRead(_)); + EXPECT_CALL(*server_.raw_socket_, doWrite(_, false)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + server_.tsi_socket_->doRead(server_.read_buffer_)); + EXPECT_EQ(makeFakeTsiFrame("SERVER_FINISHED"), server_to_client_.toString()); + + EXPECT_CALL(*client_.raw_socket_, doRead(_)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doRead(client_.read_buffer_)); + } + + struct SocketForTest { + std::unique_ptr tsi_socket_; + NiceMock* raw_socket_{}; + NiceMock callbacks_; + Buffer::OwnedImpl read_buffer_; + Buffer::OwnedImpl write_buffer_; + }; + + SocketForTest client_; + SocketForTest server_; + + Buffer::OwnedImpl client_to_server_; + Buffer::OwnedImpl server_to_client_; + + NiceMock dispatcher_; +}; + +TEST_F(TsiSocketTest, HandshakeWithoutValidationAndTransferData) { + initialize(nullptr, nullptr); + + client_.write_buffer_.add("hello from client"); + + doHandshakeAndExpectSuccess(); + EXPECT_EQ(0L, server_.read_buffer_.length()); + EXPECT_EQ(0L, client_.read_buffer_.length()); + + EXPECT_EQ("", client_.tsi_socket_->protocol()); + + EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)); + expectIoResult({Network::PostIoAction::KeepOpen, 21UL, false}, + client_.tsi_socket_->doWrite(client_.write_buffer_, false)); + EXPECT_EQ(makeFakeTsiFrame("hello from client"), client_to_server_.toString()); + + EXPECT_CALL(*server_.raw_socket_, doRead(_)); + expectIoResult({Network::PostIoAction::KeepOpen, 21UL, false}, + server_.tsi_socket_->doRead(server_.read_buffer_)); + EXPECT_EQ("hello from client", server_.read_buffer_.toString()); + + client_.tsi_socket_->closeSocket(Network::ConnectionEvent::LocalClose); + server_.tsi_socket_->closeSocket(Network::ConnectionEvent::RemoteClose); +} + +TEST_F(TsiSocketTest, HandshakeWithSucessfulValidationAndTransferData) { + auto validator = [](const tsi_peer&, std::string&) { return true; }; + initialize(validator, validator); + + client_.write_buffer_.add("hello from client"); + + doHandshakeAndExpectSuccess(); + EXPECT_EQ(0L, server_.read_buffer_.length()); + EXPECT_EQ(0L, client_.read_buffer_.length()); + + EXPECT_EQ("", client_.tsi_socket_->protocol()); + + EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)); + expectIoResult({Network::PostIoAction::KeepOpen, 21UL, false}, + client_.tsi_socket_->doWrite(client_.write_buffer_, false)); + EXPECT_EQ(makeFakeTsiFrame("hello from client"), client_to_server_.toString()); + + EXPECT_CALL(*server_.raw_socket_, doRead(_)); + expectIoResult({Network::PostIoAction::KeepOpen, 21UL, false}, + server_.tsi_socket_->doRead(server_.read_buffer_)); + EXPECT_EQ("hello from client", server_.read_buffer_.toString()); + + client_.tsi_socket_->closeSocket(Network::ConnectionEvent::LocalClose); + server_.tsi_socket_->closeSocket(Network::ConnectionEvent::RemoteClose); +} + +TEST_F(TsiSocketTest, HandshakeValidationFail) { + auto validator = [](const tsi_peer&, std::string&) { return false; }; + initialize(validator, validator); + + client_.write_buffer_.add("hello from client"); + + doFakeInitHandshake(); + + EXPECT_CALL(*client_.raw_socket_, doRead(_)); + EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doRead(client_.read_buffer_)); + EXPECT_EQ(makeFakeTsiFrame("CLIENT_FINISHED"), client_to_server_.toString()); + EXPECT_EQ(0L, client_.read_buffer_.length()); + + EXPECT_CALL(*server_.raw_socket_, doRead(_)); + EXPECT_CALL(server_.callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush)); + // doRead won't immediately fail, but it will result connection close. + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + server_.tsi_socket_->doRead(server_.read_buffer_)); + EXPECT_EQ(0, server_to_client_.length()); + + client_.tsi_socket_->closeSocket(Network::ConnectionEvent::LocalClose); + server_.tsi_socket_->closeSocket(Network::ConnectionEvent::RemoteClose); +} + +class TsiSocketFactoryTest : public testing::Test { +protected: + void SetUp() override { + auto handshaker_factory = [](Event::Dispatcher& dispatcher) { + CHandshakerPtr handshaker{tsi_create_fake_handshaker(/*is_client=*/0)}; + + return std::make_unique(std::move(handshaker), dispatcher); + }; + + socket_factory_ = std::make_unique(handshaker_factory, nullptr); + } + Network::TransportSocketFactoryPtr socket_factory_; +}; + +TEST_F(TsiSocketFactoryTest, CreateTransportSocket) { + EXPECT_NE(nullptr, socket_factory_->createTransportSocket()); +} + +TEST_F(TsiSocketFactoryTest, ImplementsSecureTransport) { + EXPECT_TRUE(socket_factory_->implementsSecureTransport()); +} + +} // namespace Alts +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy \ No newline at end of file diff --git a/test/mocks/network/mocks.cc b/test/mocks/network/mocks.cc index 31fba7361201..c37a49d2671d 100644 --- a/test/mocks/network/mocks.cc +++ b/test/mocks/network/mocks.cc @@ -203,11 +203,18 @@ MockListener::~MockListener() { onDestroy(); } MockConnectionHandler::MockConnectionHandler() {} MockConnectionHandler::~MockConnectionHandler() {} -MockTransportSocket::MockTransportSocket() {} +MockTransportSocket::MockTransportSocket() { + ON_CALL(*this, setTransportSocketCallbacks(_)).WillByDefault(SaveArg<0>(callbacks_)); +} MockTransportSocket::~MockTransportSocket() {} MockTransportSocketFactory::MockTransportSocketFactory() {} MockTransportSocketFactory::~MockTransportSocketFactory() {} +MockTransportSocketCallbacks::MockTransportSocketCallbacks() { + ON_CALL(*this, connection()).WillByDefault(ReturnRef(connection_)); +} +MockTransportSocketCallbacks::~MockTransportSocketCallbacks() {} + } // namespace Network } // namespace Envoy diff --git a/test/mocks/network/mocks.h b/test/mocks/network/mocks.h index 651da87c4174..30ad89566b73 100644 --- a/test/mocks/network/mocks.h +++ b/test/mocks/network/mocks.h @@ -437,6 +437,8 @@ class MockTransportSocket : public TransportSocket { MOCK_METHOD0(onConnected, void()); MOCK_METHOD0(ssl, Ssl::Connection*()); MOCK_CONST_METHOD0(ssl, const Ssl::Connection*()); + + TransportSocketCallbacks* callbacks_{}; }; class MockTransportSocketFactory : public TransportSocketFactory { @@ -448,5 +450,19 @@ class MockTransportSocketFactory : public TransportSocketFactory { MOCK_CONST_METHOD0(createTransportSocket, TransportSocketPtr()); }; +class MockTransportSocketCallbacks : public TransportSocketCallbacks { +public: + MockTransportSocketCallbacks(); + ~MockTransportSocketCallbacks(); + + MOCK_CONST_METHOD0(fd, int()); + MOCK_METHOD0(connection, Connection&()); + MOCK_METHOD0(shouldDrainReadBuffer, bool()); + MOCK_METHOD0(setReadBufferReady, void()); + MOCK_METHOD1(raiseEvent, void(ConnectionEvent)); + + testing::NiceMock connection_; +}; + } // namespace Network } // namespace Envoy From 62f43391691ed9a345e3e81bfb9d8525ec410b88 Mon Sep 17 00:00:00 2001 From: Lizan Zhou Date: Wed, 15 Aug 2018 02:14:47 -0700 Subject: [PATCH 02/10] more tests, fix asan Signed-off-by: Lizan Zhou --- .../transport_sockets/alts/tsi_socket_test.cc | 36 ++++++++++++++----- test/mocks/network/mocks.cc | 4 ++- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/test/extensions/transport_sockets/alts/tsi_socket_test.cc b/test/extensions/transport_sockets/alts/tsi_socket_test.cc index 8a65121e1888..00f1f54efadb 100644 --- a/test/extensions/transport_sockets/alts/tsi_socket_test.cc +++ b/test/extensions/transport_sockets/alts/tsi_socket_test.cc @@ -22,6 +22,11 @@ class TsiSocketTest : public testing::Test { protected: TsiSocketTest() {} + void TearDown() override { + client_.tsi_socket_->closeSocket(Network::ConnectionEvent::LocalClose); + server_.tsi_socket_->closeSocket(Network::ConnectionEvent::RemoteClose); + } + void initialize(HandshakeValidator server_validator, HandshakeValidator client_validator) { auto server_handshaker_factory = [](Event::Dispatcher& dispatcher) { CHandshakerPtr handshaker{tsi_create_fake_handshaker(/*is_client=*/0)}; @@ -134,13 +139,16 @@ class TsiSocketTest : public testing::Test { EXPECT_CALL(*server_.raw_socket_, doRead(_)); EXPECT_CALL(*server_.raw_socket_, doWrite(_, false)); + EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected)); expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, server_.tsi_socket_->doRead(server_.read_buffer_)); EXPECT_EQ(makeFakeTsiFrame("SERVER_FINISHED"), server_to_client_.toString()); EXPECT_CALL(*client_.raw_socket_, doRead(_)); + EXPECT_CALL(client_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected)); expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, client_.tsi_socket_->doRead(client_.read_buffer_)); + } struct SocketForTest { @@ -160,6 +168,25 @@ class TsiSocketTest : public testing::Test { NiceMock dispatcher_; }; +TEST_F(TsiSocketTest, DoesNotHaveSsl) { + initialize(nullptr, nullptr); + EXPECT_EQ(nullptr, client_.tsi_socket_->ssl()); + + const auto& socket_ = *client_.tsi_socket_; + EXPECT_EQ(nullptr, socket_.ssl()); +} + +TEST_F(TsiSocketTest, ProxyCallbacks) { + initialize(nullptr, nullptr); + + EXPECT_CALL(client_.callbacks_, fd()).WillOnce(Return(111)); + EXPECT_EQ(111, client_.raw_socket_->callbacks_->fd()); + + EXPECT_EQ(&client_.callbacks_.connection_, &client_.raw_socket_->callbacks_->connection()); + + EXPECT_FALSE(client_.raw_socket_->callbacks_->shouldDrainReadBuffer()); +} + TEST_F(TsiSocketTest, HandshakeWithoutValidationAndTransferData) { initialize(nullptr, nullptr); @@ -180,9 +207,6 @@ TEST_F(TsiSocketTest, HandshakeWithoutValidationAndTransferData) { expectIoResult({Network::PostIoAction::KeepOpen, 21UL, false}, server_.tsi_socket_->doRead(server_.read_buffer_)); EXPECT_EQ("hello from client", server_.read_buffer_.toString()); - - client_.tsi_socket_->closeSocket(Network::ConnectionEvent::LocalClose); - server_.tsi_socket_->closeSocket(Network::ConnectionEvent::RemoteClose); } TEST_F(TsiSocketTest, HandshakeWithSucessfulValidationAndTransferData) { @@ -206,9 +230,6 @@ TEST_F(TsiSocketTest, HandshakeWithSucessfulValidationAndTransferData) { expectIoResult({Network::PostIoAction::KeepOpen, 21UL, false}, server_.tsi_socket_->doRead(server_.read_buffer_)); EXPECT_EQ("hello from client", server_.read_buffer_.toString()); - - client_.tsi_socket_->closeSocket(Network::ConnectionEvent::LocalClose); - server_.tsi_socket_->closeSocket(Network::ConnectionEvent::RemoteClose); } TEST_F(TsiSocketTest, HandshakeValidationFail) { @@ -232,9 +253,6 @@ TEST_F(TsiSocketTest, HandshakeValidationFail) { expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, server_.tsi_socket_->doRead(server_.read_buffer_)); EXPECT_EQ(0, server_to_client_.length()); - - client_.tsi_socket_->closeSocket(Network::ConnectionEvent::LocalClose); - server_.tsi_socket_->closeSocket(Network::ConnectionEvent::RemoteClose); } class TsiSocketFactoryTest : public testing::Test { diff --git a/test/mocks/network/mocks.cc b/test/mocks/network/mocks.cc index c37a49d2671d..0aece109da9e 100644 --- a/test/mocks/network/mocks.cc +++ b/test/mocks/network/mocks.cc @@ -204,7 +204,9 @@ MockConnectionHandler::MockConnectionHandler() {} MockConnectionHandler::~MockConnectionHandler() {} MockTransportSocket::MockTransportSocket() { - ON_CALL(*this, setTransportSocketCallbacks(_)).WillByDefault(SaveArg<0>(callbacks_)); + ON_CALL(*this, setTransportSocketCallbacks(_)).WillByDefault(Invoke([&](TransportSocketCallbacks& callbacks) { + callbacks_ = &callbacks; + })); } MockTransportSocket::~MockTransportSocket() {} From 9d8c0d656cdd916bf159b23713a3406be6d802fa Mon Sep 17 00:00:00 2001 From: Lizan Zhou Date: Wed, 15 Aug 2018 02:45:44 -0700 Subject: [PATCH 03/10] fix format Signed-off-by: Lizan Zhou --- test/extensions/transport_sockets/alts/tsi_socket_test.cc | 1 - test/mocks/network/mocks.cc | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/test/extensions/transport_sockets/alts/tsi_socket_test.cc b/test/extensions/transport_sockets/alts/tsi_socket_test.cc index 00f1f54efadb..d6d60effe727 100644 --- a/test/extensions/transport_sockets/alts/tsi_socket_test.cc +++ b/test/extensions/transport_sockets/alts/tsi_socket_test.cc @@ -148,7 +148,6 @@ class TsiSocketTest : public testing::Test { EXPECT_CALL(client_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected)); expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, client_.tsi_socket_->doRead(client_.read_buffer_)); - } struct SocketForTest { diff --git a/test/mocks/network/mocks.cc b/test/mocks/network/mocks.cc index 0aece109da9e..b62ab6a1e7c3 100644 --- a/test/mocks/network/mocks.cc +++ b/test/mocks/network/mocks.cc @@ -204,9 +204,8 @@ MockConnectionHandler::MockConnectionHandler() {} MockConnectionHandler::~MockConnectionHandler() {} MockTransportSocket::MockTransportSocket() { - ON_CALL(*this, setTransportSocketCallbacks(_)).WillByDefault(Invoke([&](TransportSocketCallbacks& callbacks) { - callbacks_ = &callbacks; - })); + ON_CALL(*this, setTransportSocketCallbacks(_)) + .WillByDefault(Invoke([&](TransportSocketCallbacks& callbacks) { callbacks_ = &callbacks; })); } MockTransportSocket::~MockTransportSocket() {} From 10dc8099d86e575575c212c6d471458c13a7ca2e Mon Sep 17 00:00:00 2001 From: Lizan Zhou Date: Tue, 21 Aug 2018 18:39:30 -0700 Subject: [PATCH 04/10] address comment Signed-off-by: Lizan Zhou --- .../extensions/transport_sockets/alts/BUILD | 1 + .../alts/noop_transport_socket_callbacks.h | 2 + .../transport_sockets/alts/tsi_socket.cc | 27 ++++--- .../transport_sockets/alts/tsi_socket.h | 33 +++----- .../transport_sockets/alts/tsi_socket_test.cc | 80 ++++++++----------- 5 files changed, 68 insertions(+), 75 deletions(-) diff --git a/source/extensions/transport_sockets/alts/BUILD b/source/extensions/transport_sockets/alts/BUILD index 96315d434f1f..4182dd0fcafe 100644 --- a/source/extensions/transport_sockets/alts/BUILD +++ b/source/extensions/transport_sockets/alts/BUILD @@ -64,6 +64,7 @@ envoy_cc_library( ], repository = "@envoy", deps = [ + ":noop_transport_socket_callbacks_lib", ":tsi_frame_protector", ":tsi_handshaker", "//include/envoy/network:transport_socket_interface", diff --git a/source/extensions/transport_sockets/alts/noop_transport_socket_callbacks.h b/source/extensions/transport_sockets/alts/noop_transport_socket_callbacks.h index 3fe6a038104a..d72b68e8eb20 100644 --- a/source/extensions/transport_sockets/alts/noop_transport_socket_callbacks.h +++ b/source/extensions/transport_sockets/alts/noop_transport_socket_callbacks.h @@ -29,6 +29,8 @@ class NoOpTransportSocketCallbacks : public Network::TransportSocketCallbacks { Network::TransportSocketCallbacks& parent_; }; +typedef std::unique_ptr NoOpTransportSocketCallbacksPtr; + } // namespace Alts } // namespace TransportSockets } // namespace Extensions diff --git a/source/extensions/transport_sockets/alts/tsi_socket.cc b/source/extensions/transport_sockets/alts/tsi_socket.cc index 89114917b087..041711ff2a0c 100644 --- a/source/extensions/transport_sockets/alts/tsi_socket.cc +++ b/source/extensions/transport_sockets/alts/tsi_socket.cc @@ -11,9 +11,7 @@ namespace Alts { TsiSocket::TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator, Network::TransportSocketPtr&& raw_socket) : handshaker_factory_(handshaker_factory), handshake_validator_(handshake_validator), - raw_buffer_callbacks_(*this), raw_buffer_socket_(std::move(raw_socket)) { - raw_buffer_socket_->setTransportSocketCallbacks(raw_buffer_callbacks_); -} + raw_buffer_socket_(std::move(raw_socket)) {} TsiSocket::TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator) : TsiSocket(handshaker_factory, handshake_validator, @@ -24,8 +22,8 @@ TsiSocket::~TsiSocket() { ASSERT(!handshaker_); } void TsiSocket::setTransportSocketCallbacks(Envoy::Network::TransportSocketCallbacks& callbacks) { callbacks_ = &callbacks; - handshaker_ = handshaker_factory_(callbacks.connection().dispatcher()); - handshaker_->setHandshakerCallbacks(*this); + noop_callbacks_ = std::make_unique(callbacks); + raw_buffer_socket_->setTransportSocketCallbacks(*noop_callbacks_); } std::string TsiSocket::protocol() const { return ""; } @@ -34,6 +32,13 @@ Network::PostIoAction TsiSocket::doHandshake() { ASSERT(!handshake_complete_); ENVOY_CONN_LOG(debug, "TSI: doHandshake", callbacks_->connection()); + if (!handshaker_) { + handshaker_ = handshaker_factory_(callbacks_->connection().dispatcher(), + callbacks_->connection().localAddress(), + callbacks_->connection().remoteAddress()); + handshaker_->setHandshakerCallbacks(*this); + } + if (!handshaker_next_calling_) { doHandshakeNext(); } @@ -80,14 +85,14 @@ Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result std::string err; bool peer_validated = handshake_validator_(peer, err); if (peer_validated) { - ENVOY_CONN_LOG(info, "TSI: Handshake validation succeeded.", callbacks_->connection()); + ENVOY_CONN_LOG(debug, "TSI: Handshake validation succeeded.", callbacks_->connection()); } else { - ENVOY_CONN_LOG(warn, "TSI: Handshake validation failed: {}", callbacks_->connection(), err); + ENVOY_CONN_LOG(info, "TSI: Handshake validation failed: {}", callbacks_->connection(), err); tsi_peer_destruct(&peer); return Network::PostIoAction::Close; } } else { - ENVOY_CONN_LOG(info, "TSI: Handshake validation skipped.", callbacks_->connection()); + ENVOY_CONN_LOG(debug, "TSI: Handshake validation skipped.", callbacks_->connection()); } tsi_peer_destruct(&peer); @@ -184,7 +189,11 @@ Network::IoResult TsiSocket::doWrite(Buffer::Instance& buffer, bool end_stream) return {Network::PostIoAction::KeepOpen, 0, false}; } -void TsiSocket::closeSocket(Network::ConnectionEvent) { handshaker_.release()->deferredDelete(); } +void TsiSocket::closeSocket(Network::ConnectionEvent) { + if (handshaker_) { + handshaker_.release()->deferredDelete(); + } +} void TsiSocket::onConnected() { ASSERT(!handshake_complete_); } diff --git a/source/extensions/transport_sockets/alts/tsi_socket.h b/source/extensions/transport_sockets/alts/tsi_socket.h index 83d8d68352a8..f94095ef6ee4 100644 --- a/source/extensions/transport_sockets/alts/tsi_socket.h +++ b/source/extensions/transport_sockets/alts/tsi_socket.h @@ -5,6 +5,7 @@ #include "common/buffer/buffer_impl.h" #include "common/network/raw_buffer_socket.h" +#include "extensions/transport_sockets/alts/noop_transport_socket_callbacks.h" #include "extensions/transport_sockets/alts/tsi_frame_protector.h" #include "extensions/transport_sockets/alts/tsi_handshaker.h" @@ -13,7 +14,16 @@ namespace Extensions { namespace TransportSockets { namespace Alts { -typedef std::function HandshakerFactory; +/** + * A factory function to create TsiHandshaker + * @param dispatcher the dispatcher for the thread where the socket is running on. + * @param local_address the local address of the connection. + * @param remote_address the remote address of the connection. + */ +typedef std::function + HandshakerFactory; /** * A function to validate the peer of the connection. @@ -58,25 +68,6 @@ class TsiSocket : public Network::TransportSocket, void onNextDone(NextResultPtr&& result) override; private: - /** - * Callbacks for underlying RawBufferSocket, it proxies fd() and connection() - * but not raising event or flow control since they have to be handled in - * TsiSocket. - */ - class RawBufferCallbacks : public Network::TransportSocketCallbacks { - public: - explicit RawBufferCallbacks(TsiSocket& parent) : parent_(parent) {} - - int fd() const override { return parent_.callbacks_->fd(); } - Network::Connection& connection() override { return parent_.callbacks_->connection(); } - bool shouldDrainReadBuffer() override { return false; } - void setReadBufferReady() override {} - void raiseEvent(Network::ConnectionEvent) override {} - - private: - TsiSocket& parent_; - }; - Network::PostIoAction doHandshake(); void doHandshakeNext(); Network::PostIoAction doHandshakeNextDone(NextResultPtr&& next_result); @@ -89,7 +80,7 @@ class TsiSocket : public Network::TransportSocket, TsiFrameProtectorPtr frame_protector_; Envoy::Network::TransportSocketCallbacks* callbacks_{}; - RawBufferCallbacks raw_buffer_callbacks_; + NoOpTransportSocketCallbacksPtr noop_callbacks_; Network::TransportSocketPtr raw_buffer_socket_; Envoy::Buffer::OwnedImpl raw_read_buffer_; diff --git a/test/extensions/transport_sockets/alts/tsi_socket_test.cc b/test/extensions/transport_sockets/alts/tsi_socket_test.cc index d6d60effe727..b8dd63e799a8 100644 --- a/test/extensions/transport_sockets/alts/tsi_socket_test.cc +++ b/test/extensions/transport_sockets/alts/tsi_socket_test.cc @@ -28,7 +28,9 @@ class TsiSocketTest : public testing::Test { } void initialize(HandshakeValidator server_validator, HandshakeValidator client_validator) { - auto server_handshaker_factory = [](Event::Dispatcher& dispatcher) { + auto server_handshaker_factory = [](Event::Dispatcher& dispatcher, + const Network::Address::InstanceConstSharedPtr&, + const Network::Address::InstanceConstSharedPtr&) { CHandshakerPtr handshaker{tsi_create_fake_handshaker(/*is_client=*/0)}; return std::make_unique(std::move(handshaker), dispatcher); @@ -40,7 +42,9 @@ class TsiSocketTest : public testing::Test { std::make_unique(server_handshaker_factory, server_validator, Network::TransportSocketPtr{server_.raw_socket_}); - auto client_handshaker_factory = [](Event::Dispatcher& dispatcher) { + auto client_handshaker_factory = [](Event::Dispatcher& dispatcher, + const Network::Address::InstanceConstSharedPtr&, + const Network::Address::InstanceConstSharedPtr&) { CHandshakerPtr handshaker{tsi_create_fake_handshaker(/*is_client=*/1)}; return std::make_unique(std::move(handshaker), dispatcher); @@ -150,6 +154,24 @@ class TsiSocketTest : public testing::Test { client_.tsi_socket_->doRead(client_.read_buffer_)); } + void expectTransferDataFromClientToServer(const std::string& data) { + + EXPECT_EQ(0L, server_.read_buffer_.length()); + EXPECT_EQ(0L, client_.read_buffer_.length()); + + EXPECT_EQ("", client_.tsi_socket_->protocol()); + + EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)); + expectIoResult({Network::PostIoAction::KeepOpen, 21UL, false}, + client_.tsi_socket_->doWrite(client_.write_buffer_, false)); + EXPECT_EQ(makeFakeTsiFrame(data), client_to_server_.toString()); + + EXPECT_CALL(*server_.raw_socket_, doRead(_)); + expectIoResult({Network::PostIoAction::KeepOpen, 21UL, false}, + server_.tsi_socket_->doRead(server_.read_buffer_)); + EXPECT_EQ(data, server_.read_buffer_.toString()); + } + struct SocketForTest { std::unique_ptr tsi_socket_; NiceMock* raw_socket_{}; @@ -167,6 +189,8 @@ class TsiSocketTest : public testing::Test { NiceMock dispatcher_; }; +static const std::string ClientToServerData = "hello from client"; + TEST_F(TsiSocketTest, DoesNotHaveSsl) { initialize(nullptr, nullptr); EXPECT_EQ(nullptr, client_.tsi_socket_->ssl()); @@ -175,67 +199,31 @@ TEST_F(TsiSocketTest, DoesNotHaveSsl) { EXPECT_EQ(nullptr, socket_.ssl()); } -TEST_F(TsiSocketTest, ProxyCallbacks) { - initialize(nullptr, nullptr); - - EXPECT_CALL(client_.callbacks_, fd()).WillOnce(Return(111)); - EXPECT_EQ(111, client_.raw_socket_->callbacks_->fd()); - - EXPECT_EQ(&client_.callbacks_.connection_, &client_.raw_socket_->callbacks_->connection()); - - EXPECT_FALSE(client_.raw_socket_->callbacks_->shouldDrainReadBuffer()); -} - TEST_F(TsiSocketTest, HandshakeWithoutValidationAndTransferData) { + // pass a nullptr validator to skip validation. initialize(nullptr, nullptr); - client_.write_buffer_.add("hello from client"); + client_.write_buffer_.add(ClientToServerData); doHandshakeAndExpectSuccess(); - EXPECT_EQ(0L, server_.read_buffer_.length()); - EXPECT_EQ(0L, client_.read_buffer_.length()); - - EXPECT_EQ("", client_.tsi_socket_->protocol()); - - EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)); - expectIoResult({Network::PostIoAction::KeepOpen, 21UL, false}, - client_.tsi_socket_->doWrite(client_.write_buffer_, false)); - EXPECT_EQ(makeFakeTsiFrame("hello from client"), client_to_server_.toString()); - - EXPECT_CALL(*server_.raw_socket_, doRead(_)); - expectIoResult({Network::PostIoAction::KeepOpen, 21UL, false}, - server_.tsi_socket_->doRead(server_.read_buffer_)); - EXPECT_EQ("hello from client", server_.read_buffer_.toString()); + expectTransferDataFromClientToServer(ClientToServerData); } TEST_F(TsiSocketTest, HandshakeWithSucessfulValidationAndTransferData) { auto validator = [](const tsi_peer&, std::string&) { return true; }; initialize(validator, validator); - client_.write_buffer_.add("hello from client"); + client_.write_buffer_.add(ClientToServerData); doHandshakeAndExpectSuccess(); - EXPECT_EQ(0L, server_.read_buffer_.length()); - EXPECT_EQ(0L, client_.read_buffer_.length()); - - EXPECT_EQ("", client_.tsi_socket_->protocol()); - - EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)); - expectIoResult({Network::PostIoAction::KeepOpen, 21UL, false}, - client_.tsi_socket_->doWrite(client_.write_buffer_, false)); - EXPECT_EQ(makeFakeTsiFrame("hello from client"), client_to_server_.toString()); - - EXPECT_CALL(*server_.raw_socket_, doRead(_)); - expectIoResult({Network::PostIoAction::KeepOpen, 21UL, false}, - server_.tsi_socket_->doRead(server_.read_buffer_)); - EXPECT_EQ("hello from client", server_.read_buffer_.toString()); + expectTransferDataFromClientToServer(ClientToServerData); } TEST_F(TsiSocketTest, HandshakeValidationFail) { auto validator = [](const tsi_peer&, std::string&) { return false; }; initialize(validator, validator); - client_.write_buffer_.add("hello from client"); + client_.write_buffer_.add(ClientToServerData); doFakeInitHandshake(); @@ -257,7 +245,9 @@ TEST_F(TsiSocketTest, HandshakeValidationFail) { class TsiSocketFactoryTest : public testing::Test { protected: void SetUp() override { - auto handshaker_factory = [](Event::Dispatcher& dispatcher) { + auto handshaker_factory = [](Event::Dispatcher& dispatcher, + const Network::Address::InstanceConstSharedPtr&, + const Network::Address::InstanceConstSharedPtr&) { CHandshakerPtr handshaker{tsi_create_fake_handshaker(/*is_client=*/0)}; return std::make_unique(std::move(handshaker), dispatcher); From 79c50823940171342510936eabe472b4fa20817d Mon Sep 17 00:00:00 2001 From: Lizan Zhou Date: Wed, 22 Aug 2018 03:02:51 -0700 Subject: [PATCH 05/10] more tests Signed-off-by: Lizan Zhou --- .../transport_sockets/alts/tsi_socket_test.cc | 94 +++++++++++++++---- 1 file changed, 75 insertions(+), 19 deletions(-) diff --git a/test/extensions/transport_sockets/alts/tsi_socket_test.cc b/test/extensions/transport_sockets/alts/tsi_socket_test.cc index b8dd63e799a8..588a2721b70d 100644 --- a/test/extensions/transport_sockets/alts/tsi_socket_test.cc +++ b/test/extensions/transport_sockets/alts/tsi_socket_test.cc @@ -20,7 +20,23 @@ using testing::StrictMock; class TsiSocketTest : public testing::Test { protected: - TsiSocketTest() {} + TsiSocketTest() { + server_.handshaker_factory_ = [](Event::Dispatcher& dispatcher, + const Network::Address::InstanceConstSharedPtr&, + const Network::Address::InstanceConstSharedPtr&) { + CHandshakerPtr handshaker{tsi_create_fake_handshaker(/*is_client=*/0)}; + + return std::make_unique(std::move(handshaker), dispatcher); + }; + + client_.handshaker_factory_ = [](Event::Dispatcher& dispatcher, + const Network::Address::InstanceConstSharedPtr&, + const Network::Address::InstanceConstSharedPtr&) { + CHandshakerPtr handshaker{tsi_create_fake_handshaker(/*is_client=*/1)}; + + return std::make_unique(std::move(handshaker), dispatcher); + }; + } void TearDown() override { client_.tsi_socket_->closeSocket(Network::ConnectionEvent::LocalClose); @@ -28,32 +44,16 @@ class TsiSocketTest : public testing::Test { } void initialize(HandshakeValidator server_validator, HandshakeValidator client_validator) { - auto server_handshaker_factory = [](Event::Dispatcher& dispatcher, - const Network::Address::InstanceConstSharedPtr&, - const Network::Address::InstanceConstSharedPtr&) { - CHandshakerPtr handshaker{tsi_create_fake_handshaker(/*is_client=*/0)}; - - return std::make_unique(std::move(handshaker), dispatcher); - }; - server_.raw_socket_ = new NiceMock(); server_.tsi_socket_ = - std::make_unique(server_handshaker_factory, server_validator, + std::make_unique(server_.handshaker_factory_, server_validator, Network::TransportSocketPtr{server_.raw_socket_}); - auto client_handshaker_factory = [](Event::Dispatcher& dispatcher, - const Network::Address::InstanceConstSharedPtr&, - const Network::Address::InstanceConstSharedPtr&) { - CHandshakerPtr handshaker{tsi_create_fake_handshaker(/*is_client=*/1)}; - - return std::make_unique(std::move(handshaker), dispatcher); - }; - client_.raw_socket_ = new NiceMock(); client_.tsi_socket_ = - std::make_unique(client_handshaker_factory, client_validator, + std::make_unique(client_.handshaker_factory_, client_validator, Network::TransportSocketPtr{client_.raw_socket_}); ON_CALL(client_.callbacks_.connection_, dispatcher()).WillByDefault(ReturnRef(dispatcher_)); @@ -173,6 +173,7 @@ class TsiSocketTest : public testing::Test { } struct SocketForTest { + HandshakerFactory handshaker_factory_; std::unique_ptr tsi_socket_; NiceMock* raw_socket_{}; NiceMock callbacks_; @@ -242,6 +243,61 @@ TEST_F(TsiSocketTest, HandshakeValidationFail) { EXPECT_EQ(0, server_to_client_.length()); } +TEST_F(TsiSocketTest, HandshakeWithUnusedData) { + initialize(nullptr, nullptr); + + doFakeInitHandshake(); + EXPECT_CALL(*client_.raw_socket_, doRead(_)); + EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doRead(client_.read_buffer_)); + EXPECT_EQ(makeFakeTsiFrame("CLIENT_FINISHED"), client_to_server_.toString()); + EXPECT_EQ(0L, client_.read_buffer_.length()); + + // Inject unused data + client_to_server_.add(makeFakeTsiFrame(ClientToServerData)); + + EXPECT_CALL(*server_.raw_socket_, doRead(_)); + EXPECT_CALL(*server_.raw_socket_, doWrite(_, false)); + EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected)); + expectIoResult({Network::PostIoAction::KeepOpen, 21UL, false}, + server_.tsi_socket_->doRead(server_.read_buffer_)); + EXPECT_EQ(makeFakeTsiFrame("SERVER_FINISHED"), server_to_client_.toString()); + EXPECT_EQ(ClientToServerData, server_.read_buffer_.toString()); + + EXPECT_CALL(*client_.raw_socket_, doRead(_)); + EXPECT_CALL(client_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doRead(client_.read_buffer_)); +} + +TEST_F(TsiSocketTest, HandshakeWithInternalError) { + auto raw_handshaker = tsi_create_fake_handshaker(/* is_client= */ 1); + const tsi_handshaker_vtable* vtable = raw_handshaker->vtable; + tsi_handshaker_vtable mock_vtable = *vtable; + mock_vtable.next = [](tsi_handshaker*, const unsigned char*, size_t, const unsigned char**, + size_t*, tsi_handshaker_result**, tsi_handshaker_on_next_done_cb, + void*) { return TSI_INTERNAL_ERROR; }; + raw_handshaker->vtable = &mock_vtable; + + client_.handshaker_factory_ = [&](Event::Dispatcher& dispatcher, + const Network::Address::InstanceConstSharedPtr&, + const Network::Address::InstanceConstSharedPtr&) { + CHandshakerPtr handshaker{raw_handshaker}; + + return std::make_unique(std::move(handshaker), dispatcher); + }; + + initialize(nullptr, nullptr); + + EXPECT_CALL(client_.callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush)); + // doWrite won't immediately fail, but it will result connection close. + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doWrite(client_.write_buffer_, false)); + + raw_handshaker->vtable = vtable; +} + class TsiSocketFactoryTest : public testing::Test { protected: void SetUp() override { From aeab41feb2d68b3884f740bae740060d619533b3 Mon Sep 17 00:00:00 2001 From: Lizan Zhou Date: Thu, 23 Aug 2018 15:13:54 -0700 Subject: [PATCH 06/10] prepend Signed-off-by: Lizan Zhou --- source/extensions/transport_sockets/alts/tsi_socket.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/source/extensions/transport_sockets/alts/tsi_socket.cc b/source/extensions/transport_sockets/alts/tsi_socket.cc index 041711ff2a0c..2f72d9362b9e 100644 --- a/source/extensions/transport_sockets/alts/tsi_socket.cc +++ b/source/extensions/transport_sockets/alts/tsi_socket.cc @@ -103,7 +103,8 @@ Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result tsi_handshaker_result_get_unused_bytes(handshaker_result, &unused_bytes, &unused_byte_size); ASSERT(status == TSI_OK); if (unused_byte_size > 0) { - raw_read_buffer_.add(unused_bytes, unused_byte_size); + raw_read_buffer_.prepend( + absl::string_view{reinterpret_cast(unused_bytes), unused_byte_size}); } ENVOY_CONN_LOG(debug, "TSI: Handshake successful: unused_bytes: {}", callbacks_->connection(), unused_byte_size); @@ -220,4 +221,4 @@ Network::TransportSocketPtr TsiSocketFactory::createTransportSocket() const { } // namespace Alts } // namespace TransportSockets } // namespace Extensions -} // namespace Envoy \ No newline at end of file +} // namespace Envoy From 2687a25fd32aa76f8450784cf7078a7597a81fdd Mon Sep 17 00:00:00 2001 From: Lizan Zhou Date: Thu, 30 Aug 2018 01:21:25 -0700 Subject: [PATCH 07/10] handle end of stream and read error more rigidly Signed-off-by: Lizan Zhou --- .../transport_sockets/alts/tsi_socket.cc | 24 ++++++--- .../transport_sockets/alts/tsi_socket.h | 2 + .../transport_sockets/alts/tsi_socket_test.cc | 50 +++++++++++++++++++ 3 files changed, 70 insertions(+), 6 deletions(-) diff --git a/source/extensions/transport_sockets/alts/tsi_socket.cc b/source/extensions/transport_sockets/alts/tsi_socket.cc index 2f72d9362b9e..f8c5f0903c56 100644 --- a/source/extensions/transport_sockets/alts/tsi_socket.cc +++ b/source/extensions/transport_sockets/alts/tsi_socket.cc @@ -119,6 +119,12 @@ Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result callbacks_->raiseEvent(Network::ConnectionEvent::Connected); } + if (read_error_ || (!handshake_complete_ && end_stream_read_)) { + ENVOY_CONN_LOG(debug, "TSI: Handshake failed: end of stream without enough data", + callbacks_->connection()); + return Network::PostIoAction::Close; + } + if (raw_read_buffer_.length() > 0) { callbacks_->setReadBufferReady(); } @@ -132,12 +138,18 @@ Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result } Network::IoResult TsiSocket::doRead(Buffer::Instance& buffer) { - Network::IoResult result = raw_buffer_socket_->doRead(raw_read_buffer_); - ENVOY_CONN_LOG(debug, "TSI: raw read result action {} bytes {} end_stream {}", - callbacks_->connection(), enumToInt(result.action_), result.bytes_processed_, - result.end_stream_read_); - if (result.action_ == Network::PostIoAction::Close && result.bytes_processed_ == 0) { - return result; + Network::IoResult result = {Network::PostIoAction::KeepOpen, 0, false}; + if (!end_stream_read_ && !read_error_) { + result = raw_buffer_socket_->doRead(raw_read_buffer_); + ENVOY_CONN_LOG(debug, "TSI: raw read result action {} bytes {} end_stream {}", + callbacks_->connection(), enumToInt(result.action_), result.bytes_processed_, + result.end_stream_read_); + if (result.action_ == Network::PostIoAction::Close && result.bytes_processed_ == 0) { + return result; + } + + end_stream_read_ = result.end_stream_read_; + read_error_ = result.action_ == Network::PostIoAction::Close; } if (!handshake_complete_) { diff --git a/source/extensions/transport_sockets/alts/tsi_socket.h b/source/extensions/transport_sockets/alts/tsi_socket.h index f94095ef6ee4..70f8a1d7aeff 100644 --- a/source/extensions/transport_sockets/alts/tsi_socket.h +++ b/source/extensions/transport_sockets/alts/tsi_socket.h @@ -86,6 +86,8 @@ class TsiSocket : public Network::TransportSocket, Envoy::Buffer::OwnedImpl raw_read_buffer_; Envoy::Buffer::OwnedImpl raw_write_buffer_; bool handshake_complete_{}; + bool end_stream_read_{}; + bool read_error_{}; }; /** diff --git a/test/extensions/transport_sockets/alts/tsi_socket_test.cc b/test/extensions/transport_sockets/alts/tsi_socket_test.cc index 588a2721b70d..eca68bba149d 100644 --- a/test/extensions/transport_sockets/alts/tsi_socket_test.cc +++ b/test/extensions/transport_sockets/alts/tsi_socket_test.cc @@ -271,6 +271,56 @@ TEST_F(TsiSocketTest, HandshakeWithUnusedData) { client_.tsi_socket_->doRead(client_.read_buffer_)); } +TEST_F(TsiSocketTest, HandshakeWithUnusedDataAndEndOfStream) { + initialize(nullptr, nullptr); + + doFakeInitHandshake(); + EXPECT_CALL(*client_.raw_socket_, doRead(_)); + EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doRead(client_.read_buffer_)); + EXPECT_EQ(makeFakeTsiFrame("CLIENT_FINISHED"), client_to_server_.toString()); + EXPECT_EQ(0L, client_.read_buffer_.length()); + + // Inject unused data + client_to_server_.add(makeFakeTsiFrame(ClientToServerData)); + + EXPECT_CALL(*server_.raw_socket_, doRead(_)).WillOnce(Invoke([&](Buffer::Instance& buffer) { + Network::IoResult result = {Network::PostIoAction::KeepOpen, client_to_server_.length(), true}; + buffer.move(client_to_server_); + return result; + })); + EXPECT_CALL(*server_.raw_socket_, doWrite(_, false)); + EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected)); + expectIoResult({Network::PostIoAction::KeepOpen, 21UL, true}, + server_.tsi_socket_->doRead(server_.read_buffer_)); + EXPECT_EQ(makeFakeTsiFrame("SERVER_FINISHED"), server_to_client_.toString()); + EXPECT_EQ(ClientToServerData, server_.read_buffer_.toString()); + + EXPECT_CALL(*client_.raw_socket_, doRead(_)); + EXPECT_CALL(client_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doRead(client_.read_buffer_)); +} + +TEST_F(TsiSocketTest, HandshakeWithReadError) { + initialize(nullptr, nullptr); + + doFakeInitHandshake(); + + EXPECT_CALL(*client_.raw_socket_, doRead(_)).WillOnce(Invoke([&](Buffer::Instance& buffer) { + Network::IoResult result = {Network::PostIoAction::Close, server_to_client_.length(), false}; + buffer.move(server_to_client_); + return result; + })); + EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)).Times(0); + EXPECT_CALL(client_.callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush)); + expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false}, + client_.tsi_socket_->doRead(client_.read_buffer_)); + EXPECT_EQ("", client_to_server_.toString()); + EXPECT_EQ(0L, client_.read_buffer_.length()); +} + TEST_F(TsiSocketTest, HandshakeWithInternalError) { auto raw_handshaker = tsi_create_fake_handshaker(/* is_client= */ 1); const tsi_handshaker_vtable* vtable = raw_handshaker->vtable; From 39c373a12fa7dfbc0a18b519198225a899aabd42 Mon Sep 17 00:00:00 2001 From: Lizan Zhou Date: Thu, 30 Aug 2018 17:30:10 -0700 Subject: [PATCH 08/10] address comments Signed-off-by: Lizan Zhou --- .../extensions/transport_sockets/alts/BUILD | 4 ++-- .../transport_sockets/alts/tsi_socket.cc | 22 +++++++++++++------ .../transport_sockets/alts/tsi_socket_test.cc | 17 +++++++++++++- 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/source/extensions/transport_sockets/alts/BUILD b/source/extensions/transport_sockets/alts/BUILD index 4182dd0fcafe..8f7df7a46a13 100644 --- a/source/extensions/transport_sockets/alts/BUILD +++ b/source/extensions/transport_sockets/alts/BUILD @@ -32,7 +32,6 @@ envoy_cc_library( hdrs = [ "tsi_frame_protector.h", ], - repository = "@envoy", deps = [ ":grpc_tsi_wrapper", "//source/common/buffer:buffer_lib", @@ -62,13 +61,14 @@ envoy_cc_library( hdrs = [ "tsi_socket.h", ], - repository = "@envoy", deps = [ ":noop_transport_socket_callbacks_lib", ":tsi_frame_protector", ":tsi_handshaker", "//include/envoy/network:transport_socket_interface", "//source/common/buffer:buffer_lib", + "//source/common/common:cleanup_lib", + "//source/common/common:empty_string", "//source/common/common:enum_to_int", "//source/common/network:raw_buffer_socket_lib", "//source/common/protobuf:utility_lib", diff --git a/source/extensions/transport_sockets/alts/tsi_socket.cc b/source/extensions/transport_sockets/alts/tsi_socket.cc index f8c5f0903c56..6f36d65875f3 100644 --- a/source/extensions/transport_sockets/alts/tsi_socket.cc +++ b/source/extensions/transport_sockets/alts/tsi_socket.cc @@ -1,6 +1,8 @@ #include "extensions/transport_sockets/alts/tsi_socket.h" #include "common/common/assert.h" +#include "common/common/cleanup.h" +#include "common/common/empty_string.h" #include "common/common/enum_to_int.h" namespace Envoy { @@ -26,7 +28,11 @@ void TsiSocket::setTransportSocketCallbacks(Envoy::Network::TransportSocketCallb raw_buffer_socket_->setTransportSocketCallbacks(*noop_callbacks_); } -std::string TsiSocket::protocol() const { return ""; } +std::string TsiSocket::protocol() const { + // TSI doesn't have a generic way to indicate application layer protocol. + // TODO(lizan): support application layer protocol from TSI for known TSIs. + return EMPTY_STRING; +} Network::PostIoAction TsiSocket::doHandshake() { ASSERT(!handshake_complete_); @@ -74,7 +80,10 @@ Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result if (status == TSI_OK && handshaker_result != nullptr) { tsi_peer peer; - tsi_handshaker_result_extract_peer(handshaker_result, &peer); + // returns TSI_OK assuming there is no fatal error. Asserting OK. + status = tsi_handshaker_result_extract_peer(handshaker_result, &peer); + ASSERT(status == TSI_OK); + Cleanup peer_cleanup([&peer]() { tsi_peer_destruct(&peer); }); ENVOY_CONN_LOG(debug, "TSI: Handshake successful: peer properties: {}", callbacks_->connection(), peer.property_count); for (size_t i = 0; i < peer.property_count; ++i) { @@ -88,17 +97,16 @@ Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result ENVOY_CONN_LOG(debug, "TSI: Handshake validation succeeded.", callbacks_->connection()); } else { ENVOY_CONN_LOG(info, "TSI: Handshake validation failed: {}", callbacks_->connection(), err); - tsi_peer_destruct(&peer); return Network::PostIoAction::Close; } } else { ENVOY_CONN_LOG(debug, "TSI: Handshake validation skipped.", callbacks_->connection()); } - tsi_peer_destruct(&peer); const unsigned char* unused_bytes; size_t unused_byte_size; + // returns TSI_OK assuming there is no fatal error. Asserting OK. status = tsi_handshaker_result_get_unused_bytes(handshaker_result, &unused_bytes, &unused_byte_size); ASSERT(status == TSI_OK); @@ -109,6 +117,7 @@ Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result ENVOY_CONN_LOG(debug, "TSI: Handshake successful: unused_bytes: {}", callbacks_->connection(), unused_byte_size); + // returns TSI_OK assuming there is no fatal error. Asserting OK. tsi_frame_protector* frame_protector; status = tsi_handshaker_result_create_frame_protector(handshaker_result, NULL, &frame_protector); @@ -180,9 +189,8 @@ Network::IoResult TsiSocket::doRead(Buffer::Instance& buffer) { Network::IoResult TsiSocket::doWrite(Buffer::Instance& buffer, bool end_stream) { if (!handshake_complete_) { Network::PostIoAction action = doHandshake(); - if (action == Network::PostIoAction::Close) { - return {action, 0, false}; - } + ASSERT(action == Network::PostIoAction::KeepOpen); + // TODO(lizan): Handle synchronous handshake when TsiHandshaker supports it. } if (handshake_complete_) { diff --git a/test/extensions/transport_sockets/alts/tsi_socket_test.cc b/test/extensions/transport_sockets/alts/tsi_socket_test.cc index eca68bba149d..69ad81494f4c 100644 --- a/test/extensions/transport_sockets/alts/tsi_socket_test.cc +++ b/test/extensions/transport_sockets/alts/tsi_socket_test.cc @@ -303,6 +303,21 @@ TEST_F(TsiSocketTest, HandshakeWithUnusedDataAndEndOfStream) { client_.tsi_socket_->doRead(client_.read_buffer_)); } +TEST_F(TsiSocketTest, HandshakeWithImmediateReadError) { + initialize(nullptr, nullptr); + + EXPECT_CALL(*client_.raw_socket_, doRead(_)).WillOnce(Invoke([&](Buffer::Instance& buffer) { + Network::IoResult result = {Network::PostIoAction::Close, server_to_client_.length(), false}; + buffer.move(server_to_client_); + return result; + })); + EXPECT_CALL(*client_.raw_socket_, doWrite(_, false)).Times(0); + expectIoResult({Network::PostIoAction::Close, 0UL, false}, + client_.tsi_socket_->doRead(client_.read_buffer_)); + EXPECT_EQ("", client_to_server_.toString()); + EXPECT_EQ(0L, client_.read_buffer_.length()); +} + TEST_F(TsiSocketTest, HandshakeWithReadError) { initialize(nullptr, nullptr); @@ -375,4 +390,4 @@ TEST_F(TsiSocketFactoryTest, ImplementsSecureTransport) { } // namespace Alts } // namespace TransportSockets } // namespace Extensions -} // namespace Envoy \ No newline at end of file +} // namespace Envoy From f13f64a2160ab586d925654e990c32953a77ce42 Mon Sep 17 00:00:00 2001 From: Lizan Zhou Date: Thu, 30 Aug 2018 17:37:34 -0700 Subject: [PATCH 09/10] nits Signed-off-by: Lizan Zhou --- source/extensions/transport_sockets/alts/tsi_socket.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/extensions/transport_sockets/alts/tsi_socket.cc b/source/extensions/transport_sockets/alts/tsi_socket.cc index 6f36d65875f3..7e986f544399 100644 --- a/source/extensions/transport_sockets/alts/tsi_socket.cc +++ b/source/extensions/transport_sockets/alts/tsi_socket.cc @@ -92,11 +92,11 @@ Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result } if (handshake_validator_) { std::string err; - bool peer_validated = handshake_validator_(peer, err); + const bool peer_validated = handshake_validator_(peer, err); if (peer_validated) { ENVOY_CONN_LOG(debug, "TSI: Handshake validation succeeded.", callbacks_->connection()); } else { - ENVOY_CONN_LOG(info, "TSI: Handshake validation failed: {}", callbacks_->connection(), err); + ENVOY_CONN_LOG(debug, "TSI: Handshake validation failed: {}", callbacks_->connection(), err); return Network::PostIoAction::Close; } } else { From fc0b98e12c65a5d356364256ef4c60b2bcfbe617 Mon Sep 17 00:00:00 2001 From: Lizan Zhou Date: Thu, 30 Aug 2018 17:54:05 -0700 Subject: [PATCH 10/10] fix format Signed-off-by: Lizan Zhou --- source/extensions/transport_sockets/alts/tsi_socket.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/source/extensions/transport_sockets/alts/tsi_socket.cc b/source/extensions/transport_sockets/alts/tsi_socket.cc index 7e986f544399..c20887f46dc0 100644 --- a/source/extensions/transport_sockets/alts/tsi_socket.cc +++ b/source/extensions/transport_sockets/alts/tsi_socket.cc @@ -96,7 +96,8 @@ Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result if (peer_validated) { ENVOY_CONN_LOG(debug, "TSI: Handshake validation succeeded.", callbacks_->connection()); } else { - ENVOY_CONN_LOG(debug, "TSI: Handshake validation failed: {}", callbacks_->connection(), err); + ENVOY_CONN_LOG(debug, "TSI: Handshake validation failed: {}", callbacks_->connection(), + err); return Network::PostIoAction::Close; } } else {