Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added test for JerseyChunkedInputStreamClose #5759

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions connectors/netty-connector/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,25 @@
</plugins>
</build>

<profiles>
<profile>
<id>InaccessibleObjectException</id>
<activation><jdk>[12,)</jdk></activation>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<argLine>
--add-opens java.base/java.lang=ALL-UNNAMED
--add-opens java.base/java.lang.reflect=ALL-UNNAMED
</argLine>
</configuration>
</plugin>
</plugins>
</build>
</profile>
</profiles>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ public void operationComplete(io.netty.util.concurrent.Future<? super Void> futu
};
ch.closeFuture().addListener(closeListener);

final NettyEntityWriter entityWriter = NettyEntityWriter.getInstance(jerseyRequest, ch);
final NettyEntityWriter entityWriter = nettyEntityWriter(jerseyRequest, ch);
switch (entityWriter.getType()) {
case CHUNKED:
HttpUtil.setTransferEncodingChunked(nettyRequest, true);
Expand Down Expand Up @@ -523,6 +523,10 @@ public void run() {
}
}

/* package */ NettyEntityWriter nettyEntityWriter(ClientRequest clientRequest, Channel channel) {
return NettyEntityWriter.getInstance(clientRequest, channel);
}

private SSLContext getSslContext(Client client, ClientRequest request) {
Supplier<SSLContext> supplier = request.resolveProperty(ClientProperties.SSL_CONTEXT_SUPPLIER, Supplier.class);
return supplier == null ? client.getSslContext() : supplier.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,15 @@ public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception {

@Override
public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception {
try {
return readChunk0(allocator);
} catch (Exception e) {
closeOnThrowable();
throw e;
}
}

private ByteBuf readChunk0(ByteBufAllocator allocator) throws Exception {
if (!open) {
return null;
}
Expand Down Expand Up @@ -143,6 +151,14 @@ public long progress() {
return offset;
}

private void closeOnThrowable() {
try {
close();
} catch (Throwable t) {
// do not throw other throwable
}
}

@Override
public void close() throws IOException {

Expand Down Expand Up @@ -208,12 +224,12 @@ private void write(Provider<ByteBuffer> bufferSupplier) throws IOException {
try {
boolean queued = queue.offer(bufferSupplier.get(), WRITE_TIMEOUT, TimeUnit.MILLISECONDS);
if (!queued) {
close();
closeOnThrowable();
throw new IOException("Buffer overflow.");
}

} catch (InterruptedException e) {
close();
closeOnThrowable();
throw new IOException(e);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
/*
* Copyright (c) 2024 Oracle and/or its affiliates. All rights reserved.
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License v. 2.0, which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the
* Eclipse Public License v. 2.0 are satisfied: GNU General Public License,
* version 2 with the GNU Classpath Exception, which is available at
* https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
*/

package org.glassfish.jersey.netty.connector;

import io.netty.channel.Channel;
import org.glassfish.jersey.client.ClientConfig;
import org.glassfish.jersey.client.ClientProperties;
import org.glassfish.jersey.client.ClientRequest;
import org.glassfish.jersey.client.spi.Connector;
import org.glassfish.jersey.client.spi.ConnectorProvider;
import org.glassfish.jersey.netty.connector.internal.JerseyChunkedInput;
import org.glassfish.jersey.netty.connector.internal.NettyEntityWriter;
import org.glassfish.jersey.server.ResourceConfig;
import org.glassfish.jersey.test.JerseyTest;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.Invocation;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.Application;
import javax.ws.rs.core.Configuration;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedHashMap;
import javax.ws.rs.core.Response;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Proxy;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Level;
import java.util.logging.Logger;

public class ChunkedInputWriteErrorSimulationTest extends JerseyTest {
private static final String EXCEPTION_MSG = "BOGUS BUFFER OVERFLOW";
private static final AtomicReference<Throwable> caught = new AtomicReference<>(null);

public static class ClientThread extends Thread {

public static AtomicInteger count = new AtomicInteger();
public static String url;
public static int nLoops;

private static Client client;

public static void main(DequeOffer offer, String[] args) throws InterruptedException {
url = args[0];
int nThreads = Integer.parseInt(args[1]);
nLoops = Integer.parseInt(args[2]);
initClient(offer);
Thread[] threads = new Thread[nThreads];
for (int i = 0; i < nThreads; i++) {
threads[i] = new ClientThread();
threads[i].start();
}

for (int i = 0; i < nThreads; i++) {
threads[i].join();
}
// System.out.println("Processed calls: " + count);
}

private static void initClient(DequeOffer offer) {
ClientConfig defaultConfig = new ClientConfig();
defaultConfig.property(ClientProperties.CONNECT_TIMEOUT, 10 * 1000);
defaultConfig.property(ClientProperties.READ_TIMEOUT, 10 * 1000);
defaultConfig.connectorProvider(getJerseyChunkedInputModifiedNettyConnector(offer));
client = ClientBuilder.newBuilder()
.withConfig(defaultConfig)
.build();
}

public void doCall() {
CompletableFuture<Response> cf = invokeResponse().toCompletableFuture()
.whenComplete((rsp, t) -> {
if (t != null) {
// System.out.println(Thread.currentThread() + " async complete. Caught exception " + t);
// t.printStackTrace();
while (t.getCause() != null) {
t = t.getCause();
}
caught.set(t);
}
})
.handle((rsp, t) -> {
if (rsp != null) {
rsp.readEntity(String.class);
} else {
System.out.println(Thread.currentThread().getName() + " response is null");
}
return rsp;
}).exceptionally(t -> {
System.out.println("async complete. completed exceptionally " + t);
throw new RuntimeException(t);
});

try {
cf.get();
System.out.println("Done call " + count.incrementAndGet());
} catch (InterruptedException | ExecutionException ex) {
Logger.getLogger(ClientThread.class.getName()).log(Level.SEVERE, null, ex);
}
}

private static CompletionStage<Response> invokeResponse() {
WebTarget target = client.target(url);
MultivaluedHashMap hdrs = new MultivaluedHashMap<>();
StringBuilder sb = new StringBuilder("{");
for (int i = 0; i < 10000; i++) {
sb.append("\"fname\":\"foo\", \"lname\":\"bar\"");
}
sb.append("}");
String jsonPayload = sb.toString();
Invocation.Builder builder = ((WebTarget) target).request().headers(hdrs);
return builder.rx().method("POST", Entity.entity(jsonPayload, MediaType.APPLICATION_JSON_TYPE));
}

@Override
public void run() {
for (int i = 0; i < nLoops; i++) {
try {
doCall();
} catch (Throwable t) {
throw new RuntimeException(t);
}
}
}
}

@Path("/console")
public static class HangingEndpoint {
@Path("/login")
@POST
public String post(String entity) {
return "Welcome";
}
}

@Override
protected Application configure() {
return new ResourceConfig(HangingEndpoint.class);
}

@Test
public void testNoHangOnOfferInterrupt() throws InterruptedException {
String path = getBaseUri() + "console/login";
ClientThread.main(new InterruptedExceptionOffer(), new String[] {path, "5", "10"});
Assertions.assertTrue(caught.get().getMessage().contains(EXCEPTION_MSG));
}

@Test
public void testNoHangOnPollInterrupt() throws InterruptedException {
String path = getBaseUri() + "console/login";
ClientThread.main(new DequePoll(), new String[] {path, "5", "10"});
Assertions.assertNotNull(caught.get());
}

@Test
public void testNoHangOnOfferNoData() throws InterruptedException {
String path = getBaseUri() + "console/login";
ClientThread.main(new ReturnFalseOffer(), new String[] {path, "5", "10"});
Assertions.assertTrue(caught.get().getMessage().contains("Buffer overflow")); //JerseyChunkedInput
Thread.sleep(1_000L); // Sleep for the server to finish
}

private interface DequeOffer {
public boolean offer(ByteBuffer e, long timeout, TimeUnit unit) throws InterruptedException;
}

private static class InterruptedExceptionOffer implements DequeOffer {
private AtomicInteger ai = new AtomicInteger(0);

@Override
public boolean offer(ByteBuffer e, long timeout, TimeUnit unit) throws InterruptedException {
if ((ai.getAndIncrement() % 10) == 0) {
throw new InterruptedException(EXCEPTION_MSG);
}
return true;
}
}

private static class ReturnFalseOffer implements DequeOffer {
private AtomicInteger ai = new AtomicInteger(0);
@Override
public boolean offer(ByteBuffer e, long timeout, TimeUnit unit) throws InterruptedException {
return !((ai.getAndIncrement() % 10) == 1);
}
}

private static class DequePoll extends InterruptedExceptionOffer {
}


private static ConnectorProvider getJerseyChunkedInputModifiedNettyConnector(DequeOffer offer) {
return new ConnectorProvider() {
@Override
public Connector getConnector(Client client, Configuration runtimeConfig) {
return new NettyConnector(client) {
NettyEntityWriter nettyEntityWriter(ClientRequest clientRequest, Channel channel) {
NettyEntityWriter wrapped = NettyEntityWriter.getInstance(clientRequest, channel);

JerseyChunkedInput chunkedInput = (JerseyChunkedInput) wrapped.getChunkedInput();
try {
Field field = JerseyChunkedInput.class.getDeclaredField("queue");
field.setAccessible(true);

removeFinal(field);

field.set(chunkedInput, new LinkedBlockingDeque<ByteBuffer>() {
@Override
public boolean offer(ByteBuffer e, long timeout, TimeUnit unit) throws InterruptedException {
if (!DequePoll.class.isInstance(offer) && !offer.offer(e, timeout, unit)) {
return false;
}
return super.offer(e, timeout, unit);
}

@Override
public ByteBuffer poll(long timeout, TimeUnit unit) throws InterruptedException {
if (DequePoll.class.isInstance(offer)) {
offer.offer(null, timeout, unit);
}
return super.poll(timeout, unit);
}
});

} catch (Exception e) {
throw new RuntimeException(e);
}

NettyEntityWriter proxy = (NettyEntityWriter) Proxy.newProxyInstance(
ConnectorProvider.class.getClassLoader(), new Class[]{NettyEntityWriter.class},
(proxy1, method, args) -> {
if (method.getName().equals("readChunk")) {
try {
return method.invoke(wrapped, args);
} catch (RuntimeException e) {
// consume
}
}
return method.invoke(wrapped, args);
});
return proxy;
}
};
}
};
}

public static void removeFinal(Field field) throws RuntimeException {
try {
Method[] classMethods = Class.class.getDeclaredMethods();
Method declaredFieldMethod = Arrays
.stream(classMethods).filter(x -> Objects.equals(x.getName(), "getDeclaredFields0"))
.findAny().orElseThrow(() -> new NoSuchElementException("No value present"));
declaredFieldMethod.setAccessible(true);
Field[] declaredFieldsOfField = (Field[]) declaredFieldMethod.invoke(Field.class, false);
Field modifiersField = Arrays
.stream(declaredFieldsOfField).filter(x -> Objects.equals(x.getName(), "modifiers"))
.findAny().orElseThrow(() -> new NoSuchElementException("No value present"));
modifiersField.setAccessible(true);
modifiersField.setInt(field, field.getModifiers() & ~Modifier.FINAL);
} catch (RuntimeException re) {
throw re;
} catch (Exception e) {
throw new RuntimeException(e);
}
}

}