diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index 656bc5e9a8..88609cba50 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -48,6 +48,7 @@ jobs: with: fetch-depth: 0 persist-credentials: false + submodules: recursive - uses: actions/setup-java@v4 with: cache: "maven" @@ -81,6 +82,7 @@ jobs: with: fetch-depth: 0 persist-credentials: false + submodules: recursive - uses: actions/setup-java@v4 with: cache: "maven" diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000..e880666df1 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[submodule "testing"] + path = testing + url = https://github.com/apache/arrow-testing.git diff --git a/java/driver-manager/pom.xml b/java/driver-manager/pom.xml index 79ad55b992..8a083a58a8 100644 --- a/java/driver-manager/pom.xml +++ b/java/driver-manager/pom.xml @@ -42,7 +42,6 @@ org.apache.arrow arrow-memory-unsafe - ${dep.arrow.version} test diff --git a/java/driver/flight-sql/pom.xml b/java/driver/flight-sql/pom.xml index 26dd8a7262..ceceaa8249 100644 --- a/java/driver/flight-sql/pom.xml +++ b/java/driver/flight-sql/pom.xml @@ -78,5 +78,43 @@ junit-jupiter test + + org.junit.vintage + junit-vintage-engine + test + + + + junit + junit + 4.13.1 + test + + + org.apache.arrow + flight-sql-jdbc-core + test + + + org.apache.arrow + flight-sql-jdbc-core + ${dep.arrow.version} + test + tests + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + ${project.basedir}/../../../testing/data + + + + + diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightInfoReader.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightInfoReader.java index 9b0cda91dc..6850be0639 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightInfoReader.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightInfoReader.java @@ -24,12 +24,10 @@ import java.util.Objects; import org.apache.arrow.adbc.core.AdbcException; import org.apache.arrow.adbc.core.AdbcStatusCode; -import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightEndpoint; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; -import org.apache.arrow.flight.sql.FlightSqlClient; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.VectorUnloader; @@ -40,8 +38,8 @@ /** An ArrowReader that wraps a FlightInfo. */ public class FlightInfoReader extends ArrowReader { private final Schema schema; - private final FlightSqlClient client; - private final LoadingCache clientCache; + private final FlightSqlClientWithCallOptions client; + private final LoadingCache clientCache; private final List flightEndpoints; private int nextEndpointIndex; private FlightStream currentStream; @@ -49,8 +47,8 @@ public class FlightInfoReader extends ArrowReader { FlightInfoReader( BufferAllocator allocator, - FlightSqlClient client, - LoadingCache clientCache, + FlightSqlClientWithCallOptions client, + LoadingCache clientCache, List flightEndpoints) throws AdbcException { super(allocator); diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlClientWithCallOptions.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlClientWithCallOptions.java new file mode 100644 index 0000000000..f7028cb55f --- /dev/null +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlClientWithCallOptions.java @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.adbc.driver.flightsql; + +import static org.apache.arrow.flight.sql.FlightSqlClient.PreparedStatement; +import static org.apache.arrow.flight.sql.FlightSqlClient.Savepoint; +import static org.apache.arrow.flight.sql.FlightSqlClient.SubstraitPlan; +import static org.apache.arrow.flight.sql.FlightSqlClient.Transaction; + +import java.util.List; +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.CancelFlightInfoRequest; +import org.apache.arrow.flight.CancelFlightInfoResult; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.RenewFlightEndpointRequest; +import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.CancelResult; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.flight.sql.util.TableRef; +import org.apache.arrow.util.AutoCloseables; + +/** A wrapper around FlightSqlClient which automatically adds CallOptions to each RPC call. */ +public class FlightSqlClientWithCallOptions implements AutoCloseable { + private final FlightSqlClient client; + private final CallOption[] connectionOptions; + + public FlightSqlClientWithCallOptions(FlightSqlClient client, CallOption... options) { + this.client = client; + this.connectionOptions = options; + } + + public FlightInfo execute(String query, CallOption... options) { + return client.execute(query, combine(options)); + } + + public FlightInfo execute(String query, Transaction transaction, CallOption... options) { + return client.execute(query, transaction, combine(options)); + } + + public FlightInfo executeSubstrait(SubstraitPlan plan, CallOption... options) { + return client.executeSubstrait(plan, combine(options)); + } + + public FlightInfo executeSubstrait( + SubstraitPlan plan, Transaction transaction, CallOption... options) { + return client.executeSubstrait(plan, transaction, combine(options)); + } + + public SchemaResult getExecuteSchema( + String query, Transaction transaction, CallOption... options) { + return client.getExecuteSchema(query, transaction, combine(options)); + } + + public SchemaResult getExecuteSchema(String query, CallOption... options) { + return client.getExecuteSchema(query, combine(options)); + } + + public SchemaResult getExecuteSubstraitSchema( + SubstraitPlan plan, Transaction transaction, CallOption... options) { + return client.getExecuteSubstraitSchema(plan, transaction, combine(options)); + } + + public SchemaResult getExecuteSubstraitSchema( + SubstraitPlan substraitPlan, CallOption... options) { + return client.getExecuteSubstraitSchema(substraitPlan, combine(options)); + } + + public long executeUpdate(String query, CallOption... options) { + return client.executeUpdate(query, combine(options)); + } + + public long executeUpdate(String query, Transaction transaction, CallOption... options) { + return client.executeUpdate(query, transaction, combine(options)); + } + + public long executeSubstraitUpdate(SubstraitPlan plan, CallOption... options) { + return client.executeSubstraitUpdate(plan, combine(options)); + } + + public long executeSubstraitUpdate( + SubstraitPlan plan, Transaction transaction, CallOption... options) { + return client.executeSubstraitUpdate(plan, transaction, combine(options)); + } + + public FlightInfo getCatalogs(CallOption... options) { + return client.getCatalogs(options); + } + + public SchemaResult getCatalogsSchema(CallOption... options) { + return client.getCatalogsSchema(options); + } + + public FlightInfo getSchemas( + String catalog, String dbSchemaFilterPattern, CallOption... options) { + return client.getSchemas(catalog, dbSchemaFilterPattern, combine(options)); + } + + public SchemaResult getSchemasSchema(CallOption... options) { + return client.getSchemasSchema(options); + } + + public SchemaResult getSchema(FlightDescriptor descriptor, CallOption... options) { + return client.getSchema(descriptor, combine(options)); + } + + public FlightStream getStream(Ticket ticket, CallOption... options) { + return client.getStream(ticket, combine(options)); + } + + public FlightInfo getSqlInfo(FlightSql.SqlInfo... info) { + return client.getSqlInfo(info); + } + + public FlightInfo getSqlInfo(FlightSql.SqlInfo[] info, CallOption... options) { + return client.getSqlInfo(info, combine(options)); + } + + public FlightInfo getSqlInfo(int[] info, CallOption... options) { + return client.getSqlInfo(info, combine(options)); + } + + public FlightInfo getSqlInfo(Iterable info, CallOption... options) { + return client.getSqlInfo(info, combine(options)); + } + + public SchemaResult getSqlInfoSchema(CallOption... options) { + return client.getSqlInfoSchema(options); + } + + public FlightInfo getXdbcTypeInfo(int dataType, CallOption... options) { + return client.getXdbcTypeInfo(dataType, combine(options)); + } + + public FlightInfo getXdbcTypeInfo(CallOption... options) { + return client.getXdbcTypeInfo(options); + } + + public SchemaResult getXdbcTypeInfoSchema(CallOption... options) { + return client.getXdbcTypeInfoSchema(options); + } + + public FlightInfo getTables( + String catalog, + String dbSchemaFilterPattern, + String tableFilterPattern, + List tableTypes, + boolean includeSchema, + CallOption... options) { + return client.getTables( + catalog, + dbSchemaFilterPattern, + tableFilterPattern, + tableTypes, + includeSchema, + combine(options)); + } + + public SchemaResult getTablesSchema(boolean includeSchema, CallOption... options) { + return client.getTablesSchema(includeSchema, combine(options)); + } + + public FlightInfo getPrimaryKeys(TableRef tableRef, CallOption... options) { + return client.getPrimaryKeys(tableRef, combine(options)); + } + + public SchemaResult getPrimaryKeysSchema(CallOption... options) { + return client.getPrimaryKeysSchema(options); + } + + public FlightInfo getExportedKeys(TableRef tableRef, CallOption... options) { + return client.getExportedKeys(tableRef, combine(options)); + } + + public SchemaResult getExportedKeysSchema(CallOption... options) { + return client.getExportedKeysSchema(options); + } + + public FlightInfo getImportedKeys(TableRef tableRef, CallOption... options) { + return client.getImportedKeys(tableRef, combine(options)); + } + + public SchemaResult getImportedKeysSchema(CallOption... options) { + return client.getImportedKeysSchema(options); + } + + public FlightInfo getCrossReference( + TableRef pkTableRef, TableRef fkTableRef, CallOption... options) { + return client.getCrossReference(pkTableRef, fkTableRef, combine(options)); + } + + public SchemaResult getCrossReferenceSchema(CallOption... options) { + return client.getCrossReferenceSchema(options); + } + + public FlightInfo getTableTypes(CallOption... options) { + return client.getTableTypes(options); + } + + public SchemaResult getTableTypesSchema(CallOption... options) { + return client.getTableTypesSchema(options); + } + + public PreparedStatement prepare(String query, CallOption... options) { + return client.prepare(query, combine(options)); + } + + public PreparedStatement prepare( + String query, FlightSqlClient.Transaction transaction, CallOption... options) { + return client.prepare(query, transaction, combine(options)); + } + + public PreparedStatement prepare(SubstraitPlan plan, CallOption... options) { + return client.prepare(plan, combine(options)); + } + + public PreparedStatement prepare( + SubstraitPlan plan, Transaction transaction, CallOption... options) { + return client.prepare(plan, transaction, combine(options)); + } + + public Transaction beginTransaction(CallOption... options) { + return client.beginTransaction(options); + } + + public Savepoint beginSavepoint(Transaction transaction, String name, CallOption... options) { + return client.beginSavepoint(transaction, name, combine(options)); + } + + public void commit(Transaction transaction, CallOption... options) { + client.commit(transaction, combine(options)); + } + + public void release(Savepoint savepoint, CallOption... options) { + client.release(savepoint, combine(options)); + } + + public void rollback(Transaction transaction, CallOption... options) { + client.rollback(transaction, combine(options)); + } + + public void rollback(Savepoint savepoint, CallOption... options) { + client.rollback(savepoint, combine(options)); + } + + public CancelFlightInfoResult cancelFlightInfo( + CancelFlightInfoRequest request, CallOption... options) { + return client.cancelFlightInfo(request, combine(options)); + } + + public CancelResult cancelQuery(FlightInfo info, CallOption... options) { + return client.cancelQuery(info, combine(options)); + } + + public FlightEndpoint renewFlightEndpoint( + RenewFlightEndpointRequest request, CallOption... options) { + return client.renewFlightEndpoint(request, combine(options)); + } + + public void close() throws Exception { + AutoCloseables.close(client); + } + + private CallOption[] combine(CallOption... options) { + final CallOption[] result = new CallOption[connectionOptions.length + options.length]; + System.arraycopy(connectionOptions, 0, result, 0, connectionOptions.length); + System.arraycopy(options, 0, result, connectionOptions.length, options.length); + return result; + } +} diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java index f583f2b866..c079060ec8 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java @@ -20,19 +20,35 @@ import com.github.benmanes.caffeine.cache.LoadingCache; import com.github.benmanes.caffeine.cache.RemovalCause; import com.google.protobuf.InvalidProtocolBufferException; +import java.io.ByteArrayInputStream; +import java.io.DataInputStream; +import java.io.IOException; +import java.io.InputStream; import java.net.URISyntaxException; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Collections; +import java.util.Map; +import java.util.Optional; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDriver; import org.apache.arrow.adbc.core.AdbcException; import org.apache.arrow.adbc.core.AdbcStatement; +import org.apache.arrow.adbc.core.AdbcStatusCode; import org.apache.arrow.adbc.core.BulkIngestMode; import org.apache.arrow.adbc.sql.SqlQuirks; +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.FlightCallHeaders; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.HeaderCallOption; import org.apache.arrow.flight.Location; import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.auth2.BasicAuthCredentialWriter; +import org.apache.arrow.flight.client.ClientCookieMiddleware; +import org.apache.arrow.flight.grpc.CredentialCallOption; import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.flight.sql.FlightSqlClient; import org.apache.arrow.memory.BufferAllocator; @@ -42,28 +58,57 @@ public class FlightSqlConnection implements AdbcConnection { private final BufferAllocator allocator; - private final FlightSqlClient client; + private final AtomicInteger counter; + private final FlightSqlClientWithCallOptions client; private final SqlQuirks quirks; - private final LoadingCache clientCache; + private final Map parameters; + private final LoadingCache clientCache; - FlightSqlConnection(BufferAllocator allocator, FlightClient client, SqlQuirks quirks) { + // Cached data to use across additional connections. + private ClientCookieMiddleware.Factory cookieMiddlewareFactory; + private CallOption[] callOptions; + + // Used to cache the InputStream content as a byte array since + // subsequent connections may need to use it but it is supplied as a stream. + private byte[] mtlsCertChainBytes; + private byte[] mtlsPrivateKeyBytes; + private byte[] tlsRootCertsBytes; + + FlightSqlConnection( + BufferAllocator allocator, + SqlQuirks quirks, + Location location, + Map parameters) + throws AdbcException { this.allocator = allocator; - this.client = new FlightSqlClient(client); + this.counter = new AtomicInteger(0); this.quirks = quirks; + this.parameters = parameters; + FlightSqlClient flightSqlClient = new FlightSqlClient(createInitialConnection(location)); + this.client = new FlightSqlClientWithCallOptions(flightSqlClient, callOptions); this.clientCache = Caffeine.newBuilder() .expireAfterAccess(5, TimeUnit.MINUTES) .removalListener( - (Location key, FlightClient value, RemovalCause cause) -> { + (Location key, FlightSqlClientWithCallOptions value, RemovalCause cause) -> { if (value == null) return; try { value.close(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(e); + } catch (Exception ex) { + if (ex instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + throw new RuntimeException(ex); } }) - .build(location -> FlightClient.builder(allocator, location).build()); + .build( + loc -> { + FlightClient client = buildClient(loc); + client.handshake(callOptions); + return new FlightSqlClientWithCallOptions( + new FlightSqlClient(client), callOptions); + }); + this.clientCache.put(location, this.client); } @Override @@ -137,6 +182,7 @@ public void setAutoCommit(boolean enableAutoCommit) throws AdbcException { @Override public void close() throws Exception { + clientCache.invalidateAll(); AutoCloseables.close(client, allocator); } @@ -144,4 +190,164 @@ public void close() throws Exception { public String toString() { return "FlightSqlConnection{" + "client=" + client + '}'; } + + /** + * Initialize cached data to share between connections and create, test, and authenticate the + * first connection. + */ + private FlightClient createInitialConnection(Location location) throws AdbcException { + // Setup cached pre-connection properties. + try { + final InputStream mtlsCertChain = + FlightSqlConnectionProperties.MTLS_CERT_CHAIN.get(parameters); + if (mtlsCertChain != null) { + this.mtlsCertChainBytes = inputStreamToBytes(mtlsCertChain); + } + + final InputStream mtlsPrivateKey = + FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.get(parameters); + if (mtlsPrivateKey != null) { + this.mtlsPrivateKeyBytes = inputStreamToBytes(mtlsPrivateKey); + } + + final InputStream tlsRootCerts = FlightSqlConnectionProperties.TLS_ROOT_CERTS.get(parameters); + if (tlsRootCerts != null) { + this.tlsRootCertsBytes = inputStreamToBytes(tlsRootCerts); + } + } catch (IOException ex) { + throw new AdbcException( + String.format( + "Error reading stream for one of the options %s, %s, %s.", + FlightSqlConnectionProperties.MTLS_CERT_CHAIN.getKey(), + FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.getKey(), + FlightSqlConnectionProperties.TLS_ROOT_CERTS.getKey()), + ex, + AdbcStatusCode.IO, + null, + 0); + } + + final boolean useCookieMiddleware = + Boolean.TRUE.equals(FlightSqlConnectionProperties.WITH_COOKIE_MIDDLEWARE.get(parameters)); + if (useCookieMiddleware) { + this.cookieMiddlewareFactory = new ClientCookieMiddleware.Factory(); + } + + // Build the client using the above properties. + final FlightClient client = buildClient(location); + + // Add user-specified headers. + ArrayList options = new ArrayList<>(); + final FlightCallHeaders callHeaders = new FlightCallHeaders(); + for (Map.Entry parameter : parameters.entrySet()) { + if (parameter.getKey().startsWith(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX)) { + String userHeaderName = + parameter + .getKey() + .substring(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX.length()); + + if (parameter.getValue() instanceof String) { + callHeaders.insert(userHeaderName, (String) parameter.getValue()); + } else if (parameter.getValue() instanceof byte[]) { + callHeaders.insert(userHeaderName, (byte[]) parameter.getValue()); + } else { + throw new AdbcException( + String.format( + "Header values must be String or byte[]. The header failing was %s.", + parameter.getKey()), + null, + AdbcStatusCode.INVALID_ARGUMENT, + null, + 0); + } + } + } + + options.add(new HeaderCallOption(callHeaders)); + + // Test the connection. + String username = AdbcDriver.PARAM_USERNAME.get(parameters); + String password = AdbcDriver.PARAM_PASSWORD.get(parameters); + if (username != null && password != null) { + Optional bearerToken = + client.authenticateBasicToken(username, password); + options.add( + bearerToken.orElse( + new CredentialCallOption(new BasicAuthCredentialWriter(username, password)))); + this.callOptions = options.toArray(new CallOption[0]); + } else { + this.callOptions = options.toArray(new CallOption[0]); + client.handshake(this.callOptions); + } + + return client; + } + + /** Returns a yet-to-be authenticated FlightClient */ + private FlightClient buildClient(Location location) throws AdbcException { + final FlightClient.Builder builder = + FlightClient.builder() + .allocator( + allocator.newChildAllocator( + "adbc-flightclient-connection-" + counter.getAndIncrement(), + 0, + allocator.getLimit())) + .location(location); + + // Configure TLS options. + if (mtlsCertChainBytes != null && mtlsPrivateKeyBytes != null) { + builder.clientCertificate( + new ByteArrayInputStream(mtlsCertChainBytes), + new ByteArrayInputStream(mtlsPrivateKeyBytes)); + } else if (mtlsCertChainBytes != null) { + throw new AdbcException( + String.format( + "Must provide both %s and %s or neither. %s provided only.", + FlightSqlConnectionProperties.MTLS_CERT_CHAIN.getKey(), + FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.getKey(), + FlightSqlConnectionProperties.MTLS_CERT_CHAIN.getKey()), + null, + AdbcStatusCode.INVALID_ARGUMENT, + null, + 0); + } else if (mtlsPrivateKeyBytes != null) { + throw new AdbcException( + String.format( + "Must provide both %s and %s or neither. %s provided only.", + FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.getKey(), + FlightSqlConnectionProperties.MTLS_CERT_CHAIN.getKey(), + FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.getKey()), + null, + AdbcStatusCode.INVALID_ARGUMENT, + null, + 0); + } + + if (tlsRootCertsBytes != null) { + builder.trustedCertificates(new ByteArrayInputStream(tlsRootCertsBytes)); + } + + if (Boolean.TRUE.equals(FlightSqlConnectionProperties.TLS_SKIP_VERIFY.get(parameters))) { + builder.verifyServer(false); + } + + String hostnameOverride = FlightSqlConnectionProperties.TLS_OVERRIDE_HOSTNAME.get(parameters); + if (hostnameOverride != null) { + builder.overrideHostname(hostnameOverride); + } + + // Setup cookies if needed. + if (cookieMiddlewareFactory != null) { + builder.intercept(cookieMiddlewareFactory); + } + + return builder.build(); + } + + private static byte[] inputStreamToBytes(InputStream stream) throws IOException { + byte[] bytes = new byte[stream.available()]; + DataInputStream dataInputStream = new DataInputStream(stream); + dataInputStream.readFully(bytes); + return bytes; + } } diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java new file mode 100644 index 0000000000..4ab1955a1b --- /dev/null +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.adbc.driver.flightsql; + +import java.io.InputStream; +import org.apache.arrow.adbc.core.TypedKey; + +/** Defines connection options that are used by the FlightSql driver. */ +public interface FlightSqlConnectionProperties { + TypedKey MTLS_CERT_CHAIN = + new TypedKey<>("adbc.flight.sql.client_option.mtls_cert_chain", InputStream.class); + TypedKey MTLS_PRIVATE_KEY = + new TypedKey<>("adbc.flight.sql.client_option.mtls_private_key", InputStream.class); + TypedKey TLS_OVERRIDE_HOSTNAME = + new TypedKey<>("adbc.flight.sql.client_option.tls_override_hostname", String.class); + TypedKey TLS_SKIP_VERIFY = + new TypedKey<>("adbc.flight.sql.client_option.tls_skip_verify", Boolean.class); + TypedKey TLS_ROOT_CERTS = + new TypedKey<>("adbc.flight.sql.client_option.tls_root_certs", InputStream.class); + TypedKey WITH_COOKIE_MIDDLEWARE = + new TypedKey<>("adbc.flight.sql.rpc.with_cookie_middleware", Boolean.class); + String RPC_CALL_HEADER_PREFIX = "adbc.flight.sql.rpc.call_header."; +} diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java index 11ca360579..af8221d6b0 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java @@ -17,56 +17,66 @@ package org.apache.arrow.adbc.driver.flightsql; +import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; import org.apache.arrow.adbc.core.AdbcConnection; import org.apache.arrow.adbc.core.AdbcDatabase; import org.apache.arrow.adbc.core.AdbcException; import org.apache.arrow.adbc.sql.SqlQuirks; -import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.Location; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.AutoCloseables; /** An instance of a database (e.g. a handle to an in-memory database). */ public final class FlightSqlDatabase implements AdbcDatabase { private final BufferAllocator allocator; private final Location location; private final SqlQuirks quirks; - private final FlightClient client; private final AtomicInteger counter; + private final Map parameters; - FlightSqlDatabase(BufferAllocator allocator, Location location, SqlQuirks quirks) + FlightSqlDatabase( + BufferAllocator allocator, + Location location, + SqlQuirks quirks, + Map parameters) throws AdbcException { this.allocator = allocator; this.location = location; this.quirks = quirks; - try { - this.client = FlightClient.builder(allocator, location).build(); - } catch (FlightRuntimeException e) { - throw FlightSqlDriverUtil.fromFlightException(e); - } this.counter = new AtomicInteger(); + this.parameters = parameters; } @Override public AdbcConnection connect() throws AdbcException { - final FlightClient client; + final int count = counter.getAndIncrement(); + BufferAllocator connectionAllocator = + allocator.newChildAllocator("adbc-flight-connection-" + count, 0, allocator.getLimit()); try { - client = FlightClient.builder(allocator, location).build(); - } catch (FlightRuntimeException e) { - throw FlightSqlDriverUtil.fromFlightException(e); + return new FlightSqlConnection(connectionAllocator, quirks, location, parameters); + } catch (FlightRuntimeException ex) { + AdbcException adbcException = FlightSqlDriverUtil.fromFlightException(ex); + try { + AutoCloseables.close(connectionAllocator); + } catch (Exception e) { + adbcException.addSuppressed(e); + } + throw adbcException; + } catch (Exception ex) { + AdbcException adbcException = FlightSqlDriverUtil.fromGeneralException(ex); + try { + AutoCloseables.close(connectionAllocator); + } catch (Exception e) { + adbcException.addSuppressed(e); + } + throw adbcException; } - final int count = counter.getAndIncrement(); - return new FlightSqlConnection( - allocator.newChildAllocator("adbc-jdbc-connection-" + count, 0, allocator.getLimit()), - client, - quirks); } @Override - public void close() throws Exception { - client.close(); - } + public void close() throws Exception {} @Override public String toString() { diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java index 1ea0fce094..179e0c416e 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java @@ -39,7 +39,7 @@ public FlightSqlDriver(BufferAllocator allocator) { public AdbcDatabase open(Map parameters) throws AdbcException { String uri = PARAM_URI.get(parameters); if (uri == null) { - Object target = parameters.get("adbc.url"); + Object target = parameters.get(AdbcDriver.PARAM_URL); if (!(target instanceof String)) { throw AdbcException.invalidArgument( "[Flight SQL] Must provide String " + PARAM_URI + " parameter"); @@ -65,6 +65,6 @@ public AdbcDatabase open(Map parameters) throws AdbcException { } else { quirks = new SqlQuirks(); } - return new FlightSqlDatabase(allocator, location, (SqlQuirks) quirks); + return new FlightSqlDatabase(allocator, location, (SqlQuirks) quirks, parameters); } } diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java index 45b42df2ee..a4ea23dd0b 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java @@ -43,6 +43,10 @@ static AdbcException fromSqlException(SQLException e) { e.getErrorCode()); } + static AdbcException fromGeneralException(Exception ex) { + return new AdbcException(ex.getMessage(), ex, AdbcStatusCode.UNKNOWN, null, 0); + } + static AdbcStatusCode fromFlightStatusCode(FlightStatusCode code) { switch (code) { case OK: diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java index e64508b4bf..77cb2622d1 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java @@ -29,7 +29,6 @@ import org.apache.arrow.adbc.core.BulkIngestMode; import org.apache.arrow.adbc.core.PartitionDescriptor; import org.apache.arrow.adbc.sql.SqlQuirks; -import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightEndpoint; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightRuntimeException; @@ -44,8 +43,8 @@ public class FlightSqlStatement implements AdbcStatement { private final BufferAllocator allocator; - private final FlightSqlClient client; - private final LoadingCache clientCache; + private final FlightSqlClientWithCallOptions client; + private final LoadingCache clientCache; private final SqlQuirks quirks; // State for SQL queries @@ -57,8 +56,8 @@ public class FlightSqlStatement implements AdbcStatement { FlightSqlStatement( BufferAllocator allocator, - FlightSqlClient client, - LoadingCache clientCache, + FlightSqlClientWithCallOptions client, + LoadingCache clientCache, SqlQuirks quirks) { this.allocator = allocator; this.client = client; @@ -69,8 +68,8 @@ public class FlightSqlStatement implements AdbcStatement { static FlightSqlStatement ingestRoot( BufferAllocator allocator, - FlightSqlClient client, - LoadingCache clientCache, + FlightSqlClientWithCallOptions client, + LoadingCache clientCache, SqlQuirks quirks, String targetTableName, BulkIngestMode mode) { @@ -188,7 +187,7 @@ interface Execute { private R execute( Execute doPrepared, - Execute doRegular) + Execute doRegular) throws AdbcException { try { if (preparedStatement != null) { diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/InfoMetadataBuilder.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/InfoMetadataBuilder.java index 318405d6ce..aef679f63c 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/InfoMetadataBuilder.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/InfoMetadataBuilder.java @@ -31,7 +31,6 @@ import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightStream; -import org.apache.arrow.flight.sql.FlightSqlClient; import org.apache.arrow.flight.sql.impl.FlightSql; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.AutoCloseables; @@ -47,7 +46,7 @@ final class InfoMetadataBuilder implements AutoCloseable { private static final Map SUPPORTED_CODES = new HashMap<>(); private final Collection requestedCodes; - private final FlightSqlClient client; + private final FlightSqlClientWithCallOptions client; private VectorSchemaRoot root; private final UInt4Vector infoCodes; @@ -80,7 +79,8 @@ interface AddInfo { }); } - InfoMetadataBuilder(BufferAllocator allocator, FlightSqlClient client, int[] infoCodes) { + InfoMetadataBuilder( + BufferAllocator allocator, FlightSqlClientWithCallOptions client, int[] infoCodes) { if (infoCodes == null) { this.requestedCodes = new ArrayList<>(SUPPORTED_CODES.keySet()); this.requestedCodes.add(AdbcInfoCode.DRIVER_NAME.getValue()); diff --git a/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderTest.java b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderTest.java new file mode 100644 index 0000000000..19c66f8116 --- /dev/null +++ b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderTest.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.driver.flightsql; + +import static org.junit.Assert.assertThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.HashMap; +import java.util.Map; +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.AdbcDriver; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.adbc.core.AdbcStatusCode; +import org.apache.arrow.adbc.drivermanager.AdbcDriverManager; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallInfo; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightServerMiddleware; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.RequestContext; +import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator; +import org.apache.arrow.flight.auth2.GeneratedBearerTokenAuthenticator; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class HeaderTest { + + private FlightServer.Builder builder; + private FlightServer server; + private Map params; + private AdbcConnection connection; + private BufferAllocator allocator; + private HeaderValidator.Factory headerValidatorFactory; + + @BeforeEach + public void setUp() { + allocator = new RootAllocator(Long.MAX_VALUE); + headerValidatorFactory = new HeaderValidator.Factory(); + builder = + FlightServer.builder() + .middleware(HeaderValidator.KEY, headerValidatorFactory) + .location(Location.forGrpcInsecure("localhost", 0)) + .producer(new MockFlightSqlProducer()); + params = new HashMap<>(); + } + + @AfterEach + public void tearDown() throws Exception { + AutoCloseables.close(connection, server, allocator); + connection = null; + server = null; + allocator = null; + } + + @Test + public void testArbitraryHeader() throws Exception { + final String dummyValue = "dummy"; + final String dummyHeaderName = "test-header"; + params.put(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX + dummyHeaderName, dummyValue); + server = builder.build(); + server.start(); + connect(); + + CallHeaders headers = headerValidatorFactory.getHeadersReceivedAtRequest(0); + assertEquals(dummyValue, headers.get(dummyHeaderName)); + } + + @Test + public void testCookies() throws Exception { + builder.middleware(CookieMiddleware.KEY, new CookieMiddleware.Factory()); + server = builder.build(); + server.start(); + + params.put(FlightSqlConnectionProperties.WITH_COOKIE_MIDDLEWARE.getKey(), true); + connect(); + try (ArrowReader reader = connection.getInfo(new int[] {})) { + + } catch (Exception ex) { + // Swallow exceptions from the RPC call. Only interested in tracking metadata. + } + CallHeaders secondHeaders = headerValidatorFactory.getHeadersReceivedAtRequest(1); + assertTrue(secondHeaders.containsKey("cookie")); + } + + @Test + public void testBearerToken() throws Exception { + builder.headerAuthenticator( + new GeneratedBearerTokenAuthenticator( + new BasicCallHeaderAuthenticator((username, password) -> () -> username))); + server = builder.build(); + server.start(); + + params.put(AdbcDriver.PARAM_USERNAME.getKey(), "dummy_user"); + params.put(AdbcDriver.PARAM_PASSWORD.getKey(), "dummy_password"); + connect(); + try (ArrowReader reader = connection.getInfo(new int[] {})) { + + } catch (Exception ex) { + // Swallow exceptions from the RPC call. Only interested in tracking metadata. + } + CallHeaders secondHeaders = headerValidatorFactory.getHeadersReceivedAtRequest(1); + assertTrue(secondHeaders.get("authorization").contains("Bearer")); + } + + @Test + public void testUnauthenticated() throws Exception { + builder.headerAuthenticator( + new GeneratedBearerTokenAuthenticator( + new BasicCallHeaderAuthenticator((username, password) -> () -> username))); + server = builder.build(); + server.start(); + + AdbcException adbcException = assertThrows(AdbcException.class, this::connect); + assertEquals(AdbcStatusCode.UNAUTHENTICATED, adbcException.getStatus()); + } + + static class CookieMiddleware implements FlightServerMiddleware { + + public static final Key KEY = Key.of("CookieMiddleware"); + + @Override + public void onBeforeSendingHeaders(CallHeaders callHeaders) { + callHeaders.insert("set-cookie", "test=test_val"); + } + + @Override + public void onCallCompleted(CallStatus callStatus) {} + + @Override + public void onCallErrored(Throwable throwable) {} + + public static class Factory implements FlightServerMiddleware.Factory { + + @Override + public CookieMiddleware onCallStarted( + CallInfo callInfo, CallHeaders callHeaders, RequestContext requestContext) { + return new CookieMiddleware(); + } + } + } + + private void connect() throws Exception { + int port = server.getPort(); + String uri = String.format("grpc+tcp://%s:%d", "localhost", port); + params.put(AdbcDriver.PARAM_URI.getKey(), uri); + AdbcDatabase db = + AdbcDriverManager.getInstance() + .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params); + connection = db.connect(); + } +} diff --git a/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderValidator.java b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderValidator.java new file mode 100644 index 0000000000..f543dd9958 --- /dev/null +++ b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderValidator.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.driver.flightsql; + +import java.util.ArrayList; +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallInfo; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightCallHeaders; +import org.apache.arrow.flight.FlightServerMiddleware; +import org.apache.arrow.flight.RequestContext; + +public class HeaderValidator implements FlightServerMiddleware { + public static final Key KEY = Key.of("HeaderValidator"); + + @Override + public void onBeforeSendingHeaders(CallHeaders callHeaders) {} + + @Override + public void onCallCompleted(CallStatus callStatus) {} + + @Override + public void onCallErrored(Throwable throwable) {} + + public static class Factory implements FlightServerMiddleware.Factory { + + private final ArrayList headersReceived = new ArrayList<>(); + + @Override + public HeaderValidator onCallStarted( + CallInfo callInfo, CallHeaders callHeaders, RequestContext requestContext) { + CallHeaders cloneHeaders = cloneHeaders(callHeaders); + headersReceived.add(cloneHeaders); + return new HeaderValidator(); + } + + public CallHeaders getHeadersReceivedAtRequest(int request) { + return cloneHeaders(headersReceived.get(request)); + } + + private static CallHeaders cloneHeaders(CallHeaders headers) { + FlightCallHeaders cloneHeaders = new FlightCallHeaders(); + for (String key : headers.keys()) { + for (String value : headers.getAll(key)) { + cloneHeaders.insert(key, value); + } + } + return cloneHeaders; + } + } +} diff --git a/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/MutualTlsTest.java b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/MutualTlsTest.java new file mode 100644 index 0000000000..f98a7c3470 --- /dev/null +++ b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/MutualTlsTest.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.driver.flightsql; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.AdbcDriver; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.adbc.core.AdbcStatusCode; +import org.apache.arrow.adbc.drivermanager.AdbcDriverManager; +import org.apache.arrow.driver.jdbc.FlightServerTestRule; +import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication; +import org.apache.arrow.driver.jdbc.utils.FlightSqlTestCertificates; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +public class MutualTlsTest { + + @ClassRule public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE; + + private static final String USER_1 = "user1"; + private static final String PASS_1 = "pass1"; + private static final String TLS_ROOT_CERTS_PATH; + private static final String CLIENT_MTLS_CERT_PATH; + private static final String CLIENT_MTLS_KEY_PATH; + + static { + final FlightSqlTestCertificates.CertKeyPair certKey = + FlightSqlTestCertificates.exampleTlsCerts().get(0); + + TLS_ROOT_CERTS_PATH = certKey.cert.getPath(); + + final FlightSqlTestCertificates.CertKeyPair clientMTlsCertKey = + FlightSqlTestCertificates.exampleTlsCerts().get(1); + + CLIENT_MTLS_CERT_PATH = clientMTlsCertKey.cert.getPath(); + CLIENT_MTLS_KEY_PATH = clientMTlsCertKey.key.getPath(); + + UserPasswordAuthentication authentication = + new UserPasswordAuthentication.Builder().user(USER_1, PASS_1).build(); + + FLIGHT_SERVER_TEST_RULE = + new FlightServerTestRule.Builder() + .authentication(authentication) + .useEncryption(certKey.cert, certKey.key) + .useMTlsClientVerification(FlightSqlTestCertificates.exampleCACert()) + .producer(new MockFlightSqlProducer()) + .build(); + } + + private BufferAllocator allocator; + private Map params; + + @Before + public void setUp() { + allocator = new RootAllocator(Long.MAX_VALUE); + params = new HashMap<>(); + params.put(AdbcDriver.PARAM_USERNAME.getKey(), USER_1); + params.put(AdbcDriver.PARAM_PASSWORD.getKey(), PASS_1); + } + + @After + public void tearDown() throws Exception { + AutoCloseables.close(allocator); + } + + @Test + public void testClientTlsOnVerifyOffServerOnNoCertSpecified() throws Exception { + params.put(AdbcDriver.PARAM_URI.getKey(), getUri(true)); + params.put(FlightSqlConnectionProperties.TLS_SKIP_VERIFY.getKey(), true); + AdbcException adbcException = + assertThrows( + AdbcException.class, + () -> { + AdbcDatabase db = + AdbcDriverManager.getInstance() + .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params); + try (AdbcConnection conn = db.connect()) {} + }); + assertEquals(AdbcStatusCode.IO, adbcException.getStatus()); + } + + @Test + public void testClientTlsOnVerifyOnServerOnNoCertSpecified() throws Exception { + params.put(AdbcDriver.PARAM_URI.getKey(), getUri(true)); + try (InputStream stream = Files.newInputStream(Paths.get(TLS_ROOT_CERTS_PATH))) { + params.put(FlightSqlConnectionProperties.TLS_ROOT_CERTS.getKey(), stream); + AdbcException adbcException = + assertThrows( + AdbcException.class, + () -> { + AdbcDatabase db = + AdbcDriverManager.getInstance() + .connect( + FlightSqlDriverFactory.class.getCanonicalName(), allocator, params); + try (AdbcConnection conn = db.connect()) {} + }); + assertEquals(AdbcStatusCode.IO, adbcException.getStatus()); + } + } + + @Test + public void testClientTlsOnVerifyOnCertsSpecifiedServerOnNoCertSpecified() throws Exception { + params.put(AdbcDriver.PARAM_URI.getKey(), getUri(true)); + try (InputStream rootCertStream = Files.newInputStream(Paths.get(TLS_ROOT_CERTS_PATH)); + InputStream privateKeyStream = Files.newInputStream(Paths.get(CLIENT_MTLS_KEY_PATH)); + InputStream clientCertStream = Files.newInputStream(Paths.get(CLIENT_MTLS_CERT_PATH))) { + params.put(FlightSqlConnectionProperties.TLS_ROOT_CERTS.getKey(), rootCertStream); + params.put(FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.getKey(), privateKeyStream); + params.put(FlightSqlConnectionProperties.MTLS_CERT_CHAIN.getKey(), clientCertStream); + AdbcDatabase db = + AdbcDriverManager.getInstance() + .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params); + try (AdbcConnection conn = db.connect()) {} + } + } + + private String getUri(boolean withTls) { + String protocol = String.format("grpc%s", withTls ? "+tls" : "+tcp"); + return String.format( + "%s://%s:%d", + protocol, FLIGHT_SERVER_TEST_RULE.getHost(), FLIGHT_SERVER_TEST_RULE.getPort()); + } +} diff --git a/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/TlsTest.java b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/TlsTest.java new file mode 100644 index 0000000000..0301f3d8fb --- /dev/null +++ b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/TlsTest.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.driver.flightsql; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.AdbcDriver; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.adbc.core.AdbcStatusCode; +import org.apache.arrow.adbc.drivermanager.AdbcDriverManager; +import org.apache.arrow.driver.jdbc.FlightServerTestRule; +import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication; +import org.apache.arrow.driver.jdbc.utils.FlightSqlTestCertificates; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +public class TlsTest { + + @ClassRule public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE; + + private static final String USER_1 = "user1"; + private static final String PASS_1 = "pass1"; + private static final String TLS_ROOT_CERTS_PATH; + + static { + final FlightSqlTestCertificates.CertKeyPair certKey = + FlightSqlTestCertificates.exampleTlsCerts().get(0); + + TLS_ROOT_CERTS_PATH = certKey.cert.getPath(); + + UserPasswordAuthentication authentication = + new UserPasswordAuthentication.Builder().user(USER_1, PASS_1).build(); + + FLIGHT_SERVER_TEST_RULE = + new FlightServerTestRule.Builder() + .authentication(authentication) + .useEncryption(certKey.cert, certKey.key) + .producer(new MockFlightSqlProducer()) + .build(); + } + + private BufferAllocator allocator; + private Map params; + + @Before + public void setUp() { + allocator = new RootAllocator(Long.MAX_VALUE); + params = new HashMap<>(); + params.put(AdbcDriver.PARAM_USERNAME.getKey(), USER_1); + params.put(AdbcDriver.PARAM_PASSWORD.getKey(), PASS_1); + } + + @After + public void tearDown() throws Exception { + AutoCloseables.close(allocator); + } + + @Test + public void testClientTlsOffServerOn() { + params.put(AdbcDriver.PARAM_URI.getKey(), getUri(false)); + AdbcException adbcException = + assertThrows( + AdbcException.class, + () -> { + AdbcDatabase db = + AdbcDriverManager.getInstance() + .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params); + try (AdbcConnection conn = db.connect()) {} + }); + assertEquals(AdbcStatusCode.IO, adbcException.getStatus()); + } + + @Test + public void testClientTlsOnServerOnNoCertSpecified() { + params.put(AdbcDriver.PARAM_URI.getKey(), getUri(true)); + AdbcException adbcException = + assertThrows( + AdbcException.class, + () -> { + AdbcDatabase db = + AdbcDriverManager.getInstance() + .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params); + try (AdbcConnection conn = db.connect()) {} + }); + assertEquals(AdbcStatusCode.IO, adbcException.getStatus()); + } + + @Test + public void testClientTlsOnVerifyOffServerOnNoCertSpecified() throws Exception { + params.put(AdbcDriver.PARAM_URI.getKey(), getUri(true)); + params.put(FlightSqlConnectionProperties.TLS_SKIP_VERIFY.getKey(), true); + AdbcDatabase db = + AdbcDriverManager.getInstance() + .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params); + try (AdbcConnection conn = db.connect()) {} + } + + @Test + public void testClientTlsOnVerifyOnCertsSpecifiedServerOnNoCertSpecified() throws Exception { + params.put(AdbcDriver.PARAM_URI.getKey(), getUri(true)); + try (InputStream stream = Files.newInputStream(Paths.get(TLS_ROOT_CERTS_PATH))) { + params.put(FlightSqlConnectionProperties.TLS_ROOT_CERTS.getKey(), stream); + AdbcDatabase db = + AdbcDriverManager.getInstance() + .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params); + try (AdbcConnection conn = db.connect()) {} + } + } + + @Test + public void testClientTlsOnBadHostnameOverride() throws Exception { + params.put(AdbcDriver.PARAM_URI.getKey(), getUri(true)); + params.put(FlightSqlConnectionProperties.TLS_OVERRIDE_HOSTNAME.getKey(), "fakehost"); + try (InputStream stream = Files.newInputStream(Paths.get(TLS_ROOT_CERTS_PATH))) { + params.put(FlightSqlConnectionProperties.TLS_ROOT_CERTS.getKey(), stream); + AdbcException adbcException = + assertThrows( + AdbcException.class, + () -> { + AdbcDatabase db = + AdbcDriverManager.getInstance() + .connect( + FlightSqlDriverFactory.class.getCanonicalName(), allocator, params); + try (AdbcConnection conn = db.connect()) {} + }); + assertEquals(AdbcStatusCode.IO, adbcException.getStatus()); + } + } + + @Test + public void testClientTlsOnGoodHostnameOverride() throws Exception { + params.put(AdbcDriver.PARAM_URI.getKey(), getUri(true)); + params.put( + FlightSqlConnectionProperties.TLS_OVERRIDE_HOSTNAME.getKey(), + FLIGHT_SERVER_TEST_RULE.getHost()); + try (InputStream stream = Files.newInputStream(Paths.get(TLS_ROOT_CERTS_PATH))) { + params.put(FlightSqlConnectionProperties.TLS_ROOT_CERTS.getKey(), stream); + AdbcDatabase db = + AdbcDriverManager.getInstance() + .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params); + try (AdbcConnection conn = db.connect()) {} + } + } + + private String getUri(boolean withTls) { + String protocol = String.format("grpc%s", withTls ? "+tls" : "+tcp"); + return String.format( + "%s://%s:%d", + protocol, FLIGHT_SERVER_TEST_RULE.getHost(), FLIGHT_SERVER_TEST_RULE.getPort()); + } +} diff --git a/java/pom.xml b/java/pom.xml index a632959f51..5b02f67f6c 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -97,28 +97,10 @@ org.apache.arrow - arrow-jdbc - ${dep.arrow.version} - - - org.apache.arrow - arrow-memory-core - ${dep.arrow.version} - - - org.apache.arrow - arrow-vector - ${dep.arrow.version} - - - org.apache.arrow - flight-core - ${dep.arrow.version} - - - org.apache.arrow - flight-sql + arrow-bom ${dep.arrow.version} + pom + import diff --git a/testing b/testing new file mode 160000 index 0000000000..25d16511e8 --- /dev/null +++ b/testing @@ -0,0 +1 @@ +Subproject commit 25d16511e8d42c2744a1d94d90169e3a36e92631