Skip to content

Commit

Permalink
Type pollution test (#10918)
Browse files Browse the repository at this point in the history
Add a test to the benchmarks module that runs certain JMH benchmarks with the redhat type pollution agent ( https://github.com/RedHatPerf/type-pollution-agent ) and triggers a test failure if there is excessive type pollution.

This works by adding a new test task which includes the agent, and then using JFRUnit to access the JFR events emitted by the agents. If there are too many thrash events for a single concrete type, details are logged and the test fails.

Additionally:
- Add a JMH benchmark that mimics the TechEmpower benchmark
- Fix various type pollution issues that were found using the FullHttpStackBenchmark
  • Loading branch information
yawkat authored Jul 19, 2024
1 parent 93a6957 commit bb68d7c
Show file tree
Hide file tree
Showing 11 changed files with 360 additions and 47 deletions.
32 changes: 32 additions & 0 deletions benchmarks/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@ plugins {
id "me.champeau.jmh" version "0.7.2"
}

sourceSets {
create("typeCheckTest") {
compileClasspath += sourceSets.jmh.output
runtimeClasspath += sourceSets.jmh.output
}
}

dependencies {
annotationProcessor project(":inject-java")
jmhAnnotationProcessor project(":inject-java")
Expand Down Expand Up @@ -33,7 +40,19 @@ dependencies {
}

jmh libs.jmh.core

typeCheckTestImplementation libs.junit.jupiter
typeCheckTestImplementation libs.micronaut.test.type.pollution
typeCheckTestImplementation ("net.bytebuddy:byte-buddy-agent:1.14.17")
typeCheckTestImplementation ("net.bytebuddy:byte-buddy:1.14.17")
typeCheckTestRuntimeOnly libs.junit.platform.engine
}

configurations {
typeCheckTestImplementation.extendsFrom(jmhImplementation, implementation)
typeCheckTestRuntimeOnly.extendsFrom(jmhRuntimeOnly, runtimeOnly)
}

jmh {
includes = ['io.micronaut.http.server.StartupBenchmark']
duplicateClassesStrategy = DuplicatesStrategy.WARN
Expand All @@ -42,6 +61,19 @@ jmh {
tasks.named("processJmhResources") {
duplicatesStrategy = DuplicatesStrategy.WARN
}

tasks.register('typeCheckTest', Test) {
description = "Runs type check tests."
group = "verification"

testClassesDirs = sourceSets.typeCheckTest.output.classesDirs
classpath = sourceSets.typeCheckTest.runtimeClasspath

useJUnitPlatform()
}

check.dependsOn typeCheckTest

['spotlessJavaCheck', 'checkstyleMain', 'checkstyleJmh'].each {
tasks.named(it) {
enabled = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import io.micronaut.http.server.netty.NettyHttpServer;
import io.micronaut.runtime.server.EmbeddedServer;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
Expand Down Expand Up @@ -37,6 +35,13 @@
import java.util.concurrent.TimeUnit;

public class FullHttpStackBenchmark {
/**
* If {@code true}, verify that the test is running on a netty {@link FastThreadLocalThread}.
* This is relevant for performance testing, but doesn't matter for type pollution tests. Don't
* turn this off for perf testing!
*/
public static boolean checkFtlThread = true;

@Benchmark
public void test(Holder holder) {
ByteBuf response = holder.exchange();
Expand Down Expand Up @@ -85,7 +90,7 @@ public static class Holder {

@Setup
public void setUp() {
if (!(Thread.currentThread() instanceof FastThreadLocalThread)) {
if (checkFtlThread && !(Thread.currentThread() instanceof FastThreadLocalThread)) {
throw new IllegalStateException("Should run on a netty FTL thread");
}

Expand All @@ -111,14 +116,7 @@ public void setUp() {
clientChannel.writeOutbound(request);
clientChannel.flushOutbound();

requestBytes = PooledByteBufAllocator.DEFAULT.buffer();
while (true) {
ByteBuf part = clientChannel.readOutbound();
if (part == null) {
break;
}
requestBytes.writeBytes(part);
}
requestBytes = NettyUtil.readAllOutboundContiguous(clientChannel);

// sanity check: run req/resp once and see that the response is correct
responseBytes = exchange();
Expand All @@ -128,7 +126,7 @@ public void setUp() {
//System.out.println(response.content().toString(StandardCharsets.UTF_8));
Assertions.assertEquals(HttpResponseStatus.OK, response.status());
Assertions.assertEquals("application/json", response.headers().get(HttpHeaderNames.CONTENT_TYPE));
Assertions.assertEquals("keep-alive", response.headers().get(HttpHeaderNames.CONNECTION));
Assertions.assertNull(response.headers().get(HttpHeaderNames.CONNECTION));
String expectedResponseBody = "{\"listIndex\":4,\"stringIndex\":0}";
Assertions.assertEquals(expectedResponseBody, response.content().toString(StandardCharsets.UTF_8));
Assertions.assertEquals(expectedResponseBody.length(), response.headers().getInt(HttpHeaderNames.CONTENT_LENGTH));
Expand All @@ -138,15 +136,7 @@ public void setUp() {
private ByteBuf exchange() {
channel.writeInbound(requestBytes.retainedDuplicate());
channel.runPendingTasks();
CompositeByteBuf response = PooledByteBufAllocator.DEFAULT.compositeBuffer();
while (true) {
ByteBuf part = channel.readOutbound();
if (part == null) {
break;
}
response.addComponent(true, part);
}
return response;
return NettyUtil.readAllOutboundComposite(channel);
}

@TearDown
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package io.micronaut.http.server.stack;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.embedded.EmbeddedChannel;

final class NettyUtil {
static ByteBuf readAllOutboundContiguous(EmbeddedChannel clientChannel) {
ByteBuf requestBytes = PooledByteBufAllocator.DEFAULT.buffer();
while (true) {
ByteBuf part = clientChannel.readOutbound();
if (part == null) {
break;
}
requestBytes.writeBytes(part);
}
return requestBytes;
}

static ByteBuf readAllOutboundComposite(EmbeddedChannel channel) {
CompositeByteBuf response = PooledByteBufAllocator.DEFAULT.compositeBuffer();
while (true) {
ByteBuf part = channel.readOutbound();
if (part == null) {
break;
}
response.addComponent(true, part);
}
return response;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package io.micronaut.http.server.stack;

import io.micronaut.context.ApplicationContext;
import io.micronaut.context.annotation.Requires;
import io.micronaut.http.MediaType;
import io.micronaut.http.annotation.Controller;
import io.micronaut.http.annotation.Get;
import io.micronaut.http.server.netty.NettyHttpServer;
import io.micronaut.runtime.server.EmbeddedServer;
import io.netty.buffer.ByteBuf;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import org.junit.jupiter.api.Assertions;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;

import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.TimeUnit;

/**
* JMH benchmark that mimics the TechEmpower framework benchmarks.
*/
public class TfbLikeBenchmark {
public static void main(String[] args) throws RunnerException {
Options opt = new OptionsBuilder()
.include(TfbLikeBenchmark.class.getName() + ".*")
.warmupIterations(20)
.measurementIterations(30)
.mode(Mode.AverageTime)
.timeUnit(TimeUnit.NANOSECONDS)
.forks(1)
.build();

new Runner(opt).run();
}

@Benchmark
public void test(Holder holder) {
ByteBuf response = holder.exchange();
if (!holder.responseBytes.equals(response)) {
throw new AssertionError("Response did not match");
}
response.release();
}

@State(Scope.Thread)
public static class Holder {
ApplicationContext ctx;
EmbeddedChannel channel;
ByteBuf requestBytes;
ByteBuf responseBytes;

@Setup
public void setUp() {
ctx = ApplicationContext.run(Map.of(
"spec.name", "TfbLikeBenchmark",
"micronaut.server.date-header", false // disabling this makes the response identical each time
));
EmbeddedServer server = ctx.getBean(EmbeddedServer.class);
channel = ((NettyHttpServer) server).buildEmbeddedChannel(false);

EmbeddedChannel clientChannel = new EmbeddedChannel();
clientChannel.pipeline().addLast(new HttpClientCodec());
clientChannel.pipeline().addLast(new HttpObjectAggregator(1000));

FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/plaintext");
request.headers().add(HttpHeaderNames.ACCEPT, "text/plain,text/html;q=0.9,application/xhtml+xml;q=0.9,application/xml;q=0.8,*/*;q=0.7");
clientChannel.writeOutbound(request);
clientChannel.flushOutbound();

requestBytes = NettyUtil.readAllOutboundContiguous(clientChannel);

// sanity check: run req/resp once and see that the response is correct
responseBytes = exchange();
clientChannel.writeInbound(responseBytes.retainedDuplicate());
FullHttpResponse response = clientChannel.readInbound();
Assertions.assertEquals(HttpResponseStatus.OK, response.status());
Assertions.assertEquals("text/plain", response.headers().get(HttpHeaderNames.CONTENT_TYPE));
String expectedResponseBody = "Hello, World!";
Assertions.assertEquals(expectedResponseBody, response.content().toString(StandardCharsets.UTF_8));
Assertions.assertEquals(expectedResponseBody.length(), response.headers().getInt(HttpHeaderNames.CONTENT_LENGTH));
response.release();
}

ByteBuf exchange() {
channel.writeInbound(requestBytes.retainedDuplicate());
channel.runPendingTasks();
return NettyUtil.readAllOutboundComposite(channel);
}

@TearDown
public void tearDown() {
ctx.close();
requestBytes.release();
responseBytes.release();
}
}

@Controller("/plaintext")
@Requires(property = "spec.name", value = "TfbLikeBenchmark")
static class PlainTextController {

private static final byte[] TEXT = "Hello, World!".getBytes(StandardCharsets.UTF_8);

@Get(value = "/", produces = MediaType.TEXT_PLAIN)
public byte[] getPlainText() {
return TEXT;
}
}
}
109 changes: 109 additions & 0 deletions benchmarks/src/typeCheckTest/java/example/TypeThrashingTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package example;

import io.micronaut.http.server.stack.FullHttpStackBenchmark;
import io.micronaut.http.server.stack.TfbLikeBenchmark;
import io.micronaut.test.typepollution.FocusListener;
import io.micronaut.test.typepollution.ThresholdFocusListener;
import io.micronaut.test.typepollution.TypePollutionTransformer;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;

import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class TypeThrashingTest {
static final int THRESHOLD = 10_000;

static {
FullHttpStackBenchmark.checkFtlThread = false;
}

private ThresholdFocusListener focusListener;

@BeforeAll
static void setupAgent() {
TypePollutionTransformer.install(net.bytebuddy.agent.ByteBuddyAgent.install());
}

@BeforeEach
void setUp() {
focusListener = new ThresholdFocusListener();
FocusListener.setFocusListener((concreteType, interfaceType) -> {
if (concreteType == DefaultFullHttpResponse.class) {
String culprit = StackWalker.getInstance().walk(s -> s.skip(1).dropWhile(f -> f.getClassName().startsWith("io.micronaut.test.")).findFirst().map(StackWalker.StackFrame::getClassName).orElse(null));
if (culprit != null && (culprit.startsWith("io.netty") || culprit.equals("io.micronaut.http.server.netty.handler.Compressor"))) {
// these DefaultFullHttpResponse flips are false positives, fixed by franz
return;
}
}

focusListener.onFocus(concreteType, interfaceType);
});
}

@AfterEach
void verifyNoTypeThrashing() {
FocusListener.setFocusListener(null);
Assertions.assertTrue(focusListener.checkThresholds(THRESHOLD), "Threshold exceeded, check logs.");
}

/**
* This is a sample method that demonstrates the thrashing detection. This test should fail
* when enabled.
*/
@SuppressWarnings("ConstantValue")
@Test
@Disabled
public void sample() {
Object c = new Concrete();
int j = 0;
for (int i = 0; i < THRESHOLD * 2; i++) {
if (c instanceof A) {
j++;
}
if (c instanceof B) {
j++;
}
}
System.out.println(j);
}

interface A {
}

interface B {
}

static class Concrete implements A, B {
}

@Test
public void testFromJmh() throws RunnerException {
Options opt = new OptionsBuilder()
.include(Stream.of(FullHttpStackBenchmark.class, TfbLikeBenchmark.class)
.map(Class::getName)
.collect(Collectors.joining("|", "(", ")"))
+ ".*")
.warmupIterations(0)
.measurementIterations(1)
.mode(Mode.SingleShotTime)
.timeUnit(TimeUnit.NANOSECONDS)
.forks(0)
.measurementBatchSize(THRESHOLD * 2)
.shouldFailOnError(true)
.build();

new Runner(opt).run();
}
}
Loading

0 comments on commit bb68d7c

Please sign in to comment.