diff --git a/include/envoy/network/connection.h b/include/envoy/network/connection.h index 01d8313bf2d7..7efc6ada8175 100644 --- a/include/envoy/network/connection.h +++ b/include/envoy/network/connection.h @@ -101,6 +101,11 @@ class Connection : public Event::DeferredDeletable, public FilterManager { */ virtual void addConnectionCallbacks(ConnectionCallbacks& cb) PURE; + /** + * Unregister callbacks which previously fired when connection events occur. + */ + virtual void removeConnectionCallbacks(ConnectionCallbacks& cb) PURE; + /** * Register for callback every time bytes are written to the underlying TransportSocket. */ @@ -241,6 +246,12 @@ class Connection : public Event::DeferredDeletable, public FilterManager { */ virtual State state() const PURE; + /** + * @return true if the connection has not completed connecting, false if the connection is + * established. + */ + virtual bool connecting() const PURE; + /** * Write data to the connection. Will iterate through downstream filters with the buffer if any * are installed. diff --git a/include/envoy/network/filter.h b/include/envoy/network/filter.h index 18f7bda54ad4..b8bec913b462 100644 --- a/include/envoy/network/filter.h +++ b/include/envoy/network/filter.h @@ -227,6 +227,11 @@ class FilterManager { */ virtual void addReadFilter(ReadFilterSharedPtr filter) PURE; + /** + * Remove a read filter from the connection. + */ + virtual void removeReadFilter(ReadFilterSharedPtr filter) PURE; + /** * Initialize all of the installed read filters. This effectively calls onNewConnection() on * each of them. diff --git a/source/common/network/connection_impl.cc b/source/common/network/connection_impl.cc index 46d38c8c5cdf..8f8e9e759662 100644 --- a/source/common/network/connection_impl.cc +++ b/source/common/network/connection_impl.cc @@ -101,6 +101,10 @@ void ConnectionImpl::addReadFilter(ReadFilterSharedPtr filter) { filter_manager_.addReadFilter(filter); } +void ConnectionImpl::removeReadFilter(ReadFilterSharedPtr filter) { + filter_manager_.removeReadFilter(filter); +} + bool ConnectionImpl::initializeReadFilters() { return filter_manager_.initializeReadFilters(); } void ConnectionImpl::close(ConnectionCloseType type) { @@ -485,7 +489,9 @@ void ConnectionImpl::onWriteBufferLowWatermark() { ASSERT(write_buffer_above_high_watermark_); write_buffer_above_high_watermark_ = false; for (ConnectionCallbacks* callback : callbacks_) { - callback->onBelowWriteBufferLowWatermark(); + if (callback) { + callback->onBelowWriteBufferLowWatermark(); + } } } @@ -494,7 +500,9 @@ void ConnectionImpl::onWriteBufferHighWatermark() { ASSERT(!write_buffer_above_high_watermark_); write_buffer_above_high_watermark_ = true; for (ConnectionCallbacks* callback : callbacks_) { - callback->onAboveWriteBufferHighWatermark(); + if (callback) { + callback->onAboveWriteBufferHighWatermark(); + } } } diff --git a/source/common/network/connection_impl.h b/source/common/network/connection_impl.h index 9fe6d429dc55..ed32c30f24e4 100644 --- a/source/common/network/connection_impl.h +++ b/source/common/network/connection_impl.h @@ -55,6 +55,7 @@ class ConnectionImpl : public ConnectionImplBase, public TransportSocketCallback void addWriteFilter(WriteFilterSharedPtr filter) override; void addFilter(FilterSharedPtr filter) override; void addReadFilter(ReadFilterSharedPtr filter) override; + void removeReadFilter(ReadFilterSharedPtr filter) override; bool initializeReadFilters() override; // Network::Connection @@ -78,6 +79,7 @@ class ConnectionImpl : public ConnectionImplBase, public TransportSocketCallback absl::optional unixSocketPeerCredentials() const override; Ssl::ConnectionInfoConstSharedPtr ssl() const override { return transport_socket_->ssl(); } State state() const override; + bool connecting() const override { return connecting_; } void write(Buffer::Instance& data, bool end_stream) override; void setBufferLimits(uint32_t limit) override; uint32_t bufferLimit() const override { return read_buffer_limit_; } diff --git a/source/common/network/connection_impl_base.cc b/source/common/network/connection_impl_base.cc index e048465a4b35..775b09be13e4 100644 --- a/source/common/network/connection_impl_base.cc +++ b/source/common/network/connection_impl_base.cc @@ -18,6 +18,16 @@ void ConnectionImplBase::addConnectionCallbacks(ConnectionCallbacks& cb) { callbacks_.push_back(&cb); } +void ConnectionImplBase::removeConnectionCallbacks(ConnectionCallbacks& callbacks) { + // For performance/safety reasons we just clear the callback and do not resize the list + for (auto& callback : callbacks_) { + if (callback == &callbacks) { + callback = nullptr; + return; + } + } +} + void ConnectionImplBase::hashKey(std::vector& hash) const { addIdToHashKey(hash, id()); } void ConnectionImplBase::setConnectionStats(const ConnectionStats& stats) { @@ -45,7 +55,9 @@ void ConnectionImplBase::raiseConnectionEvent(ConnectionEvent event) { for (ConnectionCallbacks* callback : callbacks_) { // TODO(mattklein123): If we close while raising a connected event we should not raise further // connected events. - callback->onEvent(event); + if (callback != nullptr) { + callback->onEvent(event); + } } } diff --git a/source/common/network/connection_impl_base.h b/source/common/network/connection_impl_base.h index d0bf93670cff..5bb12eea5a7d 100644 --- a/source/common/network/connection_impl_base.h +++ b/source/common/network/connection_impl_base.h @@ -22,6 +22,7 @@ class ConnectionImplBase : public FilterManagerConnection, // Network::Connection void addConnectionCallbacks(ConnectionCallbacks& cb) override; + void removeConnectionCallbacks(ConnectionCallbacks& cb) override; Event::Dispatcher& dispatcher() override { return dispatcher_; } uint64_t id() const override { return id_; } void hashKey(std::vector& hash) const override; diff --git a/source/common/network/filter_manager_impl.cc b/source/common/network/filter_manager_impl.cc index 593abc098095..5bec16f5f75a 100644 --- a/source/common/network/filter_manager_impl.cc +++ b/source/common/network/filter_manager_impl.cc @@ -28,6 +28,15 @@ void FilterManagerImpl::addReadFilter(ReadFilterSharedPtr filter) { LinkedList::moveIntoListBack(std::move(new_filter), upstream_filters_); } +void FilterManagerImpl::removeReadFilter(ReadFilterSharedPtr filter_to_remove) { + // For perf/safety reasons, null this out rather than removing. + for (auto& filter : upstream_filters_) { + if (filter->filter_ == filter_to_remove) { + filter->filter_ = nullptr; + } + } +} + bool FilterManagerImpl::initializeReadFilters() { if (upstream_filters_.empty()) { return false; @@ -53,6 +62,9 @@ void FilterManagerImpl::onContinueReading(ActiveReadFilter* filter, } for (; entry != upstream_filters_.end(); entry++) { + if (!(*entry)->filter_) { + continue; + } if (!(*entry)->initialized_) { (*entry)->initialized_ = true; FilterStatus status = (*entry)->filter_->onNewConnection(); diff --git a/source/common/network/filter_manager_impl.h b/source/common/network/filter_manager_impl.h index 0975c2ecd7ed..a74ba02c56c5 100644 --- a/source/common/network/filter_manager_impl.h +++ b/source/common/network/filter_manager_impl.h @@ -105,6 +105,7 @@ class FilterManagerImpl { void addWriteFilter(WriteFilterSharedPtr filter); void addFilter(FilterSharedPtr filter); void addReadFilter(ReadFilterSharedPtr filter); + void removeReadFilter(ReadFilterSharedPtr filter); bool initializeReadFilters(); void onRead(); FilterStatus onWrite(); diff --git a/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.cc b/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.cc index e005a3dd7691..3e30e6ec5779 100644 --- a/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.cc +++ b/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.cc @@ -30,6 +30,10 @@ void QuicFilterManagerConnectionImpl::addReadFilter(Network::ReadFilterSharedPtr filter_manager_.addReadFilter(filter); } +void QuicFilterManagerConnectionImpl::removeReadFilter(Network::ReadFilterSharedPtr filter) { + filter_manager_.removeReadFilter(filter); +} + bool QuicFilterManagerConnectionImpl::initializeReadFilters() { return filter_manager_.initializeReadFilters(); } diff --git a/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.h b/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.h index cf049ab5ac52..8f01d03ca6b9 100644 --- a/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.h +++ b/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.h @@ -25,6 +25,7 @@ class QuicFilterManagerConnectionImpl : public Network::ConnectionImplBase { void addWriteFilter(Network::WriteFilterSharedPtr filter) override; void addFilter(Network::FilterSharedPtr filter) override; void addReadFilter(Network::ReadFilterSharedPtr filter) override; + void removeReadFilter(Network::ReadFilterSharedPtr filter) override; bool initializeReadFilters() override; // Network::Connection @@ -63,6 +64,12 @@ class QuicFilterManagerConnectionImpl : public Network::ConnectionImplBase { } return Network::Connection::State::Closed; } + bool connecting() const override { + if (quic_connection_ != nullptr && quic_connection_->connected()) { + return false; + } + return true; + } void write(Buffer::Instance& /*data*/, bool /*end_stream*/) override { // All writes should be handled by Quic internally. NOT_REACHED_GCOVR_EXCL_LINE; diff --git a/source/server/api_listener_impl.h b/source/server/api_listener_impl.h index 4731ea90ca54..b0dd0ef701c3 100644 --- a/source/server/api_listener_impl.h +++ b/source/server/api_listener_impl.h @@ -83,12 +83,18 @@ class ApiListenerImplBase : public ApiListener, } void addFilter(Network::FilterSharedPtr) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } void addReadFilter(Network::ReadFilterSharedPtr) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + void removeReadFilter(Network::ReadFilterSharedPtr) override { + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; + } bool initializeReadFilters() override { return true; } // Network::Connection void addConnectionCallbacks(Network::ConnectionCallbacks& cb) override { callbacks_.push_back(&cb); } + void removeConnectionCallbacks(Network::ConnectionCallbacks& cb) override { + callbacks_.remove(&cb); + } void addBytesSentCallback(Network::Connection::BytesSentCb) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } @@ -121,6 +127,7 @@ class ApiListenerImplBase : public ApiListener, Ssl::ConnectionInfoConstSharedPtr ssl() const override { return nullptr; } absl::string_view requestedServerName() const override { return EMPTY_STRING; } State state() const override { return Network::Connection::State::Open; } + bool connecting() const override { return false; } void write(Buffer::Instance&, bool) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } void setBufferLimits(uint32_t) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } uint32_t bufferLimit() const override { return 65000; } diff --git a/test/common/network/connection_impl_test.cc b/test/common/network/connection_impl_test.cc index 0908d09e5079..821e73ac0d3e 100644 --- a/test/common/network/connection_impl_test.cc +++ b/test/common/network/connection_impl_test.cc @@ -287,6 +287,15 @@ TEST_P(ConnectionImplTest, CloseDuringConnectCallback) { Buffer::OwnedImpl buffer("hello world"); client_connection_->write(buffer, false); client_connection_->connect(); + EXPECT_TRUE(client_connection_->connecting()); + + StrictMock added_and_removed_callbacks; + // Make sure removed connections don't get events. + client_connection_->addConnectionCallbacks(added_and_removed_callbacks); + client_connection_->removeConnectionCallbacks(added_and_removed_callbacks); + + std::shared_ptr add_and_remove_filter = + std::make_shared>(); EXPECT_CALL(client_callbacks_, onEvent(ConnectionEvent::Connected)) .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { @@ -302,6 +311,8 @@ TEST_P(ConnectionImplTest, CloseDuringConnectCallback) { std::move(socket), Network::Test::createRawBufferSocket(), stream_info_); server_connection_->addConnectionCallbacks(server_callbacks_); server_connection_->addReadFilter(read_filter_); + server_connection_->addReadFilter(add_and_remove_filter); + server_connection_->removeReadFilter(add_and_remove_filter); })); EXPECT_CALL(server_callbacks_, onEvent(ConnectionEvent::RemoteClose)) @@ -537,13 +548,22 @@ TEST_P(ConnectionImplTest, ConnectionStats) { MockConnectionStats client_connection_stats; client_connection_->setConnectionStats(client_connection_stats.toBufferStats()); + EXPECT_TRUE(client_connection_->connecting()); client_connection_->connect(); + // The Network::Connection class oddly uses onWrite as its indicator of if + // it's done connection, rather than the Connected event. + EXPECT_TRUE(client_connection_->connecting()); std::shared_ptr write_filter(new MockWriteFilter()); std::shared_ptr filter(new MockFilter()); client_connection_->addFilter(filter); client_connection_->addWriteFilter(write_filter); + // Make sure removed filters don't get callbacks. + std::shared_ptr read_filter(new StrictMock()); + client_connection_->addReadFilter(read_filter); + client_connection_->removeReadFilter(read_filter); + Sequence s1; EXPECT_CALL(*write_filter, onWrite(_, _)) .InSequence(s1) @@ -854,6 +874,11 @@ TEST_P(ConnectionImplTest, WriteWatermarks) { setUpBasicConnection(); EXPECT_FALSE(client_connection_->aboveHighWatermark()); + StrictMock added_and_removed_callbacks; + // Make sure removed connections don't get events. + client_connection_->addConnectionCallbacks(added_and_removed_callbacks); + client_connection_->removeConnectionCallbacks(added_and_removed_callbacks); + // Stick 5 bytes in the connection buffer. std::unique_ptr buffer(new Buffer::OwnedImpl("hello")); int buffer_len = buffer->length(); diff --git a/test/mocks/network/connection.h b/test/mocks/network/connection.h index f1afd5b87eca..d33d4797f992 100644 --- a/test/mocks/network/connection.h +++ b/test/mocks/network/connection.h @@ -48,10 +48,12 @@ class MockConnectionBase { #define DEFINE_MOCK_CONNECTION_MOCK_METHODS \ /* Network::Connection */ \ MOCK_METHOD(void, addConnectionCallbacks, (ConnectionCallbacks & cb)); \ + MOCK_METHOD(void, removeConnectionCallbacks, (ConnectionCallbacks & cb)); \ MOCK_METHOD(void, addBytesSentCallback, (BytesSentCb cb)); \ MOCK_METHOD(void, addWriteFilter, (WriteFilterSharedPtr filter)); \ MOCK_METHOD(void, addFilter, (FilterSharedPtr filter)); \ MOCK_METHOD(void, addReadFilter, (ReadFilterSharedPtr filter)); \ + MOCK_METHOD(void, removeReadFilter, (ReadFilterSharedPtr filter)); \ MOCK_METHOD(void, enableHalfClose, (bool enabled)); \ MOCK_METHOD(void, close, (ConnectionCloseType type)); \ MOCK_METHOD(Event::Dispatcher&, dispatcher, ()); \ @@ -72,6 +74,7 @@ class MockConnectionBase { MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, ssl, (), (const)); \ MOCK_METHOD(absl::string_view, requestedServerName, (), (const)); \ MOCK_METHOD(State, state, (), (const)); \ + MOCK_METHOD(bool, connecting, (), (const)); \ MOCK_METHOD(void, write, (Buffer::Instance & data, bool end_stream)); \ MOCK_METHOD(void, setBufferLimits, (uint32_t limit)); \ MOCK_METHOD(uint32_t, bufferLimit, (), (const)); \ @@ -128,10 +131,12 @@ class MockFilterManagerConnection : public FilterManagerConnection, public MockC // Network::Connection MOCK_METHOD(void, addConnectionCallbacks, (ConnectionCallbacks & cb)); + MOCK_METHOD(void, removeConnectionCallbacks, (ConnectionCallbacks & cb)); MOCK_METHOD(void, addBytesSentCallback, (BytesSentCb cb)); MOCK_METHOD(void, addWriteFilter, (WriteFilterSharedPtr filter)); MOCK_METHOD(void, addFilter, (FilterSharedPtr filter)); MOCK_METHOD(void, addReadFilter, (ReadFilterSharedPtr filter)); + MOCK_METHOD(void, removeReadFilter, (ReadFilterSharedPtr filter)); MOCK_METHOD(void, enableHalfClose, (bool enabled)); MOCK_METHOD(void, close, (ConnectionCloseType type)); MOCK_METHOD(Event::Dispatcher&, dispatcher, ()); @@ -152,6 +157,7 @@ class MockFilterManagerConnection : public FilterManagerConnection, public MockC MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, ssl, (), (const)); MOCK_METHOD(absl::string_view, requestedServerName, (), (const)); MOCK_METHOD(State, state, (), (const)); + MOCK_METHOD(bool, connecting, (), (const)); MOCK_METHOD(void, write, (Buffer::Instance & data, bool end_stream)); MOCK_METHOD(void, setBufferLimits, (uint32_t limit)); MOCK_METHOD(uint32_t, bufferLimit, (), (const));