Skip to content

Commit

Permalink
[promises] Migrate http server filter to new API (grpc#35197)
Browse files Browse the repository at this point in the history
Closes grpc#35197

COPYBARA_INTEGRATE_REVIEW=grpc#35197 from ctiller:cg-http-svr cdde418
PiperOrigin-RevId: 587178983
  • Loading branch information
ctiller authored and copybara-github committed Dec 2, 2023
1 parent addd18b commit 7047cc1
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 45 deletions.
81 changes: 40 additions & 41 deletions src/core/ext/filters/http/server/http_server_filter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@

namespace grpc_core {

const NoInterceptor HttpServerFilter::Call::OnClientToServerMessage;
const NoInterceptor HttpServerFilter::Call::OnServerToClientMessage;

const grpc_channel_filter HttpServerFilter::kFilter =
MakePromiseBasedFilter<HttpServerFilter, FilterEndpoint::kServer,
kFilterExaminesServerInitialMetadata>("http-server");
Expand All @@ -71,85 +74,81 @@ ServerMetadataHandle MalformedRequest(absl::string_view explanation) {
}
} // namespace

ArenaPromise<ServerMetadataHandle> HttpServerFilter::MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) {
const auto& md = call_args.client_initial_metadata;

auto method = md->get(HttpMethodMetadata());
ServerMetadataHandle HttpServerFilter::Call::OnClientInitialMetadata(
ClientMetadata& md, HttpServerFilter* filter) {
auto method = md.get(HttpMethodMetadata());
if (method.has_value()) {
switch (*method) {
case HttpMethodMetadata::kPost:
break;
case HttpMethodMetadata::kPut:
if (allow_put_requests_) {
if (filter->allow_put_requests_) {
break;
}
ABSL_FALLTHROUGH_INTENDED;
case HttpMethodMetadata::kInvalid:
case HttpMethodMetadata::kGet:
return Immediate(MalformedRequest("Bad method header"));
return MalformedRequest("Bad method header");
}
} else {
return Immediate(MalformedRequest("Missing :method header"));
return MalformedRequest("Missing :method header");
}

auto te = md->Take(TeMetadata());
auto te = md.Take(TeMetadata());
if (te == TeMetadata::kTrailers) {
// Do nothing, ok.
} else if (!te.has_value()) {
return Immediate(MalformedRequest("Missing :te header"));
return MalformedRequest("Missing :te header");
} else {
return Immediate(MalformedRequest("Bad :te header"));
return MalformedRequest("Bad :te header");
}

auto scheme = md->Take(HttpSchemeMetadata());
auto scheme = md.Take(HttpSchemeMetadata());
if (scheme.has_value()) {
if (*scheme == HttpSchemeMetadata::kInvalid) {
return Immediate(MalformedRequest("Bad :scheme header"));
return MalformedRequest("Bad :scheme header");
}
} else {
return Immediate(MalformedRequest("Missing :scheme header"));
return MalformedRequest("Missing :scheme header");
}

md->Remove(ContentTypeMetadata());
md.Remove(ContentTypeMetadata());

Slice* path_slice = md->get_pointer(HttpPathMetadata());
Slice* path_slice = md.get_pointer(HttpPathMetadata());
if (path_slice == nullptr) {
return Immediate(MalformedRequest("Missing :path header"));
return MalformedRequest("Missing :path header");
}

if (md->get_pointer(HttpAuthorityMetadata()) == nullptr) {
absl::optional<Slice> host = md->Take(HostMetadata());
if (md.get_pointer(HttpAuthorityMetadata()) == nullptr) {
absl::optional<Slice> host = md.Take(HostMetadata());
if (host.has_value()) {
md->Set(HttpAuthorityMetadata(), std::move(*host));
md.Set(HttpAuthorityMetadata(), std::move(*host));
}
}

if (md->get_pointer(HttpAuthorityMetadata()) == nullptr) {
return Immediate(MalformedRequest("Missing :authority header"));
if (md.get_pointer(HttpAuthorityMetadata()) == nullptr) {
return MalformedRequest("Missing :authority header");
}

if (!surface_user_agent_) {
md->Remove(UserAgentMetadata());
if (!filter->surface_user_agent_) {
md.Remove(UserAgentMetadata());
}

call_args.server_initial_metadata->InterceptAndMap(
[](ServerMetadataHandle md) {
if (grpc_call_trace.enabled()) {
gpr_log(GPR_INFO, "%s[http-server] Write metadata",
Activity::current()->DebugTag().c_str());
}
FilterOutgoingMetadata(md.get());
md->Set(HttpStatusMetadata(), 200);
md->Set(ContentTypeMetadata(), ContentTypeMetadata::kApplicationGrpc);
return md;
});

return Map(next_promise_factory(std::move(call_args)),
[](ServerMetadataHandle md) -> ServerMetadataHandle {
FilterOutgoingMetadata(md.get());
return md;
});
return nullptr;
}

void HttpServerFilter::Call::OnServerInitialMetadata(ServerMetadata& md) {
if (grpc_call_trace.enabled()) {
gpr_log(GPR_INFO, "%s[http-server] Write metadata",
Activity::current()->DebugTag().c_str());
}
FilterOutgoingMetadata(&md);
md.Set(HttpStatusMetadata(), 200);
md.Set(ContentTypeMetadata(), ContentTypeMetadata::kApplicationGrpc);
}

void HttpServerFilter::Call::OnServerTrailingMetadata(ServerMetadata& md) {
FilterOutgoingMetadata(&md);
}

absl::StatusOr<HttpServerFilter> HttpServerFilter::Create(
Expand Down
14 changes: 10 additions & 4 deletions src/core/ext/filters/http/server/http_server_filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,22 @@
namespace grpc_core {

// Processes metadata on the server side for HTTP2 transports
class HttpServerFilter : public ChannelFilter {
class HttpServerFilter : public ImplementChannelFilter<HttpServerFilter> {
public:
static const grpc_channel_filter kFilter;

static absl::StatusOr<HttpServerFilter> Create(
const ChannelArgs& args, ChannelFilter::Args filter_args);

// Construct a promise for one call.
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) override;
class Call {
public:
ServerMetadataHandle OnClientInitialMetadata(ClientMetadata& md,
HttpServerFilter* filter);
void OnServerInitialMetadata(ServerMetadata& md);
void OnServerTrailingMetadata(ServerMetadata& md);
static const NoInterceptor OnClientToServerMessage;
static const NoInterceptor OnServerToClientMessage;
};

private:
HttpServerFilter(bool surface_user_agent, bool allow_put_requests)
Expand Down
65 changes: 65 additions & 0 deletions src/core/lib/channel/promise_based_filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
#include "src/core/lib/promise/context.h"
#include "src/core/lib/promise/pipe.h"
#include "src/core/lib/promise/poll.h"
#include "src/core/lib/promise/promise.h"
#include "src/core/lib/promise/race.h"
#include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/slice/slice_buffer.h"
Expand Down Expand Up @@ -143,6 +144,12 @@ inline constexpr bool HasAsyncErrorInterceptor(absl::Status (T::*)(A...)) {
return true;
}

template <typename T, typename... A>
inline constexpr bool HasAsyncErrorInterceptor(
ServerMetadataHandle (T::*)(A...)) {
return true;
}

template <typename T, typename... A>
inline constexpr bool HasAsyncErrorInterceptor(void (T::*)(A...)) {
return false;
Expand Down Expand Up @@ -277,6 +284,16 @@ auto MapResult(absl::Status (Derived::Call::*fn)(ServerMetadata&), Promise x,
});
}

template <typename Promise, typename Derived>
auto MapResult(void (Derived::Call::*fn)(ServerMetadata&), Promise x,
FilterCallData<Derived>* call_data) {
GPR_DEBUG_ASSERT(fn == &Derived::Call::OnServerTrailingMetadata);
return Map(std::move(x), [call_data](ServerMetadataHandle md) {
call_data->call.OnServerTrailingMetadata(*md);
return md;
});
}

inline auto RunCall(const NoInterceptor*, CallArgs call_args,
NextPromiseFactory next_promise_factory, void*) {
return next_promise_factory(std::move(call_args));
Expand All @@ -291,6 +308,31 @@ inline auto RunCall(void (Derived::Call::*fn)(ClientMetadata& md),
return next_promise_factory(std::move(call_args));
}

template <typename Derived>
inline auto RunCall(
ServerMetadataHandle (Derived::Call::*fn)(ClientMetadata& md),
CallArgs call_args, NextPromiseFactory next_promise_factory,
FilterCallData<Derived>* call_data) -> ArenaPromise<ServerMetadataHandle> {
GPR_DEBUG_ASSERT(fn == &Derived::Call::OnClientInitialMetadata);
auto return_md = call_data->call.OnClientInitialMetadata(
*call_args.client_initial_metadata);
if (return_md == nullptr) return next_promise_factory(std::move(call_args));
return Immediate(std::move(return_md));
}

template <typename Derived>
inline auto RunCall(ServerMetadataHandle (Derived::Call::*fn)(
ClientMetadata& md, Derived* channel),
CallArgs call_args, NextPromiseFactory next_promise_factory,
FilterCallData<Derived>* call_data)
-> ArenaPromise<ServerMetadataHandle> {
GPR_DEBUG_ASSERT(fn == &Derived::Call::OnClientInitialMetadata);
auto return_md = call_data->call.OnClientInitialMetadata(
*call_args.client_initial_metadata, call_data->channel);
if (return_md == nullptr) return next_promise_factory(std::move(call_args));
return Immediate(std::move(return_md));
}

template <typename Derived>
inline auto RunCall(void (Derived::Call::*fn)(ClientMetadata& md,
Derived* channel),
Expand All @@ -308,6 +350,18 @@ inline void InterceptClientToServerMessage(const NoInterceptor*, void*,
inline void InterceptServerInitialMetadata(const NoInterceptor*, void*,
CallArgs&) {}

template <typename Derived>
inline void InterceptServerInitialMetadata(
void (Derived::Call::*fn)(ServerMetadata&),
FilterCallData<Derived>* call_data, CallArgs& call_args) {
GPR_DEBUG_ASSERT(fn == &Derived::Call::OnServerInitialMetadata);
call_args.server_initial_metadata->InterceptAndMap(
[call_data](ServerMetadataHandle md) {
call_data->call.OnServerInitialMetadata(*md);
return md;
});
}

template <typename Derived>
inline void InterceptServerInitialMetadata(
absl::Status (Derived::Call::*fn)(ServerMetadata&),
Expand Down Expand Up @@ -373,6 +427,11 @@ MakeFilterCall(Derived* derived) {
// - absl::Status $INTERCEPTOR_NAME($VALUE_TYPE&):
// the filter intercepts this event, and can modify the value.
// it can fail, in which case the call will be aborted.
// - ServerMetadataHandle $INTERCEPTOR_NAME($VALUE_TYPE&)
// the filter intercepts this event, and can modify the value.
// the filter can return nullptr for success, or a metadata handle for
// failure (in which case the call will be aborted).
// useful for cases where the exact metadata returned needs to be customized.
// - void $INTERCEPTOR_NAME($VALUE_TYPE&, Derived*):
// the filter intercepts this event, and can modify the value.
// it can access the channel via the second argument.
Expand All @@ -381,6 +440,12 @@ MakeFilterCall(Derived* derived) {
// the filter intercepts this event, and can modify the value.
// it can access the channel via the second argument.
// it can fail, in which case the call will be aborted.
// - ServerMetadataHandle $INTERCEPTOR_NAME($VALUE_TYPE&, Derived*)
// the filter intercepts this event, and can modify the value.
// it can access the channel via the second argument.
// the filter can return nullptr for success, or a metadata handle for
// failure (in which case the call will be aborted).
// useful for cases where the exact metadata returned needs to be customized.
template <typename Derived>
class ImplementChannelFilter : public ChannelFilter {
public:
Expand Down

0 comments on commit 7047cc1

Please sign in to comment.