Skip to content

Commit

Permalink
network: adding some accessors for ALPN work. (#13785)
Browse files Browse the repository at this point in the history
Adding some basic functionality to the Network::Connection needed to disassociate TCP connections from their connection pool before handing them off to a codec client.

Additional Description:
I really wanted to add Network::Connection Connecting as a state, and I did a PR to that effect and cleaned up the 30 or so call sites in Envoy, but then realized that because it broke nearly every use of state, we either needed to have a separate boolean, or change all the enum values so folks would have to fix their downstream code. I chose the former as the lesser of two evils.

I also waffled a bit about removing the callbacks or sticking with the remove pattern in codec_helper.h and went with consistency over cleanup.

Risk Level: low (not yet used)
Testing: unit tests
Docs Changes: n/a
Release Notes: n/a
Platform Specific Features: n/a
Part of #3431

Signed-off-by: Alyssa Wilk <alyssar@chromium.org>
  • Loading branch information
alyssawilk authored Oct 28, 2020
1 parent d9ea34f commit db756af
Show file tree
Hide file tree
Showing 13 changed files with 104 additions and 3 deletions.
11 changes: 11 additions & 0 deletions include/envoy/network/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ class Connection : public Event::DeferredDeletable, public FilterManager {
*/
virtual void addConnectionCallbacks(ConnectionCallbacks& cb) PURE;

/**
* Unregister callbacks which previously fired when connection events occur.
*/
virtual void removeConnectionCallbacks(ConnectionCallbacks& cb) PURE;

/**
* Register for callback every time bytes are written to the underlying TransportSocket.
*/
Expand Down Expand Up @@ -241,6 +246,12 @@ class Connection : public Event::DeferredDeletable, public FilterManager {
*/
virtual State state() const PURE;

/**
* @return true if the connection has not completed connecting, false if the connection is
* established.
*/
virtual bool connecting() const PURE;

/**
* Write data to the connection. Will iterate through downstream filters with the buffer if any
* are installed.
Expand Down
5 changes: 5 additions & 0 deletions include/envoy/network/filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ class FilterManager {
*/
virtual void addReadFilter(ReadFilterSharedPtr filter) PURE;

/**
* Remove a read filter from the connection.
*/
virtual void removeReadFilter(ReadFilterSharedPtr filter) PURE;

/**
* Initialize all of the installed read filters. This effectively calls onNewConnection() on
* each of them.
Expand Down
12 changes: 10 additions & 2 deletions source/common/network/connection_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ void ConnectionImpl::addReadFilter(ReadFilterSharedPtr filter) {
filter_manager_.addReadFilter(filter);
}

void ConnectionImpl::removeReadFilter(ReadFilterSharedPtr filter) {
filter_manager_.removeReadFilter(filter);
}

bool ConnectionImpl::initializeReadFilters() { return filter_manager_.initializeReadFilters(); }

void ConnectionImpl::close(ConnectionCloseType type) {
Expand Down Expand Up @@ -485,7 +489,9 @@ void ConnectionImpl::onWriteBufferLowWatermark() {
ASSERT(write_buffer_above_high_watermark_);
write_buffer_above_high_watermark_ = false;
for (ConnectionCallbacks* callback : callbacks_) {
callback->onBelowWriteBufferLowWatermark();
if (callback) {
callback->onBelowWriteBufferLowWatermark();
}
}
}

Expand All @@ -494,7 +500,9 @@ void ConnectionImpl::onWriteBufferHighWatermark() {
ASSERT(!write_buffer_above_high_watermark_);
write_buffer_above_high_watermark_ = true;
for (ConnectionCallbacks* callback : callbacks_) {
callback->onAboveWriteBufferHighWatermark();
if (callback) {
callback->onAboveWriteBufferHighWatermark();
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions source/common/network/connection_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class ConnectionImpl : public ConnectionImplBase, public TransportSocketCallback
void addWriteFilter(WriteFilterSharedPtr filter) override;
void addFilter(FilterSharedPtr filter) override;
void addReadFilter(ReadFilterSharedPtr filter) override;
void removeReadFilter(ReadFilterSharedPtr filter) override;
bool initializeReadFilters() override;

// Network::Connection
Expand All @@ -78,6 +79,7 @@ class ConnectionImpl : public ConnectionImplBase, public TransportSocketCallback
absl::optional<UnixDomainSocketPeerCredentials> unixSocketPeerCredentials() const override;
Ssl::ConnectionInfoConstSharedPtr ssl() const override { return transport_socket_->ssl(); }
State state() const override;
bool connecting() const override { return connecting_; }
void write(Buffer::Instance& data, bool end_stream) override;
void setBufferLimits(uint32_t limit) override;
uint32_t bufferLimit() const override { return read_buffer_limit_; }
Expand Down
14 changes: 13 additions & 1 deletion source/common/network/connection_impl_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ void ConnectionImplBase::addConnectionCallbacks(ConnectionCallbacks& cb) {
callbacks_.push_back(&cb);
}

void ConnectionImplBase::removeConnectionCallbacks(ConnectionCallbacks& callbacks) {
// For performance/safety reasons we just clear the callback and do not resize the list
for (auto& callback : callbacks_) {
if (callback == &callbacks) {
callback = nullptr;
return;
}
}
}

void ConnectionImplBase::hashKey(std::vector<uint8_t>& hash) const { addIdToHashKey(hash, id()); }

void ConnectionImplBase::setConnectionStats(const ConnectionStats& stats) {
Expand Down Expand Up @@ -45,7 +55,9 @@ void ConnectionImplBase::raiseConnectionEvent(ConnectionEvent event) {
for (ConnectionCallbacks* callback : callbacks_) {
// TODO(mattklein123): If we close while raising a connected event we should not raise further
// connected events.
callback->onEvent(event);
if (callback != nullptr) {
callback->onEvent(event);
}
}
}

Expand Down
1 change: 1 addition & 0 deletions source/common/network/connection_impl_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ConnectionImplBase : public FilterManagerConnection,

// Network::Connection
void addConnectionCallbacks(ConnectionCallbacks& cb) override;
void removeConnectionCallbacks(ConnectionCallbacks& cb) override;
Event::Dispatcher& dispatcher() override { return dispatcher_; }
uint64_t id() const override { return id_; }
void hashKey(std::vector<uint8_t>& hash) const override;
Expand Down
12 changes: 12 additions & 0 deletions source/common/network/filter_manager_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ void FilterManagerImpl::addReadFilter(ReadFilterSharedPtr filter) {
LinkedList::moveIntoListBack(std::move(new_filter), upstream_filters_);
}

void FilterManagerImpl::removeReadFilter(ReadFilterSharedPtr filter_to_remove) {
// For perf/safety reasons, null this out rather than removing.
for (auto& filter : upstream_filters_) {
if (filter->filter_ == filter_to_remove) {
filter->filter_ = nullptr;
}
}
}

bool FilterManagerImpl::initializeReadFilters() {
if (upstream_filters_.empty()) {
return false;
Expand All @@ -53,6 +62,9 @@ void FilterManagerImpl::onContinueReading(ActiveReadFilter* filter,
}

for (; entry != upstream_filters_.end(); entry++) {
if (!(*entry)->filter_) {
continue;
}
if (!(*entry)->initialized_) {
(*entry)->initialized_ = true;
FilterStatus status = (*entry)->filter_->onNewConnection();
Expand Down
1 change: 1 addition & 0 deletions source/common/network/filter_manager_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class FilterManagerImpl {
void addWriteFilter(WriteFilterSharedPtr filter);
void addFilter(FilterSharedPtr filter);
void addReadFilter(ReadFilterSharedPtr filter);
void removeReadFilter(ReadFilterSharedPtr filter);
bool initializeReadFilters();
void onRead();
FilterStatus onWrite();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ void QuicFilterManagerConnectionImpl::addReadFilter(Network::ReadFilterSharedPtr
filter_manager_.addReadFilter(filter);
}

void QuicFilterManagerConnectionImpl::removeReadFilter(Network::ReadFilterSharedPtr filter) {
filter_manager_.removeReadFilter(filter);
}

bool QuicFilterManagerConnectionImpl::initializeReadFilters() {
return filter_manager_.initializeReadFilters();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class QuicFilterManagerConnectionImpl : public Network::ConnectionImplBase {
void addWriteFilter(Network::WriteFilterSharedPtr filter) override;
void addFilter(Network::FilterSharedPtr filter) override;
void addReadFilter(Network::ReadFilterSharedPtr filter) override;
void removeReadFilter(Network::ReadFilterSharedPtr filter) override;
bool initializeReadFilters() override;

// Network::Connection
Expand Down Expand Up @@ -63,6 +64,12 @@ class QuicFilterManagerConnectionImpl : public Network::ConnectionImplBase {
}
return Network::Connection::State::Closed;
}
bool connecting() const override {
if (quic_connection_ != nullptr && quic_connection_->connected()) {
return false;
}
return true;
}
void write(Buffer::Instance& /*data*/, bool /*end_stream*/) override {
// All writes should be handled by Quic internally.
NOT_REACHED_GCOVR_EXCL_LINE;
Expand Down
7 changes: 7 additions & 0 deletions source/server/api_listener_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,18 @@ class ApiListenerImplBase : public ApiListener,
}
void addFilter(Network::FilterSharedPtr) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; }
void addReadFilter(Network::ReadFilterSharedPtr) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; }
void removeReadFilter(Network::ReadFilterSharedPtr) override {
NOT_IMPLEMENTED_GCOVR_EXCL_LINE;
}
bool initializeReadFilters() override { return true; }

