Skip to content

Commit

Permalink
Allow body readers to read form uploads (#11011)
Browse files Browse the repository at this point in the history
* Allow body readers to read form uploads

* Correct route info

* Improve
  • Loading branch information
dstepanov authored Jul 25, 2024
1 parent f3abc08 commit 1d09d95
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 61 deletions.
2 changes: 2 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ testcontainers = "1.19.8"
tomlj="1.1.1"
vertx = "4.5.9"
wiremock = "2.33.2"
mimepull = "1.10.0"

#
# Versions which start with managed- are managed by Micronaut in the sense
Expand Down Expand Up @@ -280,6 +281,7 @@ vertx = { module = "io.vertx:vertx-core", version.ref = "vertx" }
vertx-webclient = { module = "io.vertx:vertx-web-client", version.ref = "vertx" }
httpcomponents-client = { module = "org.apache.httpcomponents:httpclient", version.ref = "httpcomponents-client" }
httpcomponents-mime = { module = "org.apache.httpcomponents:httpmime", version.ref = "httpcomponents-client" }
mimepull = { module = "org.jvnet.mimepull:mimepull", version.ref = "mimepull" }

wiremock = { module = "com.github.tomakehurst:wiremock-jre8", version.ref = "wiremock" }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public void writeTo(@NonNull HttpRequest<?> request, @NonNull MutableHttpRespons

@Override
public MessageBodyWriter<T> createSpecific(Argument<T> type) {
return new NettyJsonHandler<>((JsonMessageHandler<T>) jsonMessageHandler.createSpecific(type));
return new NettyJsonHandler<>(jsonMessageHandler.createSpecific(type));
}

@Override
Expand Down
2 changes: 2 additions & 0 deletions http-server-netty/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ dependencies {
// Adding these for now since micronaut-test isnt resolving correctly ... probably need to upgrade gradle there too
testImplementation libs.junit.jupiter.api
testImplementation project(":websocket")

testImplementation(libs.mimepull)
}

tasks.withType(Test).configureEach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package io.micronaut.http.server.netty.binders;

import io.micronaut.core.annotation.Internal;
import io.micronaut.core.convert.ArgumentConversionContext;
import io.micronaut.core.convert.ConversionContext;
import io.micronaut.core.convert.ConversionError;
Expand All @@ -24,9 +25,8 @@
import io.micronaut.core.io.buffer.ByteBuffer;
import io.micronaut.core.io.buffer.ReferenceCounted;
import io.micronaut.core.propagation.PropagatedContext;
import io.micronaut.core.type.Argument;
import io.micronaut.core.util.CollectionUtils;
import io.micronaut.http.HttpAttributes;
import io.micronaut.http.HttpHeaders;
import io.micronaut.http.HttpRequest;
import io.micronaut.http.MediaType;
import io.micronaut.http.bind.binders.DefaultBodyAnnotationBinder;
Expand All @@ -50,14 +50,12 @@
import io.netty.handler.codec.http.DefaultLastHttpContent;
import io.netty.handler.codec.http.multipart.InterfaceHttpData;

import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;

@Internal
final class NettyBodyAnnotationBinder<T> extends DefaultBodyAnnotationBinder<T> {
private static final Set<Class<?>> RAW_BODY_TYPES = CollectionUtils.setOf(String.class, byte[].class, ByteBuffer.class, InputStream.class);
final NettyHttpServerConfiguration httpServerConfiguration;
final MessageBodyHandlerRegistry bodyHandlerRegistry;

Expand All @@ -69,10 +67,6 @@ final class NettyBodyAnnotationBinder<T> extends DefaultBodyAnnotationBinder<T>
this.bodyHandlerRegistry = bodyHandlerRegistry;
}

public static boolean isRaw(Argument<?> bodyType) {
return RAW_BODY_TYPES.contains(bodyType.getType());
}

@Override
protected BindingResult<T> bindBodyPart(ArgumentConversionContext<T> context, HttpRequest<?> source, String bodyComponent) {
if (source instanceof NettyHttpRequest<?> nhr && nhr.isFormOrMultipartData()) {
Expand Down Expand Up @@ -149,56 +143,41 @@ public List<ConversionError> getConversionErrors() {
}

Optional<T> transform(NettyHttpRequest<?> nhr, ArgumentConversionContext<T> context, AvailableByteBody imm) throws Throwable {
if (!isRaw(context.getArgument())) {
if (nhr.isFormOrMultipartData()) {
FormDataHttpContentProcessor processor = new FormDataHttpContentProcessor(nhr, httpServerConfiguration);
ByteBuf buf = AvailableNettyByteBody.toByteBuf(imm);
List<InterfaceHttpData> result = new ArrayList<>();
if (buf.isReadable()) {
processor.add(new DefaultLastHttpContent(buf), result);
} else {
buf.release();
}
processor.complete(result);
Optional<T> converted = new ImmediateMultiObjectBody(result)
.single(httpServerConfiguration.getDefaultCharset(), nhr.getChannelHandlerContext().alloc())
.convert(conversionService, context)
.map(o -> (T) o.claimForExternal());
nhr.setLegacyBody(converted.orElse(null));
return converted;
}
MessageBodyReader<T> reader = null;
final RouteInfo<?> routeInfo = nhr.getAttribute(HttpAttributes.ROUTE_INFO, RouteInfo.class).orElse(null);
if (routeInfo != null) {
reader = (MessageBodyReader<T>) routeInfo.getMessageBodyReader();
}
MediaType mediaType = nhr.getContentType().orElse(null);
if (mediaType != null) {
if (reader == null) {
reader = bodyHandlerRegistry.findReader(context.getArgument(), List.of(mediaType)).orElse(null);
}
if (reader != null) {
ByteBuffer<?> byteBuffer = imm.toByteBuffer();
boolean success = false;
try {
T result = reader.read(context.getArgument(), mediaType, nhr.getHeaders(), byteBuffer);
success = true;
nhr.setLegacyBody(result);
return Optional.ofNullable(result);
} catch (CodecException ce) {
if (ce.getCause() instanceof Exception e) {
context.reject(e);
} else {
context.reject(ce);
}
return Optional.empty();
} finally {
if (!success && byteBuffer instanceof ReferenceCounted rc) {
rc.release();
}
}
}
MessageBodyReader<T> reader = null;
final RouteInfo<?> routeInfo = nhr.getAttribute(HttpAttributes.ROUTE_INFO, RouteInfo.class).orElse(null);
if (routeInfo != null) {
reader = (MessageBodyReader<T>) routeInfo.getMessageBodyReader();
}
MediaType mediaType = nhr.getContentType().orElse(null);
if (mediaType != null && (reader == null || !reader.isReadable(context.getArgument(), mediaType))) {
reader = bodyHandlerRegistry.findReader(context.getArgument(), List.of(mediaType)).orElse(null);
}
if (reader != null && context.getArgument().getType().equals(Object.class)) {
// Prevent random object convertors
reader = null;
}
if (reader == null && nhr.isFormOrMultipartData()) {
FormDataHttpContentProcessor processor = new FormDataHttpContentProcessor(nhr, httpServerConfiguration);
ByteBuf buf = AvailableNettyByteBody.toByteBuf(imm);
List<InterfaceHttpData> data = new ArrayList<>();
if (buf.isReadable()) {
processor.add(new DefaultLastHttpContent(buf), data);
} else {
buf.release();
}
processor.complete(data);
Optional<T> converted = new ImmediateMultiObjectBody(data)
.single(httpServerConfiguration.getDefaultCharset(), nhr.getChannelHandlerContext().alloc())
.convert(conversionService, context)
.map(o -> (T) o.claimForExternal());
nhr.setLegacyBody(converted.orElse(null));
return converted;
}
ByteBuffer<?> byteBuffer = imm.toByteBuffer();
if (reader != null) {
T result = read(context, reader, nhr.getHeaders(), mediaType, byteBuffer);
nhr.setLegacyBody(result);
return Optional.ofNullable(result);
}
//noinspection unchecked
Optional<T> converted = new ImmediateSingleObjectBody(imm.toByteBuffer().asNativeBuffer())
Expand All @@ -207,4 +186,24 @@ Optional<T> transform(NettyHttpRequest<?> nhr, ArgumentConversionContext<T> cont
nhr.setLegacyBody(converted.orElse(null));
return converted;
}

private T read(ArgumentConversionContext<T> context, MessageBodyReader<T> reader, HttpHeaders headers, MediaType mediaType, ByteBuffer<?> byteBuffer) {
boolean success = false;
try {
T result = reader.read(context.getArgument(), mediaType, headers, byteBuffer);
success = true;
return result;
} catch (CodecException ce) {
if (ce.getCause() instanceof Exception e) {
context.reject(e);
} else {
context.reject(ce);
}
return null;
} finally {
if (!success && byteBuffer instanceof ReferenceCounted rc) {
rc.release();
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package io.micronaut.http.server.netty

import io.micronaut.context.ApplicationContext
import io.micronaut.context.annotation.Requires
import io.micronaut.core.annotation.Nullable
import io.micronaut.core.type.Argument
import io.micronaut.core.type.Headers
import io.micronaut.http.HttpRequest
import io.micronaut.http.MediaType
import io.micronaut.http.annotation.Body
import io.micronaut.http.annotation.Consumes
import io.micronaut.http.annotation.Controller
import io.micronaut.http.annotation.Post
import io.micronaut.http.annotation.Produces
import io.micronaut.http.body.MessageBodyReader
import io.micronaut.http.client.HttpClient
import io.micronaut.http.client.multipart.MultipartBody
import io.micronaut.http.codec.CodecException
import io.micronaut.runtime.server.EmbeddedServer
import jakarta.inject.Singleton
import org.junit.Assert
import org.jvnet.mimepull.Header
import org.jvnet.mimepull.MIMEMessage
import spock.lang.Specification

import java.nio.charset.StandardCharsets
import java.util.stream.Collectors

class CustomFileUploadSpec extends Specification {

def 'custom file upload'() {
given:
def ctx = ApplicationContext.run(['spec.name': 'CustomFileUploadSpec'])
def server = ctx.getBean(EmbeddedServer)
server.start()
def client = ctx.createBean(HttpClient, server.URI).toBlocking()

when:
def response = client.exchange(HttpRequest.POST(
'/multipart/file-upload',
MultipartBody.builder().addPart('MyName', 'myFile', 'foo'.bytes).build())
.contentType(MediaType.MULTIPART_FORM_DATA_TYPE), String)
then:
response.body() == 'Uploaded content-disposition: form-data; name="MyName"; filename="myFile", content-length: 3, content-type: application/octet-stream, content-transfer-encoding: binary foo'

cleanup:
client.close()
server.stop()
}

@Controller('/multipart')
@Requires(property = 'spec.name', value = 'CustomFileUploadSpec')
@Produces(MediaType.TEXT_PLAIN)
static class MultipartController {
@Post(value = '/file-upload', consumes = MediaType.MULTIPART_FORM_DATA)
String completeFileUpload(@Body MyFileUpload data) {
return "Uploaded " + data.value()
}

}

record MyFileUpload(String value) {
}

@Singleton
@Consumes("multipart/form-data")
static class MyFileUpdateReader implements MessageBodyReader<MyFileUpload> {

@Override
MyFileUpload read(Argument<MyFileUpload> type, MediaType mediaType, Headers httpHeaders, InputStream inputStream) throws CodecException {
MIMEMessage mimeMessage = new MIMEMessage(inputStream, mediaType.getParameters().get("boundary").orElse(""))
mimeMessage.parseAll()
def attachments = mimeMessage.getAttachments()
Assert.assertEquals(1, attachments.size())
def part = attachments.get(0)
def headers = part.getAllHeaders().stream().map { Header h -> h.name + ": " + h.value }.collect(Collectors.joining(", "))
return new MyFileUpload(headers + " " + new String(part.read().readAllBytes(), StandardCharsets.UTF_8))
}
}

record Metadata(@Nullable String foo) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ private ExecutionFlow<HttpResponse<?>> onErrorNoFilter(HttpRequest<?> request, T

RouteMatch<?> errorRoute = routeExecutor.findErrorRoute(cause, declaringType, request);
if (errorRoute != null) {
RouteExecutor.setRouteAttributes(request, errorRoute);
if (routeExecutor.serverConfiguration.isLogHandledExceptions()) {
routeExecutor.logException(cause);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,13 @@ UriRouteMatch<Object, Object> findRouteMatch(HttpRequest<?> httpRequest) {
}

static void setRouteAttributes(HttpRequest<?> request, UriRouteMatch<Object, Object> route) {
setRouteAttributes(request, (RouteMatch<?>) route);
request.setAttribute(HttpAttributes.URI_TEMPLATE, route.getRouteInfo().getUriMatchTemplate().toString());
}

static void setRouteAttributes(HttpRequest<?> request, RouteMatch<?> route) {
request.setAttribute(HttpAttributes.ROUTE_MATCH, route);
request.setAttribute(HttpAttributes.ROUTE_INFO, route.getRouteInfo());
request.setAttribute(HttpAttributes.URI_TEMPLATE, route.getRouteInfo().getUriMatchTemplate().toString());
}

/**
Expand Down Expand Up @@ -347,6 +351,7 @@ RouteMatch<?> findErrorRoute(Throwable cause,
if (LOG.isDebugEnabled()) {
LOG.debug("Found matching exception handler for exception [{}]: {}", cause.getMessage(), errorRoute);
}
setRouteAttributes(httpRequest, errorRoute);
requestArgumentSatisfier.fulfillArgumentRequirementsBeforeFilters(errorRoute, httpRequest);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private static CodecException decorateRead(Argument<?> type, IOException e) {
}

@Override
public MessageBodyWriter<T> createSpecific(Argument<T> type) {
public JsonMessageHandler<T> createSpecific(Argument<T> type) {
return new JsonMessageHandler<>(jsonMapper.createSpecific(type));
}

Expand Down

0 comments on commit 1d09d95

Please sign in to comment.