Skip to content

Commit

Permalink
Server side trailers helidon-io#7647
Browse files Browse the repository at this point in the history
  • Loading branch information
danielkec committed Sep 21, 2023
1 parent 85f5202 commit 57089ec
Show file tree
Hide file tree
Showing 6 changed files with 449 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,16 @@ class Http2ServerResponse extends ServerResponseBase<Http2ServerResponse> {
private final Http2StreamWriter writer;
private final int streamId;
private final ServerResponseHeaders headers;
private final ServerResponseHeaders trailers;
private final FlowControl.Outbound flowControl;
private final Http2ServerRequest request;

private boolean isSent;
private boolean streamingEntity;
private long bytesWritten;
private BlockingOutputStream outputStream;
private UnaryOperator<OutputStream> outputStreamFilter;
private String streamResult = null;

Http2ServerResponse(ConnectionContext ctx,
Http2ServerRequest request,
Expand All @@ -62,10 +65,12 @@ class Http2ServerResponse extends ServerResponseBase<Http2ServerResponse> {
FlowControl.Outbound flowControl) {
super(ctx, request);
this.ctx = ctx;
this.request = request;
this.writer = writer;
this.streamId = streamId;
this.flowControl = flowControl;
this.headers = ServerResponseHeaders.create();
this.trailers = ServerResponseHeaders.create();
}

@Override
Expand Down Expand Up @@ -111,18 +116,30 @@ public void send(byte[] entityBytes) {
"Status must be configured on response, "
+ "do not set HTTP/2 pseudo headers"));

Http2FrameData frameData = new Http2FrameData(Http2FrameHeader.create(bytes.length,
Http2FrameTypes.DATA,
DataFlags.create(Http2Flag.END_OF_STREAM),
streamId),
BufferData.create(bytes));
boolean sendTrailers = request.headers().contains(HeaderValues.TE_TRAILERS) || headers.contains(HeaderNames.TRAILER);

Http2FrameData frameData =
new Http2FrameData(Http2FrameHeader.create(bytes.length,
Http2FrameTypes.DATA,
DataFlags.create(sendTrailers ? 0 : Http2Flag.END_OF_STREAM),
streamId),
BufferData.create(bytes));

http2Headers.validateResponse();
bytesWritten = writer.writeHeaders(http2Headers,
streamId,
Http2Flag.HeaderFlags.create(Http2Flag.END_OF_HEADERS),
frameData, flowControl);

if (sendTrailers) {
Http2Headers http2trailers = Http2Headers.create(trailers);
int written = writer.writeHeaders(http2trailers,
streamId,
Http2Flag.HeaderFlags.create(Http2Flag.END_OF_HEADERS | Http2Flag.END_OF_STREAM),
flowControl);
bytesWritten += written;
}

afterSend();
}

Expand All @@ -141,7 +158,11 @@ public OutputStream outputStream() {
}
streamingEntity = true;

