Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

alts: add gRPC TSI socket #4153

Merged
merged 15 commits into from
Aug 31, 2018
23 changes: 22 additions & 1 deletion source/extensions/transport_sockets/alts/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ envoy_cc_library(
hdrs = [
"tsi_frame_protector.h",
],
repository = "@envoy",
deps = [
":grpc_tsi_wrapper",
"//source/common/buffer:buffer_lib",
Expand All @@ -54,6 +53,28 @@ envoy_cc_library(
],
)

envoy_cc_library(
name = "tsi_socket",
srcs = [
"tsi_socket.cc",
],
hdrs = [
"tsi_socket.h",
],
deps = [
":noop_transport_socket_callbacks_lib",
":tsi_frame_protector",
":tsi_handshaker",
"//include/envoy/network:transport_socket_interface",
"//source/common/buffer:buffer_lib",
"//source/common/common:cleanup_lib",
"//source/common/common:empty_string",
"//source/common/common:enum_to_int",
"//source/common/network:raw_buffer_socket_lib",
"//source/common/protobuf:utility_lib",
],
)

envoy_cc_library(
name = "noop_transport_socket_callbacks_lib",
hdrs = ["noop_transport_socket_callbacks.h"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class NoOpTransportSocketCallbacks : public Network::TransportSocketCallbacks {
Network::TransportSocketCallbacks& parent_;
};

typedef std::unique_ptr<NoOpTransportSocketCallbacks> NoOpTransportSocketCallbacksPtr;

} // namespace Alts
} // namespace TransportSockets
} // namespace Extensions
Expand Down
245 changes: 245 additions & 0 deletions source/extensions/transport_sockets/alts/tsi_socket.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
#include "extensions/transport_sockets/alts/tsi_socket.h"

#include "common/common/assert.h"
#include "common/common/cleanup.h"
#include "common/common/empty_string.h"
#include "common/common/enum_to_int.h"

