From 072f7ee918d4d155e04320bd1785cd3f5fe6583d Mon Sep 17 00:00:00 2001 From: Rasmus Larsen Date: Fri, 14 Apr 2023 19:20:38 +0200 Subject: [PATCH] Serialize debug_data when present in GOAWAY frames --- src/frame/go_away.rs | 13 +++++++++++-- src/proto/connection.rs | 23 ++++++++++++++++++++++ src/proto/go_away.rs | 4 ---- src/server.rs | 6 ++++++ tests/h2-support/src/frames.rs | 7 +++++++ tests/h2-tests/tests/server.rs | 35 ++++++++++++++++++++++++++++++++++ 6 files changed, 82 insertions(+), 6 deletions(-) diff --git a/src/frame/go_away.rs b/src/frame/go_away.rs index 91d9c4c6..4ab28d51 100644 --- a/src/frame/go_away.rs +++ b/src/frame/go_away.rs @@ -8,7 +8,6 @@ use crate::frame::{self, Error, Head, Kind, Reason, StreamId}; pub struct GoAway { last_stream_id: StreamId, error_code: Reason, - #[allow(unused)] debug_data: Bytes, } @@ -21,6 +20,15 @@ impl GoAway { } } + #[doc(hidden)] + #[cfg(feature = "unstable")] + pub fn with_debug_data(self, debug_data: impl Into) -> Self { + Self { + debug_data: debug_data.into(), + ..self + } + } + pub fn last_stream_id(&self) -> StreamId { self.last_stream_id } @@ -52,9 +60,10 @@ impl GoAway { pub fn encode(&self, dst: &mut B) { tracing::trace!("encoding GO_AWAY; code={:?}", self.error_code); let head = Head::new(Kind::GoAway, 0, StreamId::zero()); - head.encode(8, dst); + head.encode(8 + self.debug_data.len(), dst); dst.put_u32(self.last_stream_id.into()); dst.put_u32(self.error_code.into()); + dst.put(self.debug_data.slice(..)); } } diff --git a/src/proto/connection.rs b/src/proto/connection.rs index 619973df..7ea124e4 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -398,6 +398,18 @@ where self.go_away.go_away_now(frame); } + #[doc(hidden)] + #[cfg(feature = "unstable")] + fn go_away_now_debug_data(&mut self) { + let last_processed_id = self.streams.last_processed_id(); + + let frame = frame::GoAway::new(last_processed_id, Reason::NO_ERROR) + .with_debug_data("something went wrong"); + + self.streams.send_go_away(last_processed_id); + self.go_away.go_away(frame); + } + fn go_away_from_user(&mut self, e: Reason) { let last_processed_id = self.streams.last_processed_id(); let frame = frame::GoAway::new(last_processed_id, e); @@ -576,6 +588,17 @@ where // for a pong before proceeding. self.inner.ping_pong.ping_shutdown(); } + + #[doc(hidden)] + #[cfg(feature = "unstable")] + pub fn go_away_debug_data(&mut self) { + if self.inner.go_away.is_going_away() { + return; + } + + self.inner.as_dyn().go_away_now_debug_data(); + self.inner.ping_pong.ping_shutdown(); + } } impl Drop for Connection diff --git a/src/proto/go_away.rs b/src/proto/go_away.rs index 75942787..d52252cd 100644 --- a/src/proto/go_away.rs +++ b/src/proto/go_away.rs @@ -26,10 +26,6 @@ pub(super) struct GoAway { /// were a `frame::GoAway`, it might appear like we eventually wanted to /// serialize it. We **only** want to be able to look up these fields at a /// later time. -/// -/// (Technically, `frame::GoAway` should gain an opaque_debug_data field as -/// well, and we wouldn't want to save that here to accidentally dump in logs, -/// or waste struct space.) #[derive(Debug)] pub(crate) struct GoingAway { /// Stores the highest stream ID of a GOAWAY that has been sent. diff --git a/src/server.rs b/src/server.rs index f1f4cf47..032c0d17 100644 --- a/src/server.rs +++ b/src/server.rs @@ -544,6 +544,12 @@ where self.connection.go_away_gracefully(); } + #[doc(hidden)] + #[cfg(feature = "unstable")] + pub fn debug_data_shutdown(&mut self) { + self.connection.go_away_debug_data(); + } + /// Takes a `PingPong` instance from the connection. /// /// # Note diff --git a/tests/h2-support/src/frames.rs b/tests/h2-support/src/frames.rs index bc4e2e70..4ee20dd7 100644 --- a/tests/h2-support/src/frames.rs +++ b/tests/h2-support/src/frames.rs @@ -305,6 +305,13 @@ impl Mock { self.reason(frame::Reason::NO_ERROR) } + pub fn data(self, debug_data: I) -> Self + where + I: Into, + { + Mock(self.0.with_debug_data(debug_data.into())) + } + pub fn reason(self, reason: frame::Reason) -> Self { Mock(frame::GoAway::new(self.0.last_stream_id(), reason)) } diff --git a/tests/h2-tests/tests/server.rs b/tests/h2-tests/tests/server.rs index c8c1c9d1..78f4891a 100644 --- a/tests/h2-tests/tests/server.rs +++ b/tests/h2-tests/tests/server.rs @@ -705,6 +705,41 @@ async fn graceful_shutdown() { join(client, srv).await; } +#[tokio::test] +async fn go_away_sends_debug_data() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .send_frame(frames::headers(1).request("POST", "https://example.com/")) + .await; + client + .recv_frame(frames::go_away(1).no_error().data("something went wrong")) + .await; + }; + + let src = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + let (_req, _tx) = srv.next().await.unwrap().expect("server receives request"); + + srv.debug_data_shutdown(); + + let srv_fut = async move { + poll_fn(move |cx| srv.poll_closed(cx)) + .await + .expect("server"); + }; + + srv_fut.await + }; + + join(client, src).await; +} + #[tokio::test] async fn goaway_even_if_client_sent_goaway() { h2_support::trace_init!();