diff --git a/api/envoy/service/ratelimit/v3/rls.proto b/api/envoy/service/ratelimit/v3/rls.proto index 42f24cfb0805..7379368fd4c1 100644 --- a/api/envoy/service/ratelimit/v3/rls.proto +++ b/api/envoy/service/ratelimit/v3/rls.proto @@ -51,6 +51,7 @@ message RateLimitRequest { } // A response from a ShouldRateLimit call. +// [#next-free-field: 6] message RateLimitResponse { option (udpa.annotations.versioning).previous_message_type = "envoy.service.ratelimit.v2.RateLimitResponse"; @@ -131,4 +132,7 @@ message RateLimitResponse { // A list of headers to add to the request when forwarded repeated config.core.v3.HeaderValue request_headers_to_add = 4; + + // A response body to send to the downstream client when the response code is not OK. + bytes raw_body = 5; } diff --git a/docs/root/version_history/current.rst b/docs/root/version_history/current.rst index d29d29648df2..04177ac1ff68 100644 --- a/docs/root/version_history/current.rst +++ b/docs/root/version_history/current.rst @@ -75,6 +75,7 @@ New Features * overload: add :ref:`envoy.overload_actions.reduce_timeouts ` overload action to enable scaling timeouts down with load. Scaling support :ref:`is limited ` to the HTTP connection and stream idle timeouts. * ratelimit: added support for use of various :ref:`metadata ` as a ratelimit action. * ratelimit: added :ref:`disable_x_envoy_ratelimited_header ` option to disable `X-Envoy-RateLimited` header. +* ratelimit: added :ref:`body ` field to support custom response bodies for non-OK responses from the external ratelimit service. * router: added support for regex rewrites during HTTP redirects using :ref:`regex_rewrite `. * sds: improved support for atomic :ref:`key rotations ` and added configurable rotation triggers for :ref:`TlsCertificate ` and diff --git a/generated_api_shadow/envoy/service/ratelimit/v3/rls.proto b/generated_api_shadow/envoy/service/ratelimit/v3/rls.proto index 42f24cfb0805..7379368fd4c1 100644 --- a/generated_api_shadow/envoy/service/ratelimit/v3/rls.proto +++ b/generated_api_shadow/envoy/service/ratelimit/v3/rls.proto @@ -51,6 +51,7 @@ message RateLimitRequest { } // A response from a ShouldRateLimit call. +// [#next-free-field: 6] message RateLimitResponse { option (udpa.annotations.versioning).previous_message_type = "envoy.service.ratelimit.v2.RateLimitResponse"; @@ -131,4 +132,7 @@ message RateLimitResponse { // A list of headers to add to the request when forwarded repeated config.core.v3.HeaderValue request_headers_to_add = 4; + + // A response body to send to the downstream client when the response code is not OK. + bytes raw_body = 5; } diff --git a/source/extensions/filters/common/ratelimit/ratelimit.h b/source/extensions/filters/common/ratelimit/ratelimit.h index 068cd369b643..964cc25a7bd6 100644 --- a/source/extensions/filters/common/ratelimit/ratelimit.h +++ b/source/extensions/filters/common/ratelimit/ratelimit.h @@ -44,12 +44,20 @@ class RequestCallbacks { virtual ~RequestCallbacks() = default; /** - * Called when a limit request is complete. The resulting status, - * response headers and request headers to be forwarded to the upstream are supplied. + * Called when a limit request is complete. The resulting status, response headers + * and request headers to be forwarded to the upstream are supplied. + * + * @status The ratelimit status + * @descriptor_statuses The descriptor statuses + * @response_headers_to_add The headers to add to the downstream response, for non-OK statuses + * @request_headers_to_add The headers to add to the upstream request, if not ratelimited + * @response_body The response body to use for the downstream response, for non-OK statuses. May + * contain non UTF-8 values (e.g. binary data). */ virtual void complete(LimitStatus status, DescriptorStatusListPtr&& descriptor_statuses, Http::ResponseHeaderMapPtr&& response_headers_to_add, - Http::RequestHeaderMapPtr&& request_headers_to_add) PURE; + Http::RequestHeaderMapPtr&& request_headers_to_add, + const std::string& response_body) PURE; }; /** diff --git a/source/extensions/filters/common/ratelimit/ratelimit_impl.cc b/source/extensions/filters/common/ratelimit/ratelimit_impl.cc index d4c3f5afdaa3..4aef806f60d1 100644 --- a/source/extensions/filters/common/ratelimit/ratelimit_impl.cc +++ b/source/extensions/filters/common/ratelimit/ratelimit_impl.cc @@ -106,14 +106,14 @@ void GrpcClientImpl::onSuccess( DescriptorStatusListPtr descriptor_statuses = std::make_unique( response->statuses().begin(), response->statuses().end()); callbacks_->complete(status, std::move(descriptor_statuses), std::move(response_headers_to_add), - std::move(request_headers_to_add)); + std::move(request_headers_to_add), response->raw_body()); callbacks_ = nullptr; } void GrpcClientImpl::onFailure(Grpc::Status::GrpcStatus status, const std::string&, Tracing::Span&) { ASSERT(status != Grpc::Status::WellKnownGrpcStatus::Ok); - callbacks_->complete(LimitStatus::Error, nullptr, nullptr, nullptr); + callbacks_->complete(LimitStatus::Error, nullptr, nullptr, nullptr, EMPTY_STRING); callbacks_ = nullptr; } diff --git a/source/extensions/filters/http/ratelimit/ratelimit.cc b/source/extensions/filters/http/ratelimit/ratelimit.cc index 8430f47243a8..b72bb82c4b75 100644 --- a/source/extensions/filters/http/ratelimit/ratelimit.cc +++ b/source/extensions/filters/http/ratelimit/ratelimit.cc @@ -115,7 +115,7 @@ Http::FilterHeadersStatus Filter::encode100ContinueHeaders(Http::ResponseHeaderM } Http::FilterHeadersStatus Filter::encodeHeaders(Http::ResponseHeaderMap& headers, bool) { - populateResponseHeaders(headers); + populateResponseHeaders(headers, /*from_local_reply=*/false); return Http::FilterHeadersStatus::Continue; } @@ -143,7 +143,8 @@ void Filter::onDestroy() { void Filter::complete(Filters::Common::RateLimit::LimitStatus status, Filters::Common::RateLimit::DescriptorStatusListPtr&& descriptor_statuses, Http::ResponseHeaderMapPtr&& response_headers_to_add, - Http::RequestHeaderMapPtr&& request_headers_to_add) { + Http::RequestHeaderMapPtr&& request_headers_to_add, + const std::string& response_body) { state_ = State::Complete; response_headers_to_add_ = std::move(response_headers_to_add); Http::HeaderMapPtr req_headers_to_add = std::move(request_headers_to_add); @@ -195,8 +196,10 @@ void Filter::complete(Filters::Common::RateLimit::LimitStatus status, config_->runtime().snapshot().featureEnabled("ratelimit.http_filter_enforcing", 100)) { state_ = State::Responded; callbacks_->sendLocalReply( - Http::Code::TooManyRequests, "", - [this](Http::HeaderMap& headers) { populateResponseHeaders(headers); }, + Http::Code::TooManyRequests, response_body, + [this](Http::HeaderMap& headers) { + populateResponseHeaders(headers, /*from_local_reply=*/true); + }, config_->rateLimitedGrpcStatus(), RcDetails::get().RateLimited); callbacks_->streamInfo().setResponseFlag(StreamInfo::ResponseFlag::RateLimited); } else if (status == Filters::Common::RateLimit::LimitStatus::Error) { @@ -208,8 +211,8 @@ void Filter::complete(Filters::Common::RateLimit::LimitStatus status, } } else { state_ = State::Responded; - callbacks_->sendLocalReply(Http::Code::InternalServerError, "", nullptr, absl::nullopt, - RcDetails::get().RateLimitError); + callbacks_->sendLocalReply(Http::Code::InternalServerError, response_body, nullptr, + absl::nullopt, RcDetails::get().RateLimitError); callbacks_->streamInfo().setResponseFlag(StreamInfo::ResponseFlag::RateLimitServiceError); } } else if (!initiating_call_) { @@ -236,8 +239,18 @@ void Filter::populateRateLimitDescriptors(const Router::RateLimitPolicy& rate_li } } -void Filter::populateResponseHeaders(Http::HeaderMap& response_headers) { +void Filter::populateResponseHeaders(Http::HeaderMap& response_headers, bool from_local_reply) { if (response_headers_to_add_) { + // If the ratelimit service is sending back the content-type header and we're + // populating response headers for a local reply, overwrite the existing + // content-type header. + // + // We do this because sendLocalReply initially sets content-type to text/plain + // whenever the response body is non-empty, but we want the content-type coming + // from the ratelimit service to be authoritative in this case. + if (from_local_reply && !response_headers_to_add_->getContentTypeValue().empty()) { + response_headers.remove(Http::Headers::get().ContentType); + } Http::HeaderUtility::addHeaders(response_headers, *response_headers_to_add_); response_headers_to_add_ = nullptr; } diff --git a/source/extensions/filters/http/ratelimit/ratelimit.h b/source/extensions/filters/http/ratelimit/ratelimit.h index 058eb793569a..5623cd2a9840 100644 --- a/source/extensions/filters/http/ratelimit/ratelimit.h +++ b/source/extensions/filters/http/ratelimit/ratelimit.h @@ -148,7 +148,8 @@ class Filter : public Http::StreamFilter, public Filters::Common::RateLimit::Req void complete(Filters::Common::RateLimit::LimitStatus status, Filters::Common::RateLimit::DescriptorStatusListPtr&& descriptor_statuses, Http::ResponseHeaderMapPtr&& response_headers_to_add, - Http::RequestHeaderMapPtr&& request_headers_to_add) override; + Http::RequestHeaderMapPtr&& request_headers_to_add, + const std::string& response_body) override; private: void initiateCall(const Http::RequestHeaderMap& headers); @@ -156,7 +157,7 @@ class Filter : public Http::StreamFilter, public Filters::Common::RateLimit::Req std::vector& descriptors, const Router::RouteEntry* route_entry, const Http::HeaderMap& headers) const; - void populateResponseHeaders(Http::HeaderMap& response_headers); + void populateResponseHeaders(Http::HeaderMap& response_headers, bool from_local_reply); void appendRequestHeaders(Http::HeaderMapPtr& request_headers_to_add); VhRateLimitOptions getVirtualHostRateLimitOption(const Router::RouteConstSharedPtr& route); diff --git a/source/extensions/filters/network/ratelimit/ratelimit.cc b/source/extensions/filters/network/ratelimit/ratelimit.cc index 00ed50a9f60c..01d1ed73324e 100644 --- a/source/extensions/filters/network/ratelimit/ratelimit.cc +++ b/source/extensions/filters/network/ratelimit/ratelimit.cc @@ -72,7 +72,8 @@ void Filter::onEvent(Network::ConnectionEvent event) { void Filter::complete(Filters::Common::RateLimit::LimitStatus status, Filters::Common::RateLimit::DescriptorStatusListPtr&&, - Http::ResponseHeaderMapPtr&&, Http::RequestHeaderMapPtr&&) { + Http::ResponseHeaderMapPtr&&, Http::RequestHeaderMapPtr&&, + const std::string&) { status_ = Status::Complete; config_->stats().active_.dec(); diff --git a/source/extensions/filters/network/ratelimit/ratelimit.h b/source/extensions/filters/network/ratelimit/ratelimit.h index eba34f434867..d1029bfbf295 100644 --- a/source/extensions/filters/network/ratelimit/ratelimit.h +++ b/source/extensions/filters/network/ratelimit/ratelimit.h @@ -94,7 +94,8 @@ class Filter : public Network::ReadFilter, void complete(Filters::Common::RateLimit::LimitStatus status, Filters::Common::RateLimit::DescriptorStatusListPtr&& descriptor_statuses, Http::ResponseHeaderMapPtr&& response_headers_to_add, - Http::RequestHeaderMapPtr&& request_headers_to_add) override; + Http::RequestHeaderMapPtr&& request_headers_to_add, + const std::string& response_body) override; private: enum class Status { NotStarted, Calling, Complete }; diff --git a/source/extensions/filters/network/thrift_proxy/filters/ratelimit/ratelimit.cc b/source/extensions/filters/network/thrift_proxy/filters/ratelimit/ratelimit.cc index c775dff93534..a9cf61aca3e3 100644 --- a/source/extensions/filters/network/thrift_proxy/filters/ratelimit/ratelimit.cc +++ b/source/extensions/filters/network/thrift_proxy/filters/ratelimit/ratelimit.cc @@ -62,7 +62,7 @@ void Filter::onDestroy() { void Filter::complete(Filters::Common::RateLimit::LimitStatus status, Filters::Common::RateLimit::DescriptorStatusListPtr&& descriptor_statuses, Http::ResponseHeaderMapPtr&& response_headers_to_add, - Http::RequestHeaderMapPtr&& request_headers_to_add) { + Http::RequestHeaderMapPtr&& request_headers_to_add, const std::string&) { // TODO(zuercher): Store headers to append to a response. Adding them to a local reply (over // limit or error) is a matter of modifying the callbacks to allow it. Adding them to an upstream // response requires either response (aka encoder) filters or some other mechanism. diff --git a/source/extensions/filters/network/thrift_proxy/filters/ratelimit/ratelimit.h b/source/extensions/filters/network/thrift_proxy/filters/ratelimit/ratelimit.h index caa5333cda65..abecfb38ac56 100644 --- a/source/extensions/filters/network/thrift_proxy/filters/ratelimit/ratelimit.h +++ b/source/extensions/filters/network/thrift_proxy/filters/ratelimit/ratelimit.h @@ -79,7 +79,8 @@ class Filter : public ThriftProxy::ThriftFilters::PassThroughDecoderFilter, void complete(Filters::Common::RateLimit::LimitStatus status, Filters::Common::RateLimit::DescriptorStatusListPtr&& descriptor_statuses, Http::ResponseHeaderMapPtr&& response_headers_to_add, - Http::RequestHeaderMapPtr&& request_headers_to_add) override; + Http::RequestHeaderMapPtr&& request_headers_to_add, + const std::string& response_body) override; private: void initiateCall(const ThriftProxy::MessageMetadata& metadata); diff --git a/test/common/network/filter_manager_impl_test.cc b/test/common/network/filter_manager_impl_test.cc index b2e8d4f809a2..6655c7b6388d 100644 --- a/test/common/network/filter_manager_impl_test.cc +++ b/test/common/network/filter_manager_impl_test.cc @@ -416,7 +416,7 @@ stat_prefix: name .WillOnce(Return(&conn_pool)); request_callbacks->complete(Extensions::Filters::Common::RateLimit::LimitStatus::OK, nullptr, - nullptr, nullptr); + nullptr, nullptr, ""); conn_pool.poolReady(upstream_connection); diff --git a/test/extensions/filters/common/ratelimit/ratelimit_impl_test.cc b/test/extensions/filters/common/ratelimit/ratelimit_impl_test.cc index 319596f436a9..1d514e0673c1 100644 --- a/test/extensions/filters/common/ratelimit/ratelimit_impl_test.cc +++ b/test/extensions/filters/common/ratelimit/ratelimit_impl_test.cc @@ -38,15 +38,17 @@ class MockRequestCallbacks : public RequestCallbacks { public: void complete(LimitStatus status, DescriptorStatusListPtr&& descriptor_statuses, Http::ResponseHeaderMapPtr&& response_headers_to_add, - Http::RequestHeaderMapPtr&& request_headers_to_add) override { + Http::RequestHeaderMapPtr&& request_headers_to_add, + const std::string& response_body) override { complete_(status, descriptor_statuses.get(), response_headers_to_add.get(), - request_headers_to_add.get()); + request_headers_to_add.get(), response_body); } MOCK_METHOD(void, complete_, (LimitStatus status, const DescriptorStatusList* descriptor_statuses, const Http::ResponseHeaderMap* response_headers_to_add, - const Http::RequestHeaderMap* request_headers_to_add)); + const Http::RequestHeaderMap* request_headers_to_add, + const std::string& response_body)); }; class RateLimitGrpcClientTest : public testing::Test { @@ -91,7 +93,7 @@ TEST_F(RateLimitGrpcClientTest, Basic) { response = std::make_unique(); response->set_overall_code(envoy::service::ratelimit::v3::RateLimitResponse::OVER_LIMIT); EXPECT_CALL(span_, setTag(Eq("ratelimit_status"), Eq("over_limit"))); - EXPECT_CALL(request_callbacks_, complete_(LimitStatus::OverLimit, _, _, _)); + EXPECT_CALL(request_callbacks_, complete_(LimitStatus::OverLimit, _, _, _, _)); client_.onSuccess(std::move(response), span_); } @@ -110,7 +112,7 @@ TEST_F(RateLimitGrpcClientTest, Basic) { response = std::make_unique(); response->set_overall_code(envoy::service::ratelimit::v3::RateLimitResponse::OK); EXPECT_CALL(span_, setTag(Eq("ratelimit_status"), Eq("ok"))); - EXPECT_CALL(request_callbacks_, complete_(LimitStatus::OK, _, _, _)); + EXPECT_CALL(request_callbacks_, complete_(LimitStatus::OK, _, _, _, _)); client_.onSuccess(std::move(response), span_); } @@ -127,7 +129,7 @@ TEST_F(RateLimitGrpcClientTest, Basic) { Tracing::NullSpan::instance(), stream_info_); response = std::make_unique(); - EXPECT_CALL(request_callbacks_, complete_(LimitStatus::Error, _, _, _)); + EXPECT_CALL(request_callbacks_, complete_(LimitStatus::Error, _, _, _, _)); client_.onFailure(Grpc::Status::Unknown, "", span_); } @@ -150,7 +152,7 @@ TEST_F(RateLimitGrpcClientTest, Basic) { response = std::make_unique(); response->set_overall_code(envoy::service::ratelimit::v3::RateLimitResponse::OK); EXPECT_CALL(span_, setTag(Eq("ratelimit_status"), Eq("ok"))); - EXPECT_CALL(request_callbacks_, complete_(LimitStatus::OK, _, _, _)); + EXPECT_CALL(request_callbacks_, complete_(LimitStatus::OK, _, _, _, _)); client_.onSuccess(std::move(response), span_); } } diff --git a/test/extensions/filters/http/ratelimit/ratelimit_test.cc b/test/extensions/filters/http/ratelimit/ratelimit_test.cc index a14a2c4463d6..25462a3ce725 100644 --- a/test/extensions/filters/http/ratelimit/ratelimit_test.cc +++ b/test/extensions/filters/http/ratelimit/ratelimit_test.cc @@ -234,7 +234,7 @@ TEST_F(HttpRateLimitFilterTest, OkResponse) { setResponseFlag(StreamInfo::ResponseFlag::RateLimited)) .Times(0); request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::OK, nullptr, nullptr, - nullptr); + nullptr, ""); EXPECT_EQ( 1U, filter_callbacks_.clusterInfo()->statsScope().counterFromStatName(ratelimit_ok_).value()); @@ -285,7 +285,7 @@ TEST_F(HttpRateLimitFilterTest, OkResponseWithHeaders) { request_callbacks_->complete( Filters::Common::RateLimit::LimitStatus::OK, nullptr, Http::ResponseHeaderMapPtr{new Http::TestResponseHeaderMapImpl(*rl_headers)}, - Http::RequestHeaderMapPtr{new Http::TestRequestHeaderMapImpl(*request_headers_to_add)}); + Http::RequestHeaderMapPtr{new Http::TestRequestHeaderMapImpl(*request_headers_to_add)}, ""); Http::TestResponseHeaderMapImpl expected_headers(*rl_headers); Http::TestResponseHeaderMapImpl response_headers; EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers, false)); @@ -341,7 +341,7 @@ TEST_F(HttpRateLimitFilterTest, OkResponseWithFilterHeaders) { auto descriptor_statuses_ptr = std::make_unique(descriptor_statuses); request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::OK, - std::move(descriptor_statuses_ptr), nullptr, nullptr); + std::move(descriptor_statuses_ptr), nullptr, nullptr, ""); Http::TestResponseHeaderMapImpl expected_headers{ {"x-ratelimit-limit", "1, 1;w=60;name=\"first\", 4;w=3600;name=\"second\""}, @@ -368,7 +368,7 @@ TEST_F(HttpRateLimitFilterTest, ImmediateOkResponse) { .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::RateLimit::RequestCallbacks& callbacks) -> void { callbacks.complete(Filters::Common::RateLimit::LimitStatus::OK, nullptr, nullptr, - nullptr); + nullptr, ""); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); @@ -399,7 +399,7 @@ TEST_F(HttpRateLimitFilterTest, ImmediateErrorResponse) { .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::RateLimit::RequestCallbacks& callbacks) -> void { callbacks.complete(Filters::Common::RateLimit::LimitStatus::Error, nullptr, nullptr, - nullptr); + nullptr, ""); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); @@ -438,7 +438,7 @@ TEST_F(HttpRateLimitFilterTest, ErrorResponse) { EXPECT_CALL(filter_callbacks_, continueDecoding()); request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::Error, nullptr, nullptr, - nullptr); + nullptr, ""); EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data_, false)); EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_trailers_)); @@ -471,7 +471,7 @@ TEST_F(HttpRateLimitFilterTest, ErrorResponseWithFailureModeAllowOff) { filter_->decodeHeaders(request_headers_, false)); request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::Error, nullptr, nullptr, - nullptr); + nullptr, ""); EXPECT_CALL(filter_callbacks_.stream_info_, setResponseFlag(StreamInfo::ResponseFlag::RateLimitServiceError)) @@ -512,7 +512,7 @@ TEST_F(HttpRateLimitFilterTest, LimitResponse) { setResponseFlag(StreamInfo::ResponseFlag::RateLimited)); request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::OverLimit, nullptr, - std::move(h), nullptr); + std::move(h), nullptr, ""); EXPECT_EQ(1U, filter_callbacks_.clusterInfo() ->statsScope() @@ -564,7 +564,139 @@ TEST_F(HttpRateLimitFilterTest, LimitResponseWithHeaders) { Http::ResponseHeaderMapPtr h{new Http::TestResponseHeaderMapImpl(*rl_headers)}; Http::RequestHeaderMapPtr uh{new Http::TestRequestHeaderMapImpl(*request_headers_to_add)}; request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::OverLimit, nullptr, - std::move(h), std::move(uh)); + std::move(h), std::move(uh), ""); + + EXPECT_THAT(*request_headers_to_add, Not(IsSubsetOfHeaders(request_headers_))); + EXPECT_EQ(1U, filter_callbacks_.clusterInfo() + ->statsScope() + .counterFromStatName(ratelimit_over_limit_) + .value()); + EXPECT_EQ( + 1U, + filter_callbacks_.clusterInfo()->statsScope().counterFromStatName(upstream_rq_4xx_).value()); + EXPECT_EQ( + 1U, + filter_callbacks_.clusterInfo()->statsScope().counterFromStatName(upstream_rq_429_).value()); +} + +TEST_F(HttpRateLimitFilterTest, LimitResponseWithBody) { + SetUpTest(filter_config_); + InSequence s; + + EXPECT_CALL(route_rate_limit_, populateDescriptors(_, _, _, _, _, _)) + .WillOnce(SetArgReferee<1>(descriptor_)); + EXPECT_CALL(*client_, limit(_, _, _, _, _)) + .WillOnce( + WithArgs<0>(Invoke([&](Filters::Common::RateLimit::RequestCallbacks& callbacks) -> void { + request_callbacks_ = &callbacks; + }))); + + EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, + filter_->decodeHeaders(request_headers_, false)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_trailers_)); + + const std::string response_body = "this is a custom over limit response body."; + const std::string content_length = std::to_string(response_body.length()); + Http::HeaderMapPtr rl_headers{new Http::TestResponseHeaderMapImpl{ + {"x-ratelimit-limit", "1000"}, {"x-ratelimit-remaining", "0"}, {"retry-after", "33"}}}; + Http::TestResponseHeaderMapImpl expected_headers{}; + // We construct the expected_headers map in careful order, because HeaderMapEqualRef below + // compares two header maps in order. In practice, content-length and content-type headers + // are added before additional ratelimit headers and the final x-envoy-ratelimited header. + expected_headers.addCopy(":status", "429"); + expected_headers.addCopy("content-length", std::string(content_length)); + expected_headers.addCopy("content-type", "text/plain"); + expected_headers.copyFrom(*rl_headers); + expected_headers.addCopy("x-envoy-ratelimited", Http::Headers::get().EnvoyRateLimitedValues.True); + + EXPECT_CALL(filter_callbacks_, encodeHeaders_(HeaderMapEqualRef(&expected_headers), false)); + EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); + EXPECT_CALL(filter_callbacks_, encodeData(_, true)) + .WillOnce( + Invoke([&](Buffer::Instance& data, bool) { EXPECT_EQ(data.toString(), response_body); })); + EXPECT_CALL(filter_callbacks_.stream_info_, + setResponseFlag(StreamInfo::ResponseFlag::RateLimited)); + + Http::HeaderMapPtr request_headers_to_add{ + new Http::TestRequestHeaderMapImpl{{"x-rls-rate-limited", "true"}}}; + + Http::ResponseHeaderMapPtr h{new Http::TestResponseHeaderMapImpl(*rl_headers)}; + Http::RequestHeaderMapPtr uh{new Http::TestRequestHeaderMapImpl(*request_headers_to_add)}; + request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::OverLimit, nullptr, + std::move(h), std::move(uh), response_body); + + EXPECT_THAT(*request_headers_to_add, Not(IsSubsetOfHeaders(request_headers_))); + EXPECT_EQ(1U, filter_callbacks_.clusterInfo() + ->statsScope() + .counterFromStatName(ratelimit_over_limit_) + .value()); + EXPECT_EQ( + 1U, + filter_callbacks_.clusterInfo()->statsScope().counterFromStatName(upstream_rq_4xx_).value()); + EXPECT_EQ( + 1U, + filter_callbacks_.clusterInfo()->statsScope().counterFromStatName(upstream_rq_429_).value()); +} + +TEST_F(HttpRateLimitFilterTest, LimitResponseWithBodyAndContentType) { + SetUpTest(filter_config_); + InSequence s; + + EXPECT_CALL(route_rate_limit_, populateDescriptors(_, _, _, _, _, _)) + .WillOnce(SetArgReferee<1>(descriptor_)); + EXPECT_CALL(*client_, limit(_, _, _, _, _)) + .WillOnce( + WithArgs<0>(Invoke([&](Filters::Common::RateLimit::RequestCallbacks& callbacks) -> void { + request_callbacks_ = &callbacks; + }))); + + EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, + filter_->decodeHeaders(request_headers_, false)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_trailers_)); + + const std::string response_body = R"EOF( + { "message": "this is a custom over limit response body as json.", "retry-after": "33" } + )EOF"; + const std::string content_length = std::to_string(response_body.length()); + Http::HeaderMapPtr rl_headers{ + new Http::TestResponseHeaderMapImpl{{"content-type", "application/json"}, + {"x-ratelimit-limit", "1000"}, + {"x-ratelimit-remaining", "0"}, + {"retry-after", "33"}}}; + Http::TestResponseHeaderMapImpl expected_headers{}; + // We construct the expected_headers map in careful order, because HeaderMapEqualRef below + // compares two header maps in order. In practice, content-length and content-type headers + // are added before additional ratelimit headers and the final x-envoy-ratelimited header. + // Additionally, we skip explicitly adding content-type here because it's already part of + // `rl_headers` above. + expected_headers.addCopy(":status", "429"); + expected_headers.addCopy("content-length", std::string(content_length)); + expected_headers.copyFrom(*rl_headers); + expected_headers.addCopy("x-envoy-ratelimited", Http::Headers::get().EnvoyRateLimitedValues.True); + + EXPECT_CALL(filter_callbacks_, encodeHeaders_(HeaderMapEqualRef(&expected_headers), false)); + EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); + EXPECT_CALL(filter_callbacks_, encodeData(_, true)) + .WillOnce( + Invoke([&](Buffer::Instance& data, bool) { EXPECT_EQ(data.toString(), response_body); })); + EXPECT_CALL(filter_callbacks_.stream_info_, + setResponseFlag(StreamInfo::ResponseFlag::RateLimited)); + + Http::HeaderMapPtr request_headers_to_add{ + new Http::TestRequestHeaderMapImpl{{"x-rls-rate-limited", "true"}}}; + + Http::ResponseHeaderMapPtr h{new Http::TestResponseHeaderMapImpl(*rl_headers)}; + Http::RequestHeaderMapPtr uh{new Http::TestRequestHeaderMapImpl(*request_headers_to_add)}; + request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::OverLimit, nullptr, + std::move(h), std::move(uh), response_body); EXPECT_THAT(*request_headers_to_add, Not(IsSubsetOfHeaders(request_headers_))); EXPECT_EQ(1U, filter_callbacks_.clusterInfo() @@ -618,7 +750,7 @@ TEST_F(HttpRateLimitFilterTest, LimitResponseWithFilterHeaders) { auto descriptor_statuses_ptr = std::make_unique(descriptor_statuses); request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::OverLimit, - std::move(descriptor_statuses_ptr), nullptr, nullptr); + std::move(descriptor_statuses_ptr), nullptr, nullptr, ""); EXPECT_EQ(1U, filter_callbacks_.clusterInfo() ->statsScope() .counterFromStatName(ratelimit_over_limit_) @@ -654,7 +786,7 @@ TEST_F(HttpRateLimitFilterTest, LimitResponseWithoutEnvoyRateLimitedHeader) { setResponseFlag(StreamInfo::ResponseFlag::RateLimited)); request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::OverLimit, nullptr, - std::move(h), nullptr); + std::move(h), nullptr, ""); EXPECT_EQ(1U, filter_callbacks_.clusterInfo() ->statsScope() @@ -689,7 +821,7 @@ TEST_F(HttpRateLimitFilterTest, LimitResponseRuntimeDisabled) { EXPECT_CALL(filter_callbacks_, continueDecoding()); Http::ResponseHeaderMapPtr h{new Http::TestResponseHeaderMapImpl()}; request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::OverLimit, nullptr, - std::move(h), nullptr); + std::move(h), nullptr, ""); EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data_, false)); EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_trailers_)); @@ -839,7 +971,7 @@ TEST_F(HttpRateLimitFilterTest, InternalRequestType) { .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::RateLimit::RequestCallbacks& callbacks) -> void { callbacks.complete(Filters::Common::RateLimit::LimitStatus::OK, nullptr, nullptr, - nullptr); + nullptr, ""); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); @@ -883,7 +1015,7 @@ TEST_F(HttpRateLimitFilterTest, ExternalRequestType) { .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::RateLimit::RequestCallbacks& callbacks) -> void { callbacks.complete(Filters::Common::RateLimit::LimitStatus::OK, nullptr, nullptr, - nullptr); + nullptr, ""); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); @@ -937,7 +1069,7 @@ TEST_F(HttpRateLimitFilterTest, DEPRECATED_FEATURE_TEST(ExcludeVirtualHost)) { .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::RateLimit::RequestCallbacks& callbacks) -> void { callbacks.complete(Filters::Common::RateLimit::LimitStatus::OK, nullptr, nullptr, - nullptr); + nullptr, ""); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); @@ -988,7 +1120,7 @@ TEST_F(HttpRateLimitFilterTest, OverrideVHRateLimitOptionWithRouteRateLimitSet) .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::RateLimit::RequestCallbacks& callbacks) -> void { callbacks.complete(Filters::Common::RateLimit::LimitStatus::OK, nullptr, nullptr, - nullptr); + nullptr, ""); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); @@ -1039,7 +1171,7 @@ TEST_F(HttpRateLimitFilterTest, OverrideVHRateLimitOptionWithoutRouteRateLimit) .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::RateLimit::RequestCallbacks& callbacks) -> void { callbacks.complete(Filters::Common::RateLimit::LimitStatus::OK, nullptr, nullptr, - nullptr); + nullptr, ""); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); @@ -1087,7 +1219,7 @@ TEST_F(HttpRateLimitFilterTest, IncludeVHRateLimitOptionWithOnlyVHRateLimitSet) .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::RateLimit::RequestCallbacks& callbacks) -> void { callbacks.complete(Filters::Common::RateLimit::LimitStatus::OK, nullptr, nullptr, - nullptr); + nullptr, ""); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); @@ -1137,7 +1269,7 @@ TEST_F(HttpRateLimitFilterTest, IncludeVHRateLimitOptionWithRouteAndVHRateLimitS .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::RateLimit::RequestCallbacks& callbacks) -> void { callbacks.complete(Filters::Common::RateLimit::LimitStatus::OK, nullptr, nullptr, - nullptr); + nullptr, ""); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); @@ -1185,7 +1317,7 @@ TEST_F(HttpRateLimitFilterTest, IgnoreVHRateLimitOptionWithRouteRateLimitSet) { .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::RateLimit::RequestCallbacks& callbacks) -> void { callbacks.complete(Filters::Common::RateLimit::LimitStatus::OK, nullptr, nullptr, - nullptr); + nullptr, ""); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); diff --git a/test/extensions/filters/network/ratelimit/ratelimit_test.cc b/test/extensions/filters/network/ratelimit/ratelimit_test.cc index f2255e356cd2..019ad7000909 100644 --- a/test/extensions/filters/network/ratelimit/ratelimit_test.cc +++ b/test/extensions/filters/network/ratelimit/ratelimit_test.cc @@ -115,7 +115,7 @@ TEST_F(RateLimitFilterTest, OK) { EXPECT_CALL(filter_callbacks_, continueReading()); request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::OK, nullptr, nullptr, - nullptr); + nullptr, ""); EXPECT_EQ(Network::FilterStatus::Continue, filter_->onData(data, false)); @@ -143,7 +143,7 @@ TEST_F(RateLimitFilterTest, OverLimit) { EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush)); EXPECT_CALL(*client_, cancel()).Times(0); request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::OverLimit, nullptr, nullptr, - nullptr); + nullptr, ""); EXPECT_EQ(Network::FilterStatus::Continue, filter_->onData(data, false)); @@ -172,7 +172,7 @@ TEST_F(RateLimitFilterTest, OverLimitNotEnforcing) { EXPECT_CALL(*client_, cancel()).Times(0); EXPECT_CALL(filter_callbacks_, continueReading()); request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::OverLimit, nullptr, nullptr, - nullptr); + nullptr, ""); EXPECT_EQ(Network::FilterStatus::Continue, filter_->onData(data, false)); @@ -197,7 +197,7 @@ TEST_F(RateLimitFilterTest, Error) { EXPECT_CALL(filter_callbacks_, continueReading()); request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::Error, nullptr, nullptr, - nullptr); + nullptr, ""); EXPECT_EQ(Network::FilterStatus::Continue, filter_->onData(data, false)); @@ -238,7 +238,7 @@ TEST_F(RateLimitFilterTest, ImmediateOK) { .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::RateLimit::RequestCallbacks& callbacks) -> void { callbacks.complete(Filters::Common::RateLimit::LimitStatus::OK, nullptr, nullptr, - nullptr); + nullptr, ""); }))); EXPECT_EQ(Network::FilterStatus::Continue, filter_->onNewConnection()); @@ -262,7 +262,7 @@ TEST_F(RateLimitFilterTest, ImmediateError) { .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::RateLimit::RequestCallbacks& callbacks) -> void { callbacks.complete(Filters::Common::RateLimit::LimitStatus::Error, nullptr, nullptr, - nullptr); + nullptr, ""); }))); EXPECT_EQ(Network::FilterStatus::Continue, filter_->onNewConnection()); @@ -305,7 +305,7 @@ TEST_F(RateLimitFilterTest, ErrorResponseWithFailureModeAllowOff) { Buffer::OwnedImpl data("hello"); EXPECT_EQ(Network::FilterStatus::StopIteration, filter_->onData(data, false)); request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::Error, nullptr, nullptr, - nullptr); + nullptr, ""); EXPECT_EQ(Network::FilterStatus::Continue, filter_->onData(data, false)); diff --git a/test/extensions/filters/network/thrift_proxy/filters/ratelimit/ratelimit_test.cc b/test/extensions/filters/network/thrift_proxy/filters/ratelimit/ratelimit_test.cc index 131bb6513f58..c56c9383b330 100644 --- a/test/extensions/filters/network/thrift_proxy/filters/ratelimit/ratelimit_test.cc +++ b/test/extensions/filters/network/thrift_proxy/filters/ratelimit/ratelimit_test.cc @@ -227,7 +227,7 @@ TEST_F(ThriftRateLimitFilterTest, OkResponse) { setResponseFlag(StreamInfo::ResponseFlag::RateLimited)) .Times(0); request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::OK, nullptr, nullptr, - nullptr); + nullptr, ""); EXPECT_EQ(1U, cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ratelimit.ok").value()); @@ -247,7 +247,7 @@ TEST_F(ThriftRateLimitFilterTest, ImmediateOkResponse) { .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::RateLimit::RequestCallbacks& callbacks) -> void { callbacks.complete(Filters::Common::RateLimit::LimitStatus::OK, nullptr, nullptr, - nullptr); + nullptr, ""); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); @@ -271,7 +271,7 @@ TEST_F(ThriftRateLimitFilterTest, ImmediateErrorResponse) { .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::RateLimit::RequestCallbacks& callbacks) -> void { callbacks.complete(Filters::Common::RateLimit::LimitStatus::Error, nullptr, nullptr, - nullptr); + nullptr, ""); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); @@ -301,7 +301,7 @@ TEST_F(ThriftRateLimitFilterTest, ErrorResponse) { EXPECT_CALL(filter_callbacks_, continueDecoding()); request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::Error, nullptr, nullptr, - nullptr); + nullptr, ""); EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->messageEnd()); EXPECT_CALL(filter_callbacks_.stream_info_, @@ -339,7 +339,7 @@ TEST_F(ThriftRateLimitFilterTest, ErrorResponseWithFailureModeAllowOff) { EXPECT_CALL(filter_callbacks_.stream_info_, setResponseFlag(StreamInfo::ResponseFlag::RateLimitServiceError)); request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::Error, nullptr, nullptr, - nullptr); + nullptr, ""); EXPECT_EQ( 1U, @@ -373,7 +373,7 @@ TEST_F(ThriftRateLimitFilterTest, LimitResponse) { EXPECT_CALL(filter_callbacks_.stream_info_, setResponseFlag(StreamInfo::ResponseFlag::RateLimited)); request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::OverLimit, nullptr, nullptr, - nullptr); + nullptr, ""); EXPECT_EQ(1U, cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ratelimit.over_limit") @@ -406,7 +406,7 @@ TEST_F(ThriftRateLimitFilterTest, LimitResponseWithHeaders) { Http::ResponseHeaderMapPtr h{new Http::TestResponseHeaderMapImpl(*rl_headers)}; request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::OverLimit, nullptr, - std::move(h), nullptr); + std::move(h), nullptr, ""); EXPECT_EQ(1U, cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ratelimit.over_limit") @@ -431,7 +431,7 @@ TEST_F(ThriftRateLimitFilterTest, LimitResponseRuntimeDisabled) { .WillOnce(Return(false)); EXPECT_CALL(filter_callbacks_, continueDecoding()); request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::OverLimit, nullptr, nullptr, - nullptr); + nullptr, ""); EXPECT_EQ(1U, cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ratelimit.over_limit") diff --git a/test/test_common/utility.h b/test/test_common/utility.h index 6b4abcd7c63c..7580a576e56d 100644 --- a/test/test_common/utility.h +++ b/test/test_common/utility.h @@ -857,6 +857,11 @@ template class TestHeaderMapImplBase : public Inte HeaderMapImpl::copyFrom(*header_map_, rhs); header_map_->verifyByteSizeInternalForTest(); } + void copyFrom(const TestHeaderMapImplBase& rhs) { copyFrom(*rhs.header_map_); } + void copyFrom(const HeaderMap& rhs) { + HeaderMapImpl::copyFrom(*header_map_, rhs); + header_map_->verifyByteSizeInternalForTest(); + } TestHeaderMapImplBase& operator=(const TestHeaderMapImplBase& rhs) { if (this == &rhs) { return *this;