namespace Envoy {
namespace Extensions {
namespace TransportSockets {
namespace Alts {

TsiSocket::TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator,
Network::TransportSocketPtr&& raw_socket)
: handshaker_factory_(handshaker_factory), handshake_validator_(handshake_validator),
raw_buffer_socket_(std::move(raw_socket)) {}

TsiSocket::TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator)
: TsiSocket(handshaker_factory, handshake_validator,
std::make_unique<Network::RawBufferSocket>()) {}

TsiSocket::~TsiSocket() { ASSERT(!handshaker_); }

void TsiSocket::setTransportSocketCallbacks(Envoy::Network::TransportSocketCallbacks& callbacks) {
callbacks_ = &callbacks;

noop_callbacks_ = std::make_unique<NoOpTransportSocketCallbacks>(callbacks);
raw_buffer_socket_->setTransportSocketCallbacks(*noop_callbacks_);
}

std::string TsiSocket::protocol() const {
// TSI doesn't have a generic way to indicate application layer protocol.
// TODO(lizan): support application layer protocol from TSI for known TSIs.
return EMPTY_STRING;
}

Network::PostIoAction TsiSocket::doHandshake() {
ASSERT(!handshake_complete_);
ENVOY_CONN_LOG(debug, "TSI: doHandshake", callbacks_->connection());

if (!handshaker_) {
handshaker_ = handshaker_factory_(callbacks_->connection().dispatcher(),
callbacks_->connection().localAddress(),
callbacks_->connection().remoteAddress());
handshaker_->setHandshakerCallbacks(*this);
}

if (!handshaker_next_calling_) {
doHandshakeNext();
}
return Network::PostIoAction::KeepOpen;
}

void TsiSocket::doHandshakeNext() {
ENVOY_CONN_LOG(debug, "TSI: doHandshake next: received: {}", callbacks_->connection(),
raw_read_buffer_.length());
handshaker_next_calling_ = true;
Buffer::OwnedImpl handshaker_buffer;
handshaker_buffer.move(raw_read_buffer_);
handshaker_->next(handshaker_buffer);
}

Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result) {
ASSERT(next_result);

ENVOY_CONN_LOG(debug, "TSI: doHandshake next done: status: {} to_send: {}",
callbacks_->connection(), next_result->status_, next_result->to_send_->length());

tsi_result status = next_result->status_;
tsi_handshaker_result* handshaker_result = next_result->result_.get();

if (status != TSI_INCOMPLETE_DATA && status != TSI_OK) {
ENVOY_CONN_LOG(debug, "TSI: Handshake failed: status: {}", callbacks_->connection(), status);
return Network::PostIoAction::Close;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this code covered in test?

}

if (next_result->to_send_->length() > 0) {
raw_write_buffer_.move(*next_result->to_send_);
}

if (status == TSI_OK && handshaker_result != nullptr) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if status == TSI_OK && handshaker_results == nullptr? I see various error handling below, but it's not immediately clear how/why it covers these cases. Can we put in some more ASSERTS or make this more robust?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

status == TSI_OK && handshaker_results == nullptr is a common case when a handshake is in progress, i.e. need to read/write more data. It doesn't follow in any error handlings below, just need more calls to doHandshake.

tsi_peer peer;
// returns TSI_OK assuming there is no fatal error. Asserting OK.
status = tsi_handshaker_result_extract_peer(handshaker_result, &peer);
ASSERT(status == TSI_OK);
Cleanup peer_cleanup([&peer]() { tsi_peer_destruct(&peer); });
ENVOY_CONN_LOG(debug, "TSI: Handshake successful: peer properties: {}",
callbacks_->connection(), peer.property_count);
for (size_t i = 0; i < peer.property_count; ++i) {
ENVOY_CONN_LOG(debug, " {}: {}", callbacks_->connection(), peer.properties[i].name,
std::string(peer.properties[i].value.data, peer.properties[i].value.length));
}
if (handshake_validator_) {
std::string err;
const bool peer_validated = handshake_validator_(peer, err);
if (peer_validated) {
ENVOY_CONN_LOG(debug, "TSI: Handshake validation succeeded.", callbacks_->connection());
} else {
ENVOY_CONN_LOG(debug, "TSI: Handshake validation failed: {}", callbacks_->connection(),
err);
return Network::PostIoAction::Close;
}
} else {
ENVOY_CONN_LOG(debug, "TSI: Handshake validation skipped.", callbacks_->connection());
}

const unsigned char* unused_bytes;
size_t unused_byte_size;

// returns TSI_OK assuming there is no fatal error. Asserting OK.
status =
tsi_handshaker_result_get_unused_bytes(handshaker_result, &unused_bytes, &unused_byte_size);
ASSERT(status == TSI_OK);
if (unused_byte_size > 0) {
raw_read_buffer_.prepend(
absl::string_view{reinterpret_cast<const char*>(unused_bytes), unused_byte_size});
}
ENVOY_CONN_LOG(debug, "TSI: Handshake successful: unused_bytes: {}", callbacks_->connection(),
unused_byte_size);

// returns TSI_OK assuming there is no fatal error. Asserting OK.
tsi_frame_protector* frame_protector;
status =
tsi_handshaker_result_create_frame_protector(handshaker_result, NULL, &frame_protector);
ASSERT(status == TSI_OK);
frame_protector_ = std::make_unique<TsiFrameProtector>(frame_protector);

handshake_complete_ = true;
callbacks_->raiseEvent(Network::ConnectionEvent::Connected);
}

if (read_error_ || (!handshake_complete_ && end_stream_read_)) {
ENVOY_CONN_LOG(debug, "TSI: Handshake failed: end of stream without enough data",
callbacks_->connection());
return Network::PostIoAction::Close;
}

if (raw_read_buffer_.length() > 0) {
callbacks_->setReadBufferReady();
}

// Try to write raw buffer when next call is done, even this is not in do[Read|Write] stack.
if (raw_write_buffer_.length() > 0) {
return raw_buffer_socket_->doWrite(raw_write_buffer_, false).action_;
}

return Network::PostIoAction::KeepOpen;
}

