Skip to content

Commit

Permalink
support http2 for java client (#543)
Browse files Browse the repository at this point in the history
* support http2 for java client

* add example for use http2 and ssl

* update example

* format code style

* add ping failed stack

* fix socket open for tls

* revert ping failed stack

* fix comment:add log for close error & add isOpen logic
  • Loading branch information
Nicole00 committed Sep 6, 2023
1 parent 3429ca1 commit f474ecf
Show file tree
Hide file tree
Showing 17 changed files with 670 additions and 25 deletions.
18 changes: 16 additions & 2 deletions client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
<plugin>
<groupId>org.sonatype.plugins</groupId>
<artifactId>nexus-staging-maven-plugin</artifactId>
<version>1.6.8</version>
<extensions>true</extensions>
<configuration>
<serverId>ossrh</serverId>
Expand Down Expand Up @@ -90,8 +91,8 @@
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.1</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
<source>8</source>
<target>8</target>
<generatedSourcesDirectory>src/main/generated</generatedSourcesDirectory>
</configuration>
</plugin>
Expand Down Expand Up @@ -262,5 +263,18 @@
<artifactId>jts-core</artifactId>
<version>1.16.1</version>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
<version>3.14.0</version>
</dependency>

<!-- https://mvnrepository.com/artifact/org.mortbay.jetty.alpn/alpn-boot -->
<dependency>
<groupId>org.mortbay.jetty.alpn</groupId>
<artifactId>alpn-boot</artifactId>
<version>8.1.13.v20181017</version>
</dependency>

</dependencies>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/* Copyright (c) 2023 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

package com.facebook.thrift.transport;

import java.util.Arrays;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import okhttp3.OkHttpClient;
import okhttp3.Protocol;

public class OkHttp3Util {
private static OkHttpClient client;

private OkHttp3Util() {
}

public static OkHttpClient getClient(int connectTimeout, int readTimeout,
SSLSocketFactory sslFactory,
TrustManager trustManager) {
if (client == null) {
synchronized (OkHttp3Util.class) {
if (client == null) {
// Create OkHttpClient builder
OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder()
.connectTimeout(connectTimeout, TimeUnit.MILLISECONDS)
.writeTimeout(readTimeout, TimeUnit.MILLISECONDS)
.readTimeout(readTimeout, TimeUnit.MILLISECONDS);
if (sslFactory != null) {
clientBuilder.sslSocketFactory(sslFactory, (X509TrustManager) trustManager);
clientBuilder.protocols(Arrays.asList(Protocol.HTTP_2, Protocol.HTTP_1_1));
} else {
// config the http2 prior knowledge
clientBuilder.protocols(Arrays.asList(Protocol.H2_PRIOR_KNOWLEDGE));
}
client = clientBuilder.build();
}
}
}
return client;
}

public static void close(){
if (client != null) {
client.connectionPool().evictAll();
client.dispatcher().executorService().shutdown();
client = null;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/* Copyright (c) 2023 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

package com.facebook.thrift.transport;

import com.facebook.thrift.utils.Logger;
import java.io.ByteArrayOutputStream;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.ResponseBody;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

public class THttp2Client extends TTransport {
private static final Logger LOGGER = Logger.getLogger(THttp2Client.class.getName());

private final ByteArrayOutputStream requestBuffer = new ByteArrayOutputStream();
private ResponseBody responseBody = null;
private Map<String, String> customHeaders = null;
private static final Map<String, String> defaultHeaders = getDefaultHeaders();

private OkHttpClient client;
private final SSLSocketFactory sslFactory;

private final TrustManager trustManager;
private final String url;
private int connectTimeout = 0;
private int readTimeout = 0;


public THttp2Client(String url) throws TTransportException {
this(url, null, null);
}

public THttp2Client(String url, SSLSocketFactory sslFactory, TrustManager trustManager) throws TTransportException {
this.url = url;
this.sslFactory = sslFactory;
this.trustManager = trustManager;
}

public THttp2Client setConnectTimeout(int timeout) {
connectTimeout = timeout;
return this;
}

public THttp2Client setReadTimeout(int timeout) {
readTimeout = timeout;
return this;
}

public THttp2Client setCustomHeaders(Map<String, String> headers) {
customHeaders = headers;
return this;
}

public THttp2Client setCustomHeader(String key, String value) {
if (customHeaders == null) {
customHeaders = new HashMap<>();
}
customHeaders.put(key, value);
return this;
}

public void open() {
client = OkHttp3Util.getClient(connectTimeout, readTimeout, sslFactory, trustManager);
}

public void close() {
try {
if (responseBody != null) {
responseBody.close();
responseBody = null;
}

requestBuffer.close();
} catch (IOException e) {
LOGGER.warn(e.getMessage());
}
OkHttp3Util.close();
}

public boolean isOpen() {
return client != null;
}

public int read(byte[] buf, int off, int len) throws TTransportException {
if (responseBody == null) {
throw new TTransportException("Response buffer is empty, no request.");
}
try {
int ret = responseBody.byteStream().read(buf, off, len);
if (ret == -1) {
throw new TTransportException("No more data available.");
}
return ret;
} catch (IOException iox) {
throw new TTransportException(iox);
}
}

public void write(byte[] buf, int off, int len) {
requestBuffer.write(buf, off, len);
}

public void flush() throws TTransportException {
if (null == client) {
throw new TTransportException("Null HttpClient, aborting.");
}

// Extract request and reset buffer
byte[] data = requestBuffer.toByteArray();
requestBuffer.reset();
try {

// Create request object
Request.Builder requestBuilder = new Request.Builder()
.url(url)
.post(RequestBody.create(MediaType.parse("application/x-thrift"), data));

defaultHeaders.forEach(requestBuilder::header);
if (customHeaders != null) {
customHeaders.forEach(requestBuilder::header);
}

Request request = requestBuilder.build();

// Make the request
Response response = client.newCall(request).execute();
if (!response.isSuccessful()) {
throw new TTransportException("HTTP Response code: " + response.code());
}

// Read the response
responseBody = response.body();
} catch (IOException iox) {
throw new TTransportException(iox);
}
}

private static Map<String, String> getDefaultHeaders() {
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/x-thrift");
headers.put("Accept", "application/x-thrift");
headers.put("User-Agent", "Java/THttpClient");
return headers;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package com.vesoft.nebula.client.graph;

import com.vesoft.nebula.client.graph.data.SSLParam;
import com.vesoft.nebula.client.graph.net.NebulaPool;
import java.io.Serializable;

public class NebulaPoolConfig implements Serializable {
Expand Down Expand Up @@ -43,6 +42,9 @@ public class NebulaPoolConfig implements Serializable {
// SSL param is required if ssl is turned on
private SSLParam sslParam = null;

// Set if use http2 protocol
private boolean useHttp2 = false;

public boolean isEnableSsl() {
return enableSsl;
}
Expand Down Expand Up @@ -121,4 +123,13 @@ public NebulaPoolConfig setMinClusterHealthRate(double minClusterHealthRate) {
this.minClusterHealthRate = minClusterHealthRate;
return this;
}

public boolean isUseHttp2() {
return useHttp2;
}

public NebulaPoolConfig setUseHttp2(boolean useHttp2) {
this.useHttp2 = useHttp2;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,14 @@ private NebulaSession createSessionObject(SessionState state)
// reconnect with all available address
while (tryConnect-- > 0) {
try {
connection.open(getAddress(), sessionPoolConfig.getTimeout());
if (sessionPoolConfig.isEnableSsl()) {
connection.open(getAddress(), sessionPoolConfig.getTimeout(),
sessionPoolConfig.getSslParam(),
sessionPoolConfig.isUseHttp2());
} else {
connection.open(getAddress(), sessionPoolConfig.getTimeout(),
sessionPoolConfig.isUseHttp2());
}
break;
} catch (Exception e) {
if (tryConnect == 0 || !reconnect) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package com.vesoft.nebula.client.graph;

import com.vesoft.nebula.client.graph.data.HostAddress;
import com.vesoft.nebula.client.graph.data.SSLParam;
import java.io.Serializable;
import java.util.List;

Expand Down Expand Up @@ -50,6 +51,14 @@ public class SessionPoolConfig implements Serializable {
// whether reconnect when create session using a broken graphd server
private boolean reconnect = false;

// Set to true to turn on ssl encrypted traffic
private boolean enableSsl = false;

// SSL param is required if ssl is turned on
private SSLParam sslParam = null;

private boolean useHttp2 = false;


public SessionPoolConfig(List<HostAddress> addresses,
String spaceName,
Expand Down Expand Up @@ -207,6 +216,33 @@ public SessionPoolConfig setReconnect(boolean reconnect) {
return this;
}

public boolean isEnableSsl() {
return enableSsl;
}

public SessionPoolConfig setEnableSsl(boolean enableSsl) {
this.enableSsl = enableSsl;
return this;
}

public SSLParam getSslParam() {
return sslParam;
}

public SessionPoolConfig setSslParam(SSLParam sslParam) {
this.sslParam = sslParam;
return this;
}

public boolean isUseHttp2() {
return useHttp2;
}

public SessionPoolConfig setUseHttp2(boolean useHttp2) {
this.useHttp2 = useHttp2;
return this;
}

@Override
public String toString() {
return "SessionPoolConfig{"
Expand All @@ -222,6 +258,9 @@ public String toString() {
+ ", retryTimes=" + retryTimes
+ ", intervalTIme=" + intervalTime
+ ", reconnect=" + reconnect
+ ", enableSsl=" + enableSsl
+ ",sslParam=" + sslParam
+ ", useHttp2=" + useHttp2
+ '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ public SyncConnection create() throws IOErrorException, ClientServerIncompatible
throw new IllegalArgumentException("SSL Param is required when enableSsl "
+ "is set to true");
}
conn.open(address, config.getTimeout(), config.getSslParam());
conn.open(address, config.getTimeout(),
config.getSslParam(), config.isUseHttp2());
} else {
conn.open(address, config.getTimeout());
conn.open(address, config.getTimeout(), config.isUseHttp2());
}
return conn;
} catch (IOErrorException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,17 @@ public HostAddress getServerAddress() {
public abstract void open(HostAddress address, int timeout, SSLParam sslParam)
throws IOErrorException, ClientServerIncompatibleException;

public abstract void open(HostAddress address, int timeout,
SSLParam sslParam, boolean isUseHttp2)
throws IOErrorException, ClientServerIncompatibleException;


public abstract void open(HostAddress address, int timeout) throws IOErrorException,
ClientServerIncompatibleException;

public abstract void open(HostAddress address, int timeout, boolean isUseHttp2)
throws IOErrorException, ClientServerIncompatibleException;

public abstract void reopen() throws IOErrorException, ClientServerIncompatibleException;

public abstract void close();
Expand Down
Loading

0 comments on commit f474ecf

Please sign in to comment.