outputStream = new BlockingOutputStream(headers, writer, streamId, flowControl, status(), () -> {
if (request.headers().contains(HeaderValues.TE_TRAILERS)) {
headers.add(STREAM_TRAILERS);
}

outputStream = new BlockingOutputStream(request, this, () -> {
this.isSent = true;
afterSend();
});
Expand All @@ -161,10 +182,19 @@ public long bytesWritten() {
public ServerResponseHeaders headers() {
return headers;
}
@Override
public ServerResponseHeaders trailers() {
if (request.headers().contains(HeaderValues.TE_TRAILERS) || headers.contains(HeaderNames.TRAILER)) {
return trailers;
}
throw new IllegalStateException(
"Trailers are supported only when request came with 'TE: trailers' header or "
+ "response headers have trailer names definition 'Trailer: <trailer-name>'");
}

@Override
public void streamResult(String result) {
// TODO use this when closing the stream
this.streamResult = result;
}

@Override
Expand Down Expand Up @@ -210,30 +240,32 @@ public void streamFilter(UnaryOperator<OutputStream> filterFunction) {

private static class BlockingOutputStream extends OutputStream {

private final Http2ServerRequest request;
private final ServerResponseHeaders headers;
private final ServerResponseHeaders trailers;
private final Http2StreamWriter writer;
private final int streamId;
private final FlowControl.Outbound flowControl;
private final Status status;
private final Runnable responseCloseRunnable;
private final Http2ServerResponse response;

private BufferData firstBuffer;
private boolean closed;
private boolean firstByte = true;
private long bytesWritten;

private BlockingOutputStream(ServerResponseHeaders headers,
Http2StreamWriter writer,
int streamId,
FlowControl.Outbound flowControl,
Status status,
private BlockingOutputStream(Http2ServerRequest request,
Http2ServerResponse response,
Runnable responseCloseRunnable) {

this.headers = headers;
this.writer = writer;
this.streamId = streamId;
this.flowControl = flowControl;
this.status = status;
this.request = request;
this.response = response;
this.headers = response.headers;
this.trailers = response.trailers;
this.writer = response.writer;
this.streamId = response.streamId;
this.flowControl = response.flowControl;
this.status = response.status();
this.responseCloseRunnable = responseCloseRunnable;
}

Expand Down Expand Up @@ -269,8 +301,12 @@ void commit() {
return;
}
this.closed = true;
boolean sendTrailers =
request.headers().contains(HeaderValues.TE_TRAILERS) || headers.contains(HeaderNames.TRAILER);
if (firstByte) {
sendFirstChunkOnly();
sendFirstChunkOnly(sendTrailers);
} else if (sendTrailers) {
sendTrailers();
} else {
sendEndOfStream();
}
Expand Down Expand Up @@ -301,7 +337,7 @@ private void write(BufferData buffer) throws IOException {
}
}

private void sendFirstChunkOnly() {
private void sendFirstChunkOnly(boolean sendTrailers) {
int contentLength;
if (firstBuffer == null) {
headers.set(HeaderValues.CONTENT_LENGTH_ZERO);
Expand All @@ -323,16 +359,21 @@ private void sendFirstChunkOnly() {
if (contentLength == 0) {
int written = writer.writeHeaders(http2Headers,
streamId,
Http2Flag.HeaderFlags.create(Http2Flag.END_OF_HEADERS
| Http2Flag.END_OF_STREAM),
Http2Flag.HeaderFlags.create(
sendTrailers
? Http2Flag.END_OF_HEADERS
: Http2Flag.END_OF_HEADERS | Http2Flag.END_OF_STREAM),
flowControl);
bytesWritten += written;
} else {
Http2FrameData frameData = new Http2FrameData(Http2FrameHeader.create(contentLength,
Http2FrameTypes.DATA,
DataFlags.create(Http2Flag.END_OF_STREAM),
streamId),
firstBuffer);
Http2FrameData frameData =
new Http2FrameData(Http2FrameHeader.create(contentLength,
Http2FrameTypes.DATA,
DataFlags.create(sendTrailers
? 0
: Http2Flag.END_OF_STREAM),
streamId),
firstBuffer);
int written = writer.writeHeaders(http2Headers,
streamId,
Http2Flag.HeaderFlags.create(Http2Flag.END_OF_HEADERS),
Expand Down Expand Up @@ -379,5 +420,20 @@ private void sendEndOfStream() {
bytesWritten += Http2FrameHeader.LENGTH;
writer.writeData(frameData, flowControl);
}

private void sendTrailers(){
if (response.streamResult != null) {
trailers.set(STREAM_RESULT_NAME, response.streamResult);
}
trailers.set(STREAM_STATUS_NAME, status.code());

Http2Headers http2Headers = Http2Headers.create(trailers);
int written = writer.writeHeaders(http2Headers,
streamId,
Http2Flag.HeaderFlags.create(Http2Flag.END_OF_HEADERS
| Http2Flag.END_OF_STREAM),
flowControl);
bytesWritten += written;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.helidon.webserver.tests.http2;

import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
Expand All @@ -27,6 +28,13 @@
import java.util.Set;
import java.util.stream.Collectors;

import io.helidon.http.Header;
import io.helidon.http.HeaderNames;
import io.helidon.http.HeaderValues;
import io.helidon.http.Status;
import io.helidon.webclient.api.ClientResponseTyped;
import io.helidon.webclient.http2.Http2Client;
import io.helidon.webclient.http2.Http2ClientProtocolConfig;
import io.helidon.webserver.WebServer;
import io.helidon.webserver.WebServerConfig;
import io.helidon.webserver.http.HttpRouting;
Expand All @@ -44,6 +52,7 @@
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import static io.helidon.common.testing.http.junit5.HttpHeaderMatcher.hasHeader;
import static io.helidon.http.Method.GET;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
Expand All @@ -53,6 +62,8 @@ public class HeadersTest {

private static final Duration TIMEOUT = Duration.ofSeconds(10);
private static final String DATA = "Helidon!!!".repeat(10);
private static final Header TEST_TRAILER_HEADER = HeaderValues.create("test-trailer", "trailer-value");
private final Http2Client client;

@SetUpServer
static void setUpServer(WebServerConfig.Builder serverBuilder) {
Expand All @@ -76,25 +87,67 @@ static void setUpServer(WebServerConfig.Builder serverBuilder) {

@SetUpRoute
static void router(HttpRouting.Builder router) {
router.error(IllegalStateException.class, (req, res, t) -> res.status(500).send(t.getMessage()));
router.route(Http2Route.route(GET, "/ping", (req, res) -> res.send("pong")));
router.route(Http2Route.route(GET, "/cont-out",
(req, res) -> {
for (int i = 0; i < 500; i++) {
res.header("test-header-" + i, DATA + i);
}
res.send();
}
(req, res) -> {
for (int i = 0; i < 500; i++) {
res.header("test-header-" + i, DATA + i);
}
res.send();
}
));
router.route(Http2Route.route(GET, "/cont-in",
(req, res) -> {
String joinedHeaders = req.headers()
.stream()
.filter(h -> h.name().startsWith("test-header-"))
.map(h -> h.name() + "=" + h.get())
.collect(Collectors.joining("\n"));
res.send(joinedHeaders);
}
(req, res) -> {
String joinedHeaders = req.headers()
.stream()
.filter(h -> h.name().startsWith("test-header-"))
.map(h -> h.name() + "=" + h.get())
.collect(Collectors.joining("\n"));
res.send(joinedHeaders);
}
));
router.route(Http2Route.route(GET, "/trailers-stream",
(req, res) -> {
res.header(HeaderNames.TRAILER, TEST_TRAILER_HEADER.name());
try (var os = res.outputStream()) {
os.write(DATA.getBytes());
os.write(DATA.getBytes());
os.write(DATA.getBytes());
res.trailers().add(TEST_TRAILER_HEADER);
}
}
));
router.route(Http2Route.route(GET, "/trailers-stream-result",
(req, res) -> {
try (var os = res.outputStream()) {
os.write(DATA.getBytes());
os.write(DATA.getBytes());
os.write(DATA.getBytes());
res.streamResult("Kaboom!");
}
}
));
router.route(Http2Route.route(GET, "/trailers",
(req, res) -> {
res.header(HeaderNames.TRAILER, TEST_TRAILER_HEADER.name());
res.trailers().add(TEST_TRAILER_HEADER);
res.send(DATA.repeat(3));
}
));
router.route(Http2Route.route(GET, "/trailers-no-trailers",
(req, res) -> {
res.trailers().add(TEST_TRAILER_HEADER);
res.send(DATA);
}
));
}

HeadersTest(WebServer server) {
client = Http2Client.builder()
.baseUri("http://localhost:" + server.port())
.protocolConfig(Http2ClientProtocolConfig.builder().priorKnowledge(true).build())
.build();
}

@Test
Expand Down Expand Up @@ -166,6 +219,52 @@ void serverInboundTooLarge(WebServer server) throws IOException, InterruptedExce
HttpResponse.BodyHandlers.ofString()));
}

@Test
void trailersEntity() throws IOException {
ClientResponseTyped<InputStream> res = client
.get("/trailers")
.request(InputStream.class);
try (var is = res.entity()) {
is.readAllBytes();
}
assertThat(res.trailers(), hasHeader(TEST_TRAILER_HEADER));
}

@Test
void trailersStream() throws IOException {
ClientResponseTyped<InputStream> res = client
.get("/trailers-stream")
.request(InputStream.class);
try (var is = res.entity()) {
is.readAllBytes();
}
assertThat(res.trailers(), hasHeader(TEST_TRAILER_HEADER));
}

@Test
void trailersStreamResult() throws IOException {
ClientResponseTyped<InputStream> res = client
.get("/trailers-stream-result")
.header(HeaderValues.TE_TRAILERS)
.request(InputStream.class);
try (var is = res.entity()) {
is.readAllBytes();
}
assertThat(res.trailers(), hasHeader(HeaderValues.create("stream-result", "Kaboom!")));
}

@Test
void trailersNoTrailers() {
ClientResponseTyped<String> res = client
.get("/trailers-no-trailers")
.request(String.class);

assertThat(res.status(), is(Status.INTERNAL_SERVER_ERROR_500));
assertThat(res.entity(), is(
"Trailers are supported only when request came with 'TE: trailers' header or "
+ "response headers have trailer names definition 'Trailer: <trailer-name>'"));
}

private HttpClient http2Client(URI base) throws IOException, InterruptedException {
HttpClient client = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_2)
Expand Down
Loading

0 comments on commit 57089ec

Please sign in to comment.