Network::IoResult TsiSocket::doRead(Buffer::Instance& buffer) {
Network::IoResult result = {Network::PostIoAction::KeepOpen, 0, false};
if (!end_stream_read_ && !read_error_) {
result = raw_buffer_socket_->doRead(raw_read_buffer_);
ENVOY_CONN_LOG(debug, "TSI: raw read result action {} bytes {} end_stream {}",
callbacks_->connection(), enumToInt(result.action_), result.bytes_processed_,
result.end_stream_read_);
if (result.action_ == Network::PostIoAction::Close && result.bytes_processed_ == 0) {
return result;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing coverage.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}

end_stream_read_ = result.end_stream_read_;
read_error_ = result.action_ == Network::PostIoAction::Close;
}

if (!handshake_complete_) {
Network::PostIoAction action = doHandshake();
if (action == Network::PostIoAction::Close || !handshake_complete_) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this condition ever satisfied? It seems that doHandshake() always return Network::PostIoAction::KeepOpen.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave this uncovered until next PR to follow up on synchronous TSI optimization.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that's for synchronous handshake! Makes sense.

return {action, 0, false};
}
}

if (handshake_complete_) {
ASSERT(frame_protector_);

uint64_t read_size = raw_read_buffer_.length();
ENVOY_CONN_LOG(debug, "TSI: unprotecting buffer size: {}", callbacks_->connection(),
raw_read_buffer_.length());
tsi_result status = frame_protector_->unprotect(raw_read_buffer_, buffer);
ENVOY_CONN_LOG(debug, "TSI: unprotected buffer left: {} result: {}", callbacks_->connection(),
raw_read_buffer_.length(), tsi_result_to_string(status));
result.bytes_processed_ = read_size - raw_read_buffer_.length();
}

ENVOY_CONN_LOG(debug, "TSI: do read result action {} bytes {} end_stream {}",
callbacks_->connection(), enumToInt(result.action_), result.bytes_processed_,
result.end_stream_read_);
return result;
}

Network::IoResult TsiSocket::doWrite(Buffer::Instance& buffer, bool end_stream) {
if (!handshake_complete_) {
Network::PostIoAction action = doHandshake();
ASSERT(action == Network::PostIoAction::KeepOpen);
// TODO(lizan): Handle synchronous handshake when TsiHandshaker supports it.
}

if (handshake_complete_) {
ASSERT(frame_protector_);
ENVOY_CONN_LOG(debug, "TSI: protecting buffer size: {}", callbacks_->connection(),
buffer.length());
tsi_result status = frame_protector_->protect(buffer, raw_write_buffer_);
ENVOY_CONN_LOG(debug, "TSI: protected buffer left: {} result: {}", callbacks_->connection(),
buffer.length(), tsi_result_to_string(status));
}

if (raw_write_buffer_.length() > 0) {
ENVOY_CONN_LOG(debug, "TSI: raw_write length {} end_stream {}", callbacks_->connection(),
raw_write_buffer_.length(), end_stream);
return raw_buffer_socket_->doWrite(raw_write_buffer_, end_stream && (buffer.length() == 0));
}
return {Network::PostIoAction::KeepOpen, 0, false};
}

void TsiSocket::closeSocket(Network::ConnectionEvent) {
if (handshaker_) {
handshaker_.release()->deferredDelete();
}
}

void TsiSocket::onConnected() { ASSERT(!handshake_complete_); }

void TsiSocket::onNextDone(NextResultPtr&& result) {
handshaker_next_calling_ = false;

Network::PostIoAction action = doHandshakeNextDone(std::move(result));
if (action == Network::PostIoAction::Close) {
callbacks_->connection().close(Network::ConnectionCloseType::NoFlush);
}
}

TsiSocketFactory::TsiSocketFactory(HandshakerFactory handshaker_factory,
HandshakeValidator handshake_validator)
: handshaker_factory_(std::move(handshaker_factory)),
handshake_validator_(std::move(handshake_validator)) {}

bool TsiSocketFactory::implementsSecureTransport() const { return true; }

Network::TransportSocketPtr TsiSocketFactory::createTransportSocket() const {
return std::make_unique<TsiSocket>(handshaker_factory_, handshake_validator_);
}

} // namespace Alts
} // namespace TransportSockets
} // namespace Extensions
} // namespace Envoy
Loading