// Network::Connection
void addConnectionCallbacks(Network::ConnectionCallbacks& cb) override {
callbacks_.push_back(&cb);
}
void removeConnectionCallbacks(Network::ConnectionCallbacks& cb) override {
callbacks_.remove(&cb);
}
void addBytesSentCallback(Network::Connection::BytesSentCb) override {
NOT_IMPLEMENTED_GCOVR_EXCL_LINE;
}
Expand Down Expand Up @@ -121,6 +127,7 @@ class ApiListenerImplBase : public ApiListener,
Ssl::ConnectionInfoConstSharedPtr ssl() const override { return nullptr; }
absl::string_view requestedServerName() const override { return EMPTY_STRING; }
State state() const override { return Network::Connection::State::Open; }
bool connecting() const override { return false; }
void write(Buffer::Instance&, bool) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; }
void setBufferLimits(uint32_t) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; }
uint32_t bufferLimit() const override { return 65000; }
Expand Down
25 changes: 25 additions & 0 deletions test/common/network/connection_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,15 @@ TEST_P(ConnectionImplTest, CloseDuringConnectCallback) {
Buffer::OwnedImpl buffer("hello world");
client_connection_->write(buffer, false);
client_connection_->connect();
EXPECT_TRUE(client_connection_->connecting());

StrictMock<MockConnectionCallbacks> added_and_removed_callbacks;
// Make sure removed connections don't get events.
client_connection_->addConnectionCallbacks(added_and_removed_callbacks);
client_connection_->removeConnectionCallbacks(added_and_removed_callbacks);

std::shared_ptr<MockReadFilter> add_and_remove_filter =
std::make_shared<StrictMock<MockReadFilter>>();

EXPECT_CALL(client_callbacks_, onEvent(ConnectionEvent::Connected))
.WillOnce(Invoke([&](Network::ConnectionEvent) -> void {
Expand All @@ -302,6 +311,8 @@ TEST_P(ConnectionImplTest, CloseDuringConnectCallback) {
std::move(socket), Network::Test::createRawBufferSocket(), stream_info_);
server_connection_->addConnectionCallbacks(server_callbacks_);
server_connection_->addReadFilter(read_filter_);
server_connection_->addReadFilter(add_and_remove_filter);
server_connection_->removeReadFilter(add_and_remove_filter);
}));

EXPECT_CALL(server_callbacks_, onEvent(ConnectionEvent::RemoteClose))
Expand Down Expand Up @@ -537,13 +548,22 @@ TEST_P(ConnectionImplTest, ConnectionStats) {

MockConnectionStats client_connection_stats;
client_connection_->setConnectionStats(client_connection_stats.toBufferStats());
EXPECT_TRUE(client_connection_->connecting());
client_connection_->connect();
// The Network::Connection class oddly uses onWrite as its indicator of if
// it's done connection, rather than the Connected event.
EXPECT_TRUE(client_connection_->connecting());

std::shared_ptr<MockWriteFilter> write_filter(new MockWriteFilter());
std::shared_ptr<MockFilter> filter(new MockFilter());
client_connection_->addFilter(filter);
client_connection_->addWriteFilter(write_filter);

// Make sure removed filters don't get callbacks.
std::shared_ptr<MockReadFilter> read_filter(new StrictMock<MockReadFilter>());
client_connection_->addReadFilter(read_filter);
client_connection_->removeReadFilter(read_filter);

Sequence s1;
EXPECT_CALL(*write_filter, onWrite(_, _))
.InSequence(s1)
Expand Down Expand Up @@ -854,6 +874,11 @@ TEST_P(ConnectionImplTest, WriteWatermarks) {
setUpBasicConnection();
EXPECT_FALSE(client_connection_->aboveHighWatermark());

StrictMock<MockConnectionCallbacks> added_and_removed_callbacks;
// Make sure removed connections don't get events.
client_connection_->addConnectionCallbacks(added_and_removed_callbacks);
client_connection_->removeConnectionCallbacks(added_and_removed_callbacks);

// Stick 5 bytes in the connection buffer.
std::unique_ptr<Buffer::OwnedImpl> buffer(new Buffer::OwnedImpl("hello"));
int buffer_len = buffer->length();
Expand Down
6 changes: 6 additions & 0 deletions test/mocks/network/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ class MockConnectionBase {
#define DEFINE_MOCK_CONNECTION_MOCK_METHODS \
/* Network::Connection */ \
MOCK_METHOD(void, addConnectionCallbacks, (ConnectionCallbacks & cb)); \
MOCK_METHOD(void, removeConnectionCallbacks, (ConnectionCallbacks & cb)); \
MOCK_METHOD(void, addBytesSentCallback, (BytesSentCb cb)); \
MOCK_METHOD(void, addWriteFilter, (WriteFilterSharedPtr filter)); \
MOCK_METHOD(void, addFilter, (FilterSharedPtr filter)); \
MOCK_METHOD(void, addReadFilter, (ReadFilterSharedPtr filter)); \
MOCK_METHOD(void, removeReadFilter, (ReadFilterSharedPtr filter)); \
MOCK_METHOD(void, enableHalfClose, (bool enabled)); \
MOCK_METHOD(void, close, (ConnectionCloseType type)); \
MOCK_METHOD(Event::Dispatcher&, dispatcher, ()); \
Expand All @@ -72,6 +74,7 @@ class MockConnectionBase {
MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, ssl, (), (const)); \
MOCK_METHOD(absl::string_view, requestedServerName, (), (const)); \
MOCK_METHOD(State, state, (), (const)); \
MOCK_METHOD(bool, connecting, (), (const)); \
MOCK_METHOD(void, write, (Buffer::Instance & data, bool end_stream)); \
MOCK_METHOD(void, setBufferLimits, (uint32_t limit)); \
MOCK_METHOD(uint32_t, bufferLimit, (), (const)); \
Expand Down Expand Up @@ -128,10 +131,12 @@ class MockFilterManagerConnection : public FilterManagerConnection, public MockC

// Network::Connection
MOCK_METHOD(void, addConnectionCallbacks, (ConnectionCallbacks & cb));
MOCK_METHOD(void, removeConnectionCallbacks, (ConnectionCallbacks & cb));
MOCK_METHOD(void, addBytesSentCallback, (BytesSentCb cb));
MOCK_METHOD(void, addWriteFilter, (WriteFilterSharedPtr filter));
MOCK_METHOD(void, addFilter, (FilterSharedPtr filter));
MOCK_METHOD(void, addReadFilter, (ReadFilterSharedPtr filter));
MOCK_METHOD(void, removeReadFilter, (ReadFilterSharedPtr filter));
MOCK_METHOD(void, enableHalfClose, (bool enabled));
MOCK_METHOD(void, close, (ConnectionCloseType type));
MOCK_METHOD(Event::Dispatcher&, dispatcher, ());
Expand All @@ -152,6 +157,7 @@ class MockFilterManagerConnection : public FilterManagerConnection, public MockC
MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, ssl, (), (const));
MOCK_METHOD(absl::string_view, requestedServerName, (), (const));
MOCK_METHOD(State, state, (), (const));
MOCK_METHOD(bool, connecting, (), (const));
MOCK_METHOD(void, write, (Buffer::Instance & data, bool end_stream));
MOCK_METHOD(void, setBufferLimits, (uint32_t limit));
MOCK_METHOD(uint32_t, bufferLimit, (), (const));
Expand Down

0 comments on commit db756af

Please sign in to comment.