From f474ecf63dacf00fed74638d14e809c0b99f8e01 Mon Sep 17 00:00:00 2001 From: Anqi Date: Thu, 7 Sep 2023 07:56:19 +0800 Subject: [PATCH] support http2 for java client (#543) * support http2 for java client * add example for use http2 and ssl * update example * format code style * add ping failed stack * fix socket open for tls * revert ping failed stack * fix comment:add log for close error & add isOpen logic --- client/pom.xml | 18 +- .../thrift/transport/OkHttp3Util.java | 54 ++++++ .../thrift/transport/THttp2Client.java | 155 +++++++++++++++++ .../nebula/client/graph/NebulaPoolConfig.java | 13 +- .../nebula/client/graph/SessionPool.java | 9 +- .../client/graph/SessionPoolConfig.java | 39 +++++ .../client/graph/net/ConnObjectPool.java | 5 +- .../nebula/client/graph/net/Connection.java | 7 + .../nebula/client/graph/net/NebulaPool.java | 6 +- .../graph/net/RoundRobinLoadBalancer.java | 21 ++- .../client/graph/net/SyncConnection.java | 97 +++++++++-- .../java/com/vesoft/nebula/util/SslUtil.java | 13 +- examples/pom.xml | 33 ++++ .../GraphSessionPoolWithHttp2Example.java | 158 ++++++++++++++++++ examples/src/main/resources/ssl/root.crt | 21 +++ examples/src/main/resources/ssl/server.crt | 19 +++ examples/src/main/resources/ssl/server.key | 27 +++ 17 files changed, 670 insertions(+), 25 deletions(-) create mode 100644 client/src/main/fbthrift/com/facebook/thrift/transport/OkHttp3Util.java create mode 100644 client/src/main/fbthrift/com/facebook/thrift/transport/THttp2Client.java create mode 100644 examples/src/main/java/com/vesoft/nebula/examples/GraphSessionPoolWithHttp2Example.java create mode 100644 examples/src/main/resources/ssl/root.crt create mode 100644 examples/src/main/resources/ssl/server.crt create mode 100755 examples/src/main/resources/ssl/server.key diff --git a/client/pom.xml b/client/pom.xml index 28cd14ae9..08ba5119d 100644 --- a/client/pom.xml +++ b/client/pom.xml @@ -42,6 +42,7 @@ org.sonatype.plugins nexus-staging-maven-plugin + 1.6.8 true ossrh @@ -90,8 +91,8 @@ maven-compiler-plugin 3.8.1 - 1.8 - 1.8 + 8 + 8 src/main/generated @@ -262,5 +263,18 @@ jts-core 1.16.1 + + com.squareup.okhttp3 + okhttp + 3.14.0 + + + + + org.mortbay.jetty.alpn + alpn-boot + 8.1.13.v20181017 + + diff --git a/client/src/main/fbthrift/com/facebook/thrift/transport/OkHttp3Util.java b/client/src/main/fbthrift/com/facebook/thrift/transport/OkHttp3Util.java new file mode 100644 index 000000000..7e29057b5 --- /dev/null +++ b/client/src/main/fbthrift/com/facebook/thrift/transport/OkHttp3Util.java @@ -0,0 +1,54 @@ +/* Copyright (c) 2023 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.facebook.thrift.transport; + +import java.util.Arrays; +import java.util.concurrent.TimeUnit; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; +import okhttp3.OkHttpClient; +import okhttp3.Protocol; + +public class OkHttp3Util { + private static OkHttpClient client; + + private OkHttp3Util() { + } + + public static OkHttpClient getClient(int connectTimeout, int readTimeout, + SSLSocketFactory sslFactory, + TrustManager trustManager) { + if (client == null) { + synchronized (OkHttp3Util.class) { + if (client == null) { + // Create OkHttpClient builder + OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder() + .connectTimeout(connectTimeout, TimeUnit.MILLISECONDS) + .writeTimeout(readTimeout, TimeUnit.MILLISECONDS) + .readTimeout(readTimeout, TimeUnit.MILLISECONDS); + if (sslFactory != null) { + clientBuilder.sslSocketFactory(sslFactory, (X509TrustManager) trustManager); + clientBuilder.protocols(Arrays.asList(Protocol.HTTP_2, Protocol.HTTP_1_1)); + } else { + // config the http2 prior knowledge + clientBuilder.protocols(Arrays.asList(Protocol.H2_PRIOR_KNOWLEDGE)); + } + client = clientBuilder.build(); + } + } + } + return client; + } + + public static void close(){ + if (client != null) { + client.connectionPool().evictAll(); + client.dispatcher().executorService().shutdown(); + client = null; + } + } +} diff --git a/client/src/main/fbthrift/com/facebook/thrift/transport/THttp2Client.java b/client/src/main/fbthrift/com/facebook/thrift/transport/THttp2Client.java new file mode 100644 index 000000000..334be15d1 --- /dev/null +++ b/client/src/main/fbthrift/com/facebook/thrift/transport/THttp2Client.java @@ -0,0 +1,155 @@ +/* Copyright (c) 2023 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.facebook.thrift.transport; + +import com.facebook.thrift.utils.Logger; +import java.io.ByteArrayOutputStream; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import okhttp3.ResponseBody; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +public class THttp2Client extends TTransport { + private static final Logger LOGGER = Logger.getLogger(THttp2Client.class.getName()); + + private final ByteArrayOutputStream requestBuffer = new ByteArrayOutputStream(); + private ResponseBody responseBody = null; + private Map customHeaders = null; + private static final Map defaultHeaders = getDefaultHeaders(); + + private OkHttpClient client; + private final SSLSocketFactory sslFactory; + + private final TrustManager trustManager; + private final String url; + private int connectTimeout = 0; + private int readTimeout = 0; + + + public THttp2Client(String url) throws TTransportException { + this(url, null, null); + } + + public THttp2Client(String url, SSLSocketFactory sslFactory, TrustManager trustManager) throws TTransportException { + this.url = url; + this.sslFactory = sslFactory; + this.trustManager = trustManager; + } + + public THttp2Client setConnectTimeout(int timeout) { + connectTimeout = timeout; + return this; + } + + public THttp2Client setReadTimeout(int timeout) { + readTimeout = timeout; + return this; + } + + public THttp2Client setCustomHeaders(Map headers) { + customHeaders = headers; + return this; + } + + public THttp2Client setCustomHeader(String key, String value) { + if (customHeaders == null) { + customHeaders = new HashMap<>(); + } + customHeaders.put(key, value); + return this; + } + + public void open() { + client = OkHttp3Util.getClient(connectTimeout, readTimeout, sslFactory, trustManager); + } + + public void close() { + try { + if (responseBody != null) { + responseBody.close(); + responseBody = null; + } + + requestBuffer.close(); + } catch (IOException e) { + LOGGER.warn(e.getMessage()); + } + OkHttp3Util.close(); + } + + public boolean isOpen() { + return client != null; + } + + public int read(byte[] buf, int off, int len) throws TTransportException { + if (responseBody == null) { + throw new TTransportException("Response buffer is empty, no request."); + } + try { + int ret = responseBody.byteStream().read(buf, off, len); + if (ret == -1) { + throw new TTransportException("No more data available."); + } + return ret; + } catch (IOException iox) { + throw new TTransportException(iox); + } + } + + public void write(byte[] buf, int off, int len) { + requestBuffer.write(buf, off, len); + } + + public void flush() throws TTransportException { + if (null == client) { + throw new TTransportException("Null HttpClient, aborting."); + } + + // Extract request and reset buffer + byte[] data = requestBuffer.toByteArray(); + requestBuffer.reset(); + try { + + // Create request object + Request.Builder requestBuilder = new Request.Builder() + .url(url) + .post(RequestBody.create(MediaType.parse("application/x-thrift"), data)); + + defaultHeaders.forEach(requestBuilder::header); + if (customHeaders != null) { + customHeaders.forEach(requestBuilder::header); + } + + Request request = requestBuilder.build(); + + // Make the request + Response response = client.newCall(request).execute(); + if (!response.isSuccessful()) { + throw new TTransportException("HTTP Response code: " + response.code()); + } + + // Read the response + responseBody = response.body(); + } catch (IOException iox) { + throw new TTransportException(iox); + } + } + + private static Map getDefaultHeaders() { + Map headers = new HashMap<>(); + headers.put("Content-Type", "application/x-thrift"); + headers.put("Accept", "application/x-thrift"); + headers.put("User-Agent", "Java/THttpClient"); + return headers; + } +} diff --git a/client/src/main/java/com/vesoft/nebula/client/graph/NebulaPoolConfig.java b/client/src/main/java/com/vesoft/nebula/client/graph/NebulaPoolConfig.java index 3682063fa..0189366a5 100644 --- a/client/src/main/java/com/vesoft/nebula/client/graph/NebulaPoolConfig.java +++ b/client/src/main/java/com/vesoft/nebula/client/graph/NebulaPoolConfig.java @@ -6,7 +6,6 @@ package com.vesoft.nebula.client.graph; import com.vesoft.nebula.client.graph.data.SSLParam; -import com.vesoft.nebula.client.graph.net.NebulaPool; import java.io.Serializable; public class NebulaPoolConfig implements Serializable { @@ -43,6 +42,9 @@ public class NebulaPoolConfig implements Serializable { // SSL param is required if ssl is turned on private SSLParam sslParam = null; + // Set if use http2 protocol + private boolean useHttp2 = false; + public boolean isEnableSsl() { return enableSsl; } @@ -121,4 +123,13 @@ public NebulaPoolConfig setMinClusterHealthRate(double minClusterHealthRate) { this.minClusterHealthRate = minClusterHealthRate; return this; } + + public boolean isUseHttp2() { + return useHttp2; + } + + public NebulaPoolConfig setUseHttp2(boolean useHttp2) { + this.useHttp2 = useHttp2; + return this; + } } diff --git a/client/src/main/java/com/vesoft/nebula/client/graph/SessionPool.java b/client/src/main/java/com/vesoft/nebula/client/graph/SessionPool.java index d6b54e0af..f1130150e 100644 --- a/client/src/main/java/com/vesoft/nebula/client/graph/SessionPool.java +++ b/client/src/main/java/com/vesoft/nebula/client/graph/SessionPool.java @@ -349,7 +349,14 @@ private NebulaSession createSessionObject(SessionState state) // reconnect with all available address while (tryConnect-- > 0) { try { - connection.open(getAddress(), sessionPoolConfig.getTimeout()); + if (sessionPoolConfig.isEnableSsl()) { + connection.open(getAddress(), sessionPoolConfig.getTimeout(), + sessionPoolConfig.getSslParam(), + sessionPoolConfig.isUseHttp2()); + } else { + connection.open(getAddress(), sessionPoolConfig.getTimeout(), + sessionPoolConfig.isUseHttp2()); + } break; } catch (Exception e) { if (tryConnect == 0 || !reconnect) { diff --git a/client/src/main/java/com/vesoft/nebula/client/graph/SessionPoolConfig.java b/client/src/main/java/com/vesoft/nebula/client/graph/SessionPoolConfig.java index 898739c31..5a9fb3fcf 100644 --- a/client/src/main/java/com/vesoft/nebula/client/graph/SessionPoolConfig.java +++ b/client/src/main/java/com/vesoft/nebula/client/graph/SessionPoolConfig.java @@ -6,6 +6,7 @@ package com.vesoft.nebula.client.graph; import com.vesoft.nebula.client.graph.data.HostAddress; +import com.vesoft.nebula.client.graph.data.SSLParam; import java.io.Serializable; import java.util.List; @@ -50,6 +51,14 @@ public class SessionPoolConfig implements Serializable { // whether reconnect when create session using a broken graphd server private boolean reconnect = false; + // Set to true to turn on ssl encrypted traffic + private boolean enableSsl = false; + + // SSL param is required if ssl is turned on + private SSLParam sslParam = null; + + private boolean useHttp2 = false; + public SessionPoolConfig(List addresses, String spaceName, @@ -207,6 +216,33 @@ public SessionPoolConfig setReconnect(boolean reconnect) { return this; } + public boolean isEnableSsl() { + return enableSsl; + } + + public SessionPoolConfig setEnableSsl(boolean enableSsl) { + this.enableSsl = enableSsl; + return this; + } + + public SSLParam getSslParam() { + return sslParam; + } + + public SessionPoolConfig setSslParam(SSLParam sslParam) { + this.sslParam = sslParam; + return this; + } + + public boolean isUseHttp2() { + return useHttp2; + } + + public SessionPoolConfig setUseHttp2(boolean useHttp2) { + this.useHttp2 = useHttp2; + return this; + } + @Override public String toString() { return "SessionPoolConfig{" @@ -222,6 +258,9 @@ public String toString() { + ", retryTimes=" + retryTimes + ", intervalTIme=" + intervalTime + ", reconnect=" + reconnect + + ", enableSsl=" + enableSsl + + ",sslParam=" + sslParam + + ", useHttp2=" + useHttp2 + '}'; } } diff --git a/client/src/main/java/com/vesoft/nebula/client/graph/net/ConnObjectPool.java b/client/src/main/java/com/vesoft/nebula/client/graph/net/ConnObjectPool.java index bb5eb383e..dba3a02b9 100644 --- a/client/src/main/java/com/vesoft/nebula/client/graph/net/ConnObjectPool.java +++ b/client/src/main/java/com/vesoft/nebula/client/graph/net/ConnObjectPool.java @@ -39,9 +39,10 @@ public SyncConnection create() throws IOErrorException, ClientServerIncompatible throw new IllegalArgumentException("SSL Param is required when enableSsl " + "is set to true"); } - conn.open(address, config.getTimeout(), config.getSslParam()); + conn.open(address, config.getTimeout(), + config.getSslParam(), config.isUseHttp2()); } else { - conn.open(address, config.getTimeout()); + conn.open(address, config.getTimeout(), config.isUseHttp2()); } return conn; } catch (IOErrorException e) { diff --git a/client/src/main/java/com/vesoft/nebula/client/graph/net/Connection.java b/client/src/main/java/com/vesoft/nebula/client/graph/net/Connection.java index 7fc9ba131..47abf3524 100644 --- a/client/src/main/java/com/vesoft/nebula/client/graph/net/Connection.java +++ b/client/src/main/java/com/vesoft/nebula/client/graph/net/Connection.java @@ -19,10 +19,17 @@ public HostAddress getServerAddress() { public abstract void open(HostAddress address, int timeout, SSLParam sslParam) throws IOErrorException, ClientServerIncompatibleException; + public abstract void open(HostAddress address, int timeout, + SSLParam sslParam, boolean isUseHttp2) + throws IOErrorException, ClientServerIncompatibleException; + public abstract void open(HostAddress address, int timeout) throws IOErrorException, ClientServerIncompatibleException; + public abstract void open(HostAddress address, int timeout, boolean isUseHttp2) + throws IOErrorException, ClientServerIncompatibleException; + public abstract void reopen() throws IOErrorException, ClientServerIncompatibleException; public abstract void close(); diff --git a/client/src/main/java/com/vesoft/nebula/client/graph/net/NebulaPool.java b/client/src/main/java/com/vesoft/nebula/client/graph/net/NebulaPool.java index 67362cbc6..6090da72b 100644 --- a/client/src/main/java/com/vesoft/nebula/client/graph/net/NebulaPool.java +++ b/client/src/main/java/com/vesoft/nebula/client/graph/net/NebulaPool.java @@ -86,9 +86,9 @@ public boolean init(List addresses, NebulaPoolConfig config) this.waitTime = config.getWaitTime(); this.loadBalancer = config.isEnableSsl() ? new RoundRobinLoadBalancer(addresses, config.getTimeout(), config.getSslParam(), - config.getMinClusterHealthRate()) + config.getMinClusterHealthRate(), config.isUseHttp2()) : new RoundRobinLoadBalancer(addresses, config.getTimeout(), - config.getMinClusterHealthRate()); + config.getMinClusterHealthRate(),config.isUseHttp2()); ConnObjectPool objectPool = new ConnObjectPool(this.loadBalancer, config); this.objectPool = new GenericObjectPool<>(objectPool); GenericObjectPoolConfig objConfig = new GenericObjectPoolConfig(); @@ -188,7 +188,7 @@ public int getWaitersNum() { protected void updateServerStatus() { checkNoInitAndClosed(); if (objectPool.getFactory() instanceof ConnObjectPool) { - ((ConnObjectPool)objectPool.getFactory()).updateServerStatus(); + ((ConnObjectPool) objectPool.getFactory()).updateServerStatus(); } } diff --git a/client/src/main/java/com/vesoft/nebula/client/graph/net/RoundRobinLoadBalancer.java b/client/src/main/java/com/vesoft/nebula/client/graph/net/RoundRobinLoadBalancer.java index e245141cf..c4da59766 100644 --- a/client/src/main/java/com/vesoft/nebula/client/graph/net/RoundRobinLoadBalancer.java +++ b/client/src/main/java/com/vesoft/nebula/client/graph/net/RoundRobinLoadBalancer.java @@ -27,22 +27,35 @@ public class RoundRobinLoadBalancer implements LoadBalancer { private final int delayTime = 60; // Unit seconds private final ScheduledExecutorService schedule = Executors.newScheduledThreadPool(1); private SSLParam sslParam; - private boolean enabledSsl; + private boolean enabledSsl = false; + + private boolean useHttp2 = false; public RoundRobinLoadBalancer(List addresses, int timeout, double minClusterHealthRate) { + this(addresses, timeout, minClusterHealthRate, false); + } + + public RoundRobinLoadBalancer(List addresses, int timeout, + double minClusterHealthRate, boolean useHttp2) { this.timeout = timeout; for (HostAddress addr : addresses) { this.addresses.add(addr); this.serversStatus.put(addr, S_BAD); } this.minClusterHealthRate = minClusterHealthRate; + this.useHttp2 = useHttp2; schedule.scheduleAtFixedRate(this::scheduleTask, 0, delayTime, TimeUnit.SECONDS); } public RoundRobinLoadBalancer(List addresses, int timeout, SSLParam sslParam, double minClusterHealthRate) { - this(addresses, timeout, minClusterHealthRate); + this(addresses, timeout, sslParam, minClusterHealthRate, false); + } + + public RoundRobinLoadBalancer(List addresses, int timeout, SSLParam sslParam, + double minClusterHealthRate, boolean useHttp2) { + this(addresses, timeout, minClusterHealthRate, useHttp2); this.sslParam = sslParam; this.enabledSsl = true; } @@ -82,9 +95,9 @@ public boolean ping(HostAddress addr) { try { Connection connection = new SyncConnection(); if (enabledSsl) { - connection.open(addr, this.timeout, sslParam); + connection.open(addr, this.timeout, sslParam, useHttp2); } else { - connection.open(addr, this.timeout); + connection.open(addr, this.timeout, useHttp2); } boolean pong = connection.ping(); connection.close(); diff --git a/client/src/main/java/com/vesoft/nebula/client/graph/net/SyncConnection.java b/client/src/main/java/com/vesoft/nebula/client/graph/net/SyncConnection.java index b7c853feb..641558e19 100644 --- a/client/src/main/java/com/vesoft/nebula/client/graph/net/SyncConnection.java +++ b/client/src/main/java/com/vesoft/nebula/client/graph/net/SyncConnection.java @@ -6,10 +6,12 @@ package com.vesoft.nebula.client.graph.net; import com.facebook.thrift.TException; +import com.facebook.thrift.protocol.TBinaryProtocol; import com.facebook.thrift.protocol.TCompactProtocol; import com.facebook.thrift.protocol.THeaderProtocol; import com.facebook.thrift.protocol.TProtocol; import com.facebook.thrift.transport.THeaderTransport; +import com.facebook.thrift.transport.THttp2Client; import com.facebook.thrift.transport.TSocket; import com.facebook.thrift.transport.TTransport; import com.facebook.thrift.transport.TTransportException; @@ -33,6 +35,7 @@ import java.util.Collections; import java.util.Map; import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,23 +44,30 @@ public class SyncConnection extends Connection { private static final Logger LOGGER = LoggerFactory.getLogger(SyncConnection.class); - protected THeaderTransport transport = null; - protected THeaderProtocol protocol = null; + protected TTransport transport = null; + protected TProtocol protocol = null; private GraphService.Client client = null; private int timeout = 0; private SSLParam sslParam = null; private boolean enabledSsl = false; private SSLSocketFactory sslSocketFactory = null; + private boolean useHttp2 = false; @Override public void open(HostAddress address, int timeout, SSLParam sslParam) throws IOErrorException, ClientServerIncompatibleException { - try { + this.open(address, timeout, sslParam, false); + } + @Override + public void open(HostAddress address, int timeout, SSLParam sslParam, boolean isUseHttp2) + throws IOErrorException, ClientServerIncompatibleException { + try { this.serverAddr = address; this.timeout = timeout <= 0 ? Integer.MAX_VALUE : timeout; this.enabledSsl = true; this.sslParam = sslParam; + this.useHttp2 = isUseHttp2; if (sslSocketFactory == null) { if (sslParam.getSignMode() == SSLParam.SignMode.CA_SIGNED) { sslSocketFactory = @@ -67,10 +77,12 @@ public void open(HostAddress address, int timeout, SSLParam sslParam) SslUtil.getSSLSocketFactoryWithoutCA((SelfSignedSSLParam) sslParam); } } - this.transport = new THeaderTransport(new TSocket( - sslSocketFactory.createSocket(address.getHost(), - address.getPort()), this.timeout, this.timeout)); - this.protocol = new THeaderProtocol(transport); + if (isUseHttp2) { + getProtocolWithTlsHttp2(); + } else { + getProtocolForTls(); + } + client = new GraphService.Client(protocol); // check if client version matches server version @@ -88,15 +100,26 @@ public void open(HostAddress address, int timeout, SSLParam sslParam) } @Override - public void open(HostAddress address, int timeout) + public void open(HostAddress address, int timeout) throws IOErrorException, + ClientServerIncompatibleException { + this.open(address, timeout, false); + } + + @Override + public void open(HostAddress address, int timeout, boolean isUseHttp2) throws IOErrorException, ClientServerIncompatibleException { try { this.serverAddr = address; this.timeout = timeout <= 0 ? Integer.MAX_VALUE : timeout; + if (isUseHttp2) { + getProtocolForHttp2(); + } else { + getProtocol(); + } this.transport = new THeaderTransport(new TSocket( address.getHost(), address.getPort(), this.timeout, this.timeout)); this.transport.open(); - this.protocol = new THeaderProtocol(transport); + this.protocol = new THeaderProtocol((THeaderTransport) transport); client = new GraphService.Client(protocol); // check if client version matches server version @@ -112,6 +135,57 @@ public void open(HostAddress address, int timeout) } } + /** + * create protocol for http2 with tls + */ + private void getProtocolWithTlsHttp2() { + String url = "https://" + serverAddr.getHost() + ":" + serverAddr.getPort(); + TrustManager trustManager; + if (SslUtil.getTrustManagers() == null || SslUtil.getTrustManagers().length == 0) { + trustManager = null; + } else { + trustManager = SslUtil.getTrustManagers()[0]; + } + this.transport = new THttp2Client(url, sslSocketFactory, trustManager) + .setConnectTimeout(timeout) + .setReadTimeout(timeout); + transport.open(); + this.protocol = new TBinaryProtocol(transport); + } + + /** + * create protocol for http2 without tls + */ + private void getProtocolForTls() throws IOException { + this.transport = new THeaderTransport(new TSocket( + sslSocketFactory.createSocket(serverAddr.getHost(), + serverAddr.getPort()), this.timeout, this.timeout)); + this.protocol = new THeaderProtocol((THeaderTransport) transport); + } + + /** + * create protocol for http2 + */ + private void getProtocolForHttp2() { + String url = "http://" + serverAddr.getHost() + ":" + serverAddr.getPort(); + this.transport = new THttp2Client(url) + .setConnectTimeout(timeout) + .setReadTimeout(timeout); + transport.open(); + this.protocol = new TBinaryProtocol(transport); + } + + /** + * create protocol for tcp + */ + private void getProtocol() { + this.transport = new THeaderTransport(new TSocket( + serverAddr.getHost(), serverAddr.getPort(), this.timeout, this.timeout)); + transport.open(); + this.protocol = new THeaderProtocol((THeaderTransport) transport); + } + + /* * Because the code generated by Fbthrift does not handle the seqID, * the message will be dislocation when the timeout occurs, @@ -126,9 +200,9 @@ public void open(HostAddress address, int timeout) public void reopen() throws IOErrorException, ClientServerIncompatibleException { close(); if (enabledSsl) { - open(serverAddr, timeout, sslParam); + open(serverAddr, timeout, sslParam, useHttp2); } else { - open(serverAddr, timeout); + open(serverAddr, timeout, useHttp2); } } @@ -257,6 +331,7 @@ public boolean ping(long sessionID) { public void close() { if (transport != null && transport.isOpen()) { transport.close(); + transport = null; } } diff --git a/client/src/main/java/com/vesoft/nebula/util/SslUtil.java b/client/src/main/java/com/vesoft/nebula/util/SslUtil.java index 82133e95e..c31672bda 100644 --- a/client/src/main/java/com/vesoft/nebula/util/SslUtil.java +++ b/client/src/main/java/com/vesoft/nebula/util/SslUtil.java @@ -16,6 +16,7 @@ import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; import org.bouncycastle.cert.X509CertificateHolder; import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; @@ -32,6 +33,8 @@ public class SslUtil { private static final Logger LOGGER = LoggerFactory.getLogger(SslUtil.class); + private static TrustManager[] trustManagers; + public static SSLSocketFactory getSSLSocketFactoryWithCA(CASignedSSLParam param) { final String caCrtFile = param.getCaCrtFilePath(); final String crtFile = param.getCrtFilePath(); @@ -117,6 +120,8 @@ public static SSLSocketFactory getSSLSocketFactoryWithCA(CASignedSSLParam param) context.init(keyManagerFactory.getKeyManagers(), trustManagerFactory.getTrustManagers(), null); + + trustManagers = trustManagerFactory.getTrustManagers(); // Return the newly created socket factory object return context.getSocketFactory(); @@ -176,8 +181,10 @@ public static SSLSocketFactory getSSLSocketFactoryWithoutCA(SelfSignedSSLParam p } X509Certificate cert = certificateConverter.getCertificate(certHolder); + // certificate is used to authenticate server KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + System.out.println(keyStore); keyStore.load(null, null); keyStore.setCertificateEntry("certificate", cert); @@ -201,12 +208,16 @@ public static SSLSocketFactory getSSLSocketFactoryWithoutCA(SelfSignedSSLParam p context.init(keyManagerFactory.getKeyManagers(), trustManagerFactory.getTrustManagers(), null); + trustManagers = trustManagerFactory.getTrustManagers(); // Return the newly created socket factory object return context.getSocketFactory(); } catch (Exception e) { LOGGER.error(e.getMessage()); + throw new RuntimeException(e); } + } - return null; + public static TrustManager[] getTrustManagers() { + return trustManagers; } } diff --git a/examples/pom.xml b/examples/pom.xml index 335babf0f..0bf73b3c8 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -13,6 +13,39 @@ + + org.apache.maven.plugins + maven-assembly-plugin + 2.5.3 + + + package + + single + + + + + + jar-with-dependencies + + ${project.artifactId}-${project.version}-jar-with-dependencies + false + + + + + org.apache.maven.plugins + maven-jar-plugin + 3.2.0 + + + + test-jar + + + + org.apache.maven.plugins maven-deploy-plugin diff --git a/examples/src/main/java/com/vesoft/nebula/examples/GraphSessionPoolWithHttp2Example.java b/examples/src/main/java/com/vesoft/nebula/examples/GraphSessionPoolWithHttp2Example.java new file mode 100644 index 000000000..4f40d5f94 --- /dev/null +++ b/examples/src/main/java/com/vesoft/nebula/examples/GraphSessionPoolWithHttp2Example.java @@ -0,0 +1,158 @@ +/* Copyright (c) 2023 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.examples; + +import com.vesoft.nebula.client.graph.NebulaPoolConfig; +import com.vesoft.nebula.client.graph.SessionPool; +import com.vesoft.nebula.client.graph.SessionPoolConfig; +import com.vesoft.nebula.client.graph.data.CASignedSSLParam; +import com.vesoft.nebula.client.graph.data.HostAddress; +import com.vesoft.nebula.client.graph.data.ResultSet; +import com.vesoft.nebula.client.graph.data.SSLParam; +import com.vesoft.nebula.client.graph.exception.AuthFailedException; +import com.vesoft.nebula.client.graph.exception.BindSpaceFailedException; +import com.vesoft.nebula.client.graph.exception.ClientServerIncompatibleException; +import com.vesoft.nebula.client.graph.exception.IOErrorException; +import com.vesoft.nebula.client.graph.net.NebulaPool; +import com.vesoft.nebula.client.graph.net.Session; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * when use http2, please call System.exit(0) to exit your process. + */ +public class GraphSessionPoolWithHttp2Example { + private static final Logger log = LoggerFactory.getLogger(GraphClientExample.class); + + private static String host = "192.168.8.202"; + + private static int port = 9119; + + private static String user = "root"; + + private static String password = "nebula"; + + private static String spaceName = "test"; + + private static int parallel = 20; + + private static int executeTimes = 20; + + public static void main(String[] args) { + SSLParam sslParam = new CASignedSSLParam( + "examples/src/main/resources/ssl/root.crt", + "examples/src/main/resources/ssl/server.crt", + "examples/src/main/resources/ssl/server.key"); + prepare(sslParam); + + SessionPoolConfig sessionPoolConfig = new SessionPoolConfig( + Arrays.asList(new HostAddress(host, port)), spaceName, user, password) + .setMaxSessionSize(parallel) + .setMinSessionSize(parallel) + .setRetryConnectTimes(3) + .setWaitTime(100) + .setRetryTimes(3) + .setIntervalTime(100) + .setEnableSsl(true) + .setSslParam(sslParam) + .setUseHttp2(false); + SessionPool sessionPool = new SessionPool(sessionPoolConfig); + if (!sessionPool.init()) { + log.error("session pool init failed."); + System.exit(-1); + } + + executeForSingleThread(sessionPool); + executeForMultiThreads(sessionPool); + + sessionPool.close(); + System.exit(0); + } + + /** + * execute in single thread + */ + private static void executeForSingleThread(SessionPool sessionPool) { + try { + ResultSet resultSet = sessionPool.execute("match(v:player) return v limit 1;"); + System.out.println(resultSet.toString()); + } catch (IOErrorException | ClientServerIncompatibleException | AuthFailedException + | BindSpaceFailedException e) { + e.printStackTrace(); + sessionPool.close(); + System.exit(1); + } + } + + /** + * execute in mutil-threads + */ + private static void executeForMultiThreads(SessionPool sessionPool) { + ExecutorService executorService = Executors.newFixedThreadPool(parallel); + CountDownLatch count = new CountDownLatch(parallel); + for (int i = 0; i < parallel; i++) { + executorService.submit(() -> { + try { + for (int j = 0; j < executeTimes; j++) { + ResultSet result = sessionPool.execute("match(v:player) return v limit 1;"); + System.out.println(result.toString()); + } + count.countDown(); + } catch (Exception e) { + e.printStackTrace(); + } + }); + } + + try { + count.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + + executorService.shutdown(); + } + + + private static void prepare(SSLParam sslParam) { + NebulaPool pool = new NebulaPool(); + Session session; + NebulaPoolConfig nebulaPoolConfig = new NebulaPoolConfig(); + nebulaPoolConfig.setUseHttp2(false); + nebulaPoolConfig.setEnableSsl(true); + nebulaPoolConfig.setSslParam(sslParam); + nebulaPoolConfig.setMaxConnSize(10); + List addresses = Arrays.asList(new HostAddress(host, port)); + try { + boolean initResult = pool.init(addresses, nebulaPoolConfig); + if (!initResult) { + log.error("pool init failed."); + System.exit(-1); + } + + session = pool.getSession("root", "nebula", false); + ResultSet res = session.execute( + "CREATE SPACE IF NOT EXISTS " + spaceName + "(vid_type=fixed_string(20));" + + "USE test;" + + "CREATE TAG IF NOT EXISTS player(name string, age int);"); + session.execute("insert vertex player(name,age) values \"1\":(\"Tom\",20);"); + if (!res.isSucceeded()) { + System.out.println(res.getErrorMessage()); + System.exit(-1); + } + } catch (Exception e) { + e.printStackTrace(); + System.exit(1); + } finally { + pool.close(); + } + } +} diff --git a/examples/src/main/resources/ssl/root.crt b/examples/src/main/resources/ssl/root.crt new file mode 100644 index 000000000..6fd938602 --- /dev/null +++ b/examples/src/main/resources/ssl/root.crt @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDiTCCAnGgAwIBAgIUEXgJM36pHqMj9xpMOV+bIeLd8BkwDQYJKoZIhvcNAQEL +BQAwLjELMAkGA1UEBhMCQ0gxEDAOBgNVBAoMB3Rlc3QtY2ExDTALBgNVBAMMBHJv +b3QwHhcNMjMwOTA2MDMzNDI3WhcNMzMwOTAzMDMzNDI3WjAuMQswCQYDVQQGEwJD +SDEQMA4GA1UECgwHdGVzdC1jYTENMAsGA1UEAwwEcm9vdDCCASIwDQYJKoZIhvcN +AQEBBQADggEPADCCAQoCggEBAMjXEshOucs4SFJkJl9GUvWdN5u0mZszwlHSQZUH +LrxRq6Z7QVm++9tDFp68FlQhEms9MTxe/ggPwLY+lpWl3QKQGGazeFSOjD6nlT0r +FmN4G52yryP8F3VOt3APc+NRVNJHMjzyeicxVzCcfzEuII3QZVKh3QVofmhIqJDn +RRYeFcTUOHNQygCoE3alsAv25PKpQN7H/9TefwCkuS37an4ZJm+nskM+CkDwCsai +Hu29C7VL/TspXEtmvat52biZVY/si2vmITXZky1Sfg+FKIrdsnhEsFgTaOamn8iG +tWrSqOdxwdRgUlfXT0I+SXCuG4qWQXWZ7oKK3pSrGc0p4p0CAwEAAaOBnjCBmzAP +BgNVHRMBAf8EBTADAQH/MB0GA1UdDgQWBBQcsh7OkAUiljB9oah4A7hpk+y5VTBp +BgNVHSMEYjBggBQcsh7OkAUiljB9oah4A7hpk+y5VaEypDAwLjELMAkGA1UEBhMC +Q0gxEDAOBgNVBAoMB3Rlc3QtY2ExDTALBgNVBAMMBHJvb3SCFBF4CTN+qR6jI/ca +TDlfmyHi3fAZMA0GCSqGSIb3DQEBCwUAA4IBAQBI+HZE8KlSHDo8Az5+0TZWwKlO +D/aVAh7O7Amhxp0ukM/tOFymdBZ5J5GrlsEmgJCHX2WkGXIH+i8X67Q6VaA0bWN5 +6Wz+cA9XEyK44j2H5lHubbyIuE9qx71s6QW1u/w7YvK+vyd4K4G4jD2IIwNLiCUd +gwuxG2elVxY2qqPLBNqjkWWLZ6N/LfJ4/qJ/hsl7h6g1OzwRZE+6wOZ2Bik8IA/R +k6m+JhBww4FQp6lfroVKshNBmFfY4TiwLHjQ5CtnDdoktsZBvQkK6pVBGtXR5yPf +AA/+vjiwuxueF97h9lUs1eDW2s3zkW1hOHkc/0lhgX6WBmkgDdjz7FEOi3hh +-----END CERTIFICATE----- diff --git a/examples/src/main/resources/ssl/server.crt b/examples/src/main/resources/ssl/server.crt new file mode 100644 index 000000000..7b3bc178d --- /dev/null +++ b/examples/src/main/resources/ssl/server.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDKzCCAhOgAwIBAgIUczXPKik40SjLEw/bFLb8emHV8Q8wDQYJKoZIhvcNAQEL +BQAwLjELMAkGA1UEBhMCQ0gxEDAOBgNVBAoMB3Rlc3QtY2ExDTALBgNVBAMMBHJv +b3QwHhcNMjMwOTA2MDUyNzE1WhcNMjMwOTE2MDUyNzE1WjAwMQswCQYDVQQGEwJD +SDEQMA4GA1UECgwHdGVzdC1jYTEPMA0GA1UEAwwGc2VydmVyMIIBIjANBgkqhkiG +9w0BAQEFAAOCAQ8AMIIBCgKCAQEA2FADBs38uSLTypNJhuHKRfKnxQ3Coj+St0bo +SRSdfiOc7pcgJRLnU+HtBcNnPI1U9JqsCvQ9lelO/vaJQTSW0+ftFsGHGqdnU8Nx +7Q45gHaTAhXLHqCUYzwdFzVxih96klzpqC3rN4850HSf9MEw2WW3zKKDHngD2jzb +POr0q4n6IG754Hvh95Zs9a4qdkENDv/wehxJ+92Bl7GWT2Pr517AXRwjFC/UDjd8 +7WTIrR+HNLz489NvNwWWtV3XOAfWUyhlKKyqJZ9WIpf/fhqgEqjuFBoD5/G7dyH0 +waL9sTnZLlyR7IDvWel8FAmRgVCpg4Ug+czMPmtojlXefyIW7QIDAQABoz8wPTA7 +BgNVHREENDAyhwR/AAABhwTAqAjKgglsb2NhbGhvc3SCB2dyYXBoZDCCB2dyYXBo +ZDGCB2dyYXBoZDIwDQYJKoZIhvcNAQELBQADggEBAJ07ST/5kTGNxTYoxJxtxGo/ +OLtUsfuu7apYdUPgpr1ZQB3hGCZ/+C8aGHEzf+a+qLSZsofd8VKXDvdPg+bStlAP +aAbHjnj4uA2Jn2efBZ7EBmocxVuJ6lngbOK6ApCQqynP2jHb1VZHgH+AqZio5ahd +2GosmcO+4vGf5p6k8de2sS8ryj1EhoLMQMmjPn0hDS/Zy2A7qxlwOZfVzvV6hz0u +CglVGBGA89oAXUmnRclpDvvl9M0Xue3eH0LmzVHNgGHlc/XtM4hvnyyObA7MhnGB +xLiFfUlxp30DQX1I6BAh3QmCDAyeGL8BWI7KP7dR7DjeSnNKQ0wVTtMaH0aouRk= +-----END CERTIFICATE----- diff --git a/examples/src/main/resources/ssl/server.key b/examples/src/main/resources/ssl/server.key new file mode 100755 index 000000000..861c70273 --- /dev/null +++ b/examples/src/main/resources/ssl/server.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEA2FADBs38uSLTypNJhuHKRfKnxQ3Coj+St0boSRSdfiOc7pcg +JRLnU+HtBcNnPI1U9JqsCvQ9lelO/vaJQTSW0+ftFsGHGqdnU8Nx7Q45gHaTAhXL +HqCUYzwdFzVxih96klzpqC3rN4850HSf9MEw2WW3zKKDHngD2jzbPOr0q4n6IG75 +4Hvh95Zs9a4qdkENDv/wehxJ+92Bl7GWT2Pr517AXRwjFC/UDjd87WTIrR+HNLz4 +89NvNwWWtV3XOAfWUyhlKKyqJZ9WIpf/fhqgEqjuFBoD5/G7dyH0waL9sTnZLlyR +7IDvWel8FAmRgVCpg4Ug+czMPmtojlXefyIW7QIDAQABAoIBAQChIFBw1CUo/rFG +FxpQ14VlPxALL2nIk5RE4xOJxEpgOETgUEAIfaFEkWiNv3T53MjofwIiEraBIU3P +i3LH2FV1OTAYoEVz7DiCY3ZMPylD8I8moXcwtCp0FMYSkKOnYDVcKst9k96+/vfw +t76igPlTJnqXeoIywvETsfsY0Gc6EmlKYWILwX7cfkOvnCofbSDjkX5x24jDvK9L +O9aaTmU+BfHA8gpaalF69SFOck1zP0Zu9uFI1hpjuL63lqrGA0INfcENZG5FR4wC +CI69BtXcOVbLb3FK+j1QVjU+IX9Hu1dw7zn1/PwLZHhiLmYyz/B0VC6LzpxB5L5P +7v2U9U8dAoGBAPEIMbWjrzGBHx3eEwH1lECh1KXRDSo5xwH8DvGedO2+of/nz/Iv +sfOYRJ/xHyFe5Byz0m/7IX/XdcQ4LlGGpv5Z48olZ2YNvppF5mv5O4/KK0w0dlmm +tNvEzwL7pM8xpKk7rm8RKc5aw4vcn7e+9gBTKuhW7o2z6dF2uvBaQj1/AoGBAOW+ +1w1TeNwKhB/hZhxNRBlTv8MGml3gx3L5s+28L/bJ/IQCpGptmuHgDIFfqOtO0Mzo +SnwNMoYZcqRS29Ms/ygcFvP3nZ6zsNaT3cDnR1UHBc4byaDa8a7INcTDAVDAYBML +4o51dsguWPAqdN3Od5yuK5C7Qexfzb+y4srXWLmTAoGABJFdL84MUenahxxgS9c1 +mgv3FbVihHxX0yfNuLKCEMdeFpV0EWjp/G3UTxuotV8w/4JA6LJfriaNKszNw+nD +XGqjsH8I+JwmEpJkjYNJp63zKByOaaCJKOkP60SNmQed0T86TQyMOEbsEch6lmbe +Dp+E3qZXGwRf2AJiBJARVU0CgYEAwMvuwiMbWGSGzg4j13pLvIlYcjxXTJK2LVk4 +0jdLdOm6O7nP6fRCtmyDcgopwhXpCRuibgnRLVGrsBRMnyGymiFAbcmM/0JCE0AR +JrGvXb4/89/Dy3YQvSEMZitTLkXSGgmuPOh8Hq8uOZUXb4+1Nsm+i31pbAhVrBpd +UeV3cnsCgYEA3nKwC+hjSe5jLJJKD/GEfCcALqWEuX4xJJjXAPPQ5advWdUwQmA8 +LnrcJROvDP9nW0hjiCXWifAYC9fbXdmU1e6HTfHIIbsTqHGMA2mZEe+XQevj/GVT +hZ4gALLJzMtbqyJd/53Vs0CRcEeke8YRs0NqGF6La/mjnHbwoKqIpYM= +-----END RSA PRIVATE KEY-----