diff --git a/src/main/java/io/milvus/v2/client/ConnectConfig.java b/src/main/java/io/milvus/v2/client/ConnectConfig.java index e9247479d..e101ccca2 100644 --- a/src/main/java/io/milvus/v2/client/ConnectConfig.java +++ b/src/main/java/io/milvus/v2/client/ConnectConfig.java @@ -24,6 +24,7 @@ import lombok.NonNull; import lombok.experimental.SuperBuilder; +import javax.net.ssl.SSLContext; import java.net.URI; import java.util.concurrent.TimeUnit; @@ -57,6 +58,8 @@ public class ConnectConfig { @Builder.Default private long idleTimeoutMs = TimeUnit.MILLISECONDS.convert(24, TimeUnit.HOURS); + private SSLContext sslContext; + public String getHost() { URI uri = URI.create(this.uri); return uri.getHost(); diff --git a/src/main/java/io/milvus/v2/utils/ClientUtils.java b/src/main/java/io/milvus/v2/utils/ClientUtils.java index 7ccd19a7a..8164c40cb 100644 --- a/src/main/java/io/milvus/v2/utils/ClientUtils.java +++ b/src/main/java/io/milvus/v2/utils/ClientUtils.java @@ -24,12 +24,17 @@ import io.grpc.Metadata; import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; +import io.grpc.netty.shaded.io.netty.handler.ssl.ApplicationProtocolConfig; +import io.grpc.netty.shaded.io.netty.handler.ssl.ClientAuth; +import io.grpc.netty.shaded.io.netty.handler.ssl.IdentityCipherSuiteFilter; +import io.grpc.netty.shaded.io.netty.handler.ssl.JdkSslContext; import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext; import io.grpc.stub.MetadataUtils; import io.milvus.client.MilvusServiceClient; import io.milvus.grpc.*; import io.milvus.v2.client.ConnectConfig; import org.apache.commons.lang3.StringUtils; +import org.jetbrains.annotations.NotNull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -57,7 +62,25 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){ } try { - if (StringUtils.isNotEmpty(connectConfig.getServerPemPath())) { + if (connectConfig.getSslContext() != null) { + // sslContext from connect config + NettyChannelBuilder builder = NettyChannelBuilder.forAddress(connectConfig.getHost(), connectConfig.getPort()) + .overrideAuthority(connectConfig.getServerName()) + .sslContext(convertJavaSslContextToNetty(connectConfig)) + .maxInboundMessageSize(Integer.MAX_VALUE) + .keepAliveTime(connectConfig.getKeepAliveTimeMs(), TimeUnit.MILLISECONDS) + .keepAliveTimeout(connectConfig.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS) + .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls()) + .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS) + .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata)); + if(connectConfig.isSecure()) { + builder.useTransportSecurity(); + } + if (StringUtils.isNotEmpty(connectConfig.getServerName())) { + builder.overrideAuthority(connectConfig.getServerName()); + } + channel = builder.build(); + } else if (StringUtils.isNotEmpty(connectConfig.getServerPemPath())) { // one-way tls SslContext sslContext = GrpcSslContexts.forClient() .trustManager(new File(connectConfig.getServerPemPath())) @@ -122,6 +145,13 @@ public ManagedChannel getChannel(ConnectConfig connectConfig){ return channel; } + private static JdkSslContext convertJavaSslContextToNetty(ConnectConfig connectConfig) { + ApplicationProtocolConfig applicationProtocolConfig = new ApplicationProtocolConfig(ApplicationProtocolConfig.Protocol.NONE, + ApplicationProtocolConfig.SelectorFailureBehavior.FATAL_ALERT, ApplicationProtocolConfig.SelectedListenerFailureBehavior.FATAL_ALERT); + return new JdkSslContext(connectConfig.getSslContext(), true, null, + IdentityCipherSuiteFilter.INSTANCE, applicationProtocolConfig, ClientAuth.NONE, null, false); + } + public void checkDatabaseExist(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String dbName) { String title = String.format("Check database %s exist", dbName); ListDatabasesRequest listDatabasesRequest = ListDatabasesRequest.newBuilder().build();