diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml
index 656bc5e9a8..88609cba50 100644
--- a/.github/workflows/java.yml
+++ b/.github/workflows/java.yml
@@ -48,6 +48,7 @@ jobs:
with:
fetch-depth: 0
persist-credentials: false
+ submodules: recursive
- uses: actions/setup-java@v4
with:
cache: "maven"
@@ -81,6 +82,7 @@ jobs:
with:
fetch-depth: 0
persist-credentials: false
+ submodules: recursive
- uses: actions/setup-java@v4
with:
cache: "maven"
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000000..e880666df1
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[submodule "testing"]
+ path = testing
+ url = https://github.com/apache/arrow-testing.git
diff --git a/java/driver-manager/pom.xml b/java/driver-manager/pom.xml
index 79ad55b992..8a083a58a8 100644
--- a/java/driver-manager/pom.xml
+++ b/java/driver-manager/pom.xml
@@ -42,7 +42,6 @@
org.apache.arrow
arrow-memory-unsafe
- ${dep.arrow.version}
test
diff --git a/java/driver/flight-sql/pom.xml b/java/driver/flight-sql/pom.xml
index 26dd8a7262..ceceaa8249 100644
--- a/java/driver/flight-sql/pom.xml
+++ b/java/driver/flight-sql/pom.xml
@@ -78,5 +78,43 @@
junit-jupiter
test
+
+ org.junit.vintage
+ junit-vintage-engine
+ test
+
+
+
+ junit
+ junit
+ 4.13.1
+ test
+
+
+ org.apache.arrow
+ flight-sql-jdbc-core
+ test
+
+
+ org.apache.arrow
+ flight-sql-jdbc-core
+ ${dep.arrow.version}
+ test
+ tests
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+
+
+ ${project.basedir}/../../../testing/data
+
+
+
+
+
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightInfoReader.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightInfoReader.java
index 9b0cda91dc..6850be0639 100644
--- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightInfoReader.java
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightInfoReader.java
@@ -24,12 +24,10 @@
import java.util.Objects;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.core.AdbcStatusCode;
-import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Location;
-import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
@@ -40,8 +38,8 @@
/** An ArrowReader that wraps a FlightInfo. */
public class FlightInfoReader extends ArrowReader {
private final Schema schema;
- private final FlightSqlClient client;
- private final LoadingCache clientCache;
+ private final FlightSqlClientWithCallOptions client;
+ private final LoadingCache clientCache;
private final List flightEndpoints;
private int nextEndpointIndex;
private FlightStream currentStream;
@@ -49,8 +47,8 @@ public class FlightInfoReader extends ArrowReader {
FlightInfoReader(
BufferAllocator allocator,
- FlightSqlClient client,
- LoadingCache clientCache,
+ FlightSqlClientWithCallOptions client,
+ LoadingCache clientCache,
List flightEndpoints)
throws AdbcException {
super(allocator);
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlClientWithCallOptions.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlClientWithCallOptions.java
new file mode 100644
index 0000000000..f7028cb55f
--- /dev/null
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlClientWithCallOptions.java
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.adbc.driver.flightsql;
+
+import static org.apache.arrow.flight.sql.FlightSqlClient.PreparedStatement;
+import static org.apache.arrow.flight.sql.FlightSqlClient.Savepoint;
+import static org.apache.arrow.flight.sql.FlightSqlClient.SubstraitPlan;
+import static org.apache.arrow.flight.sql.FlightSqlClient.Transaction;
+
+import java.util.List;
+import org.apache.arrow.flight.CallOption;
+import org.apache.arrow.flight.CancelFlightInfoRequest;
+import org.apache.arrow.flight.CancelFlightInfoResult;
+import org.apache.arrow.flight.FlightDescriptor;
+import org.apache.arrow.flight.FlightEndpoint;
+import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.flight.RenewFlightEndpointRequest;
+import org.apache.arrow.flight.SchemaResult;
+import org.apache.arrow.flight.Ticket;
+import org.apache.arrow.flight.sql.CancelResult;
+import org.apache.arrow.flight.sql.FlightSqlClient;
+import org.apache.arrow.flight.sql.impl.FlightSql;
+import org.apache.arrow.flight.sql.util.TableRef;
+import org.apache.arrow.util.AutoCloseables;
+
+/** A wrapper around FlightSqlClient which automatically adds CallOptions to each RPC call. */
+public class FlightSqlClientWithCallOptions implements AutoCloseable {
+ private final FlightSqlClient client;
+ private final CallOption[] connectionOptions;
+
+ public FlightSqlClientWithCallOptions(FlightSqlClient client, CallOption... options) {
+ this.client = client;
+ this.connectionOptions = options;
+ }
+
+ public FlightInfo execute(String query, CallOption... options) {
+ return client.execute(query, combine(options));
+ }
+
+ public FlightInfo execute(String query, Transaction transaction, CallOption... options) {
+ return client.execute(query, transaction, combine(options));
+ }
+
+ public FlightInfo executeSubstrait(SubstraitPlan plan, CallOption... options) {
+ return client.executeSubstrait(plan, combine(options));
+ }
+
+ public FlightInfo executeSubstrait(
+ SubstraitPlan plan, Transaction transaction, CallOption... options) {
+ return client.executeSubstrait(plan, transaction, combine(options));
+ }
+
+ public SchemaResult getExecuteSchema(
+ String query, Transaction transaction, CallOption... options) {
+ return client.getExecuteSchema(query, transaction, combine(options));
+ }
+
+ public SchemaResult getExecuteSchema(String query, CallOption... options) {
+ return client.getExecuteSchema(query, combine(options));
+ }
+
+ public SchemaResult getExecuteSubstraitSchema(
+ SubstraitPlan plan, Transaction transaction, CallOption... options) {
+ return client.getExecuteSubstraitSchema(plan, transaction, combine(options));
+ }
+
+ public SchemaResult getExecuteSubstraitSchema(
+ SubstraitPlan substraitPlan, CallOption... options) {
+ return client.getExecuteSubstraitSchema(substraitPlan, combine(options));
+ }
+
+ public long executeUpdate(String query, CallOption... options) {
+ return client.executeUpdate(query, combine(options));
+ }
+
+ public long executeUpdate(String query, Transaction transaction, CallOption... options) {
+ return client.executeUpdate(query, transaction, combine(options));
+ }
+
+ public long executeSubstraitUpdate(SubstraitPlan plan, CallOption... options) {
+ return client.executeSubstraitUpdate(plan, combine(options));
+ }
+
+ public long executeSubstraitUpdate(
+ SubstraitPlan plan, Transaction transaction, CallOption... options) {
+ return client.executeSubstraitUpdate(plan, transaction, combine(options));
+ }
+
+ public FlightInfo getCatalogs(CallOption... options) {
+ return client.getCatalogs(options);
+ }
+
+ public SchemaResult getCatalogsSchema(CallOption... options) {
+ return client.getCatalogsSchema(options);
+ }
+
+ public FlightInfo getSchemas(
+ String catalog, String dbSchemaFilterPattern, CallOption... options) {
+ return client.getSchemas(catalog, dbSchemaFilterPattern, combine(options));
+ }
+
+ public SchemaResult getSchemasSchema(CallOption... options) {
+ return client.getSchemasSchema(options);
+ }
+
+ public SchemaResult getSchema(FlightDescriptor descriptor, CallOption... options) {
+ return client.getSchema(descriptor, combine(options));
+ }
+
+ public FlightStream getStream(Ticket ticket, CallOption... options) {
+ return client.getStream(ticket, combine(options));
+ }
+
+ public FlightInfo getSqlInfo(FlightSql.SqlInfo... info) {
+ return client.getSqlInfo(info);
+ }
+
+ public FlightInfo getSqlInfo(FlightSql.SqlInfo[] info, CallOption... options) {
+ return client.getSqlInfo(info, combine(options));
+ }
+
+ public FlightInfo getSqlInfo(int[] info, CallOption... options) {
+ return client.getSqlInfo(info, combine(options));
+ }
+
+ public FlightInfo getSqlInfo(Iterable info, CallOption... options) {
+ return client.getSqlInfo(info, combine(options));
+ }
+
+ public SchemaResult getSqlInfoSchema(CallOption... options) {
+ return client.getSqlInfoSchema(options);
+ }
+
+ public FlightInfo getXdbcTypeInfo(int dataType, CallOption... options) {
+ return client.getXdbcTypeInfo(dataType, combine(options));
+ }
+
+ public FlightInfo getXdbcTypeInfo(CallOption... options) {
+ return client.getXdbcTypeInfo(options);
+ }
+
+ public SchemaResult getXdbcTypeInfoSchema(CallOption... options) {
+ return client.getXdbcTypeInfoSchema(options);
+ }
+
+ public FlightInfo getTables(
+ String catalog,
+ String dbSchemaFilterPattern,
+ String tableFilterPattern,
+ List tableTypes,
+ boolean includeSchema,
+ CallOption... options) {
+ return client.getTables(
+ catalog,
+ dbSchemaFilterPattern,
+ tableFilterPattern,
+ tableTypes,
+ includeSchema,
+ combine(options));
+ }
+
+ public SchemaResult getTablesSchema(boolean includeSchema, CallOption... options) {
+ return client.getTablesSchema(includeSchema, combine(options));
+ }
+
+ public FlightInfo getPrimaryKeys(TableRef tableRef, CallOption... options) {
+ return client.getPrimaryKeys(tableRef, combine(options));
+ }
+
+ public SchemaResult getPrimaryKeysSchema(CallOption... options) {
+ return client.getPrimaryKeysSchema(options);
+ }
+
+ public FlightInfo getExportedKeys(TableRef tableRef, CallOption... options) {
+ return client.getExportedKeys(tableRef, combine(options));
+ }
+
+ public SchemaResult getExportedKeysSchema(CallOption... options) {
+ return client.getExportedKeysSchema(options);
+ }
+
+ public FlightInfo getImportedKeys(TableRef tableRef, CallOption... options) {
+ return client.getImportedKeys(tableRef, combine(options));
+ }
+
+ public SchemaResult getImportedKeysSchema(CallOption... options) {
+ return client.getImportedKeysSchema(options);
+ }
+
+ public FlightInfo getCrossReference(
+ TableRef pkTableRef, TableRef fkTableRef, CallOption... options) {
+ return client.getCrossReference(pkTableRef, fkTableRef, combine(options));
+ }
+
+ public SchemaResult getCrossReferenceSchema(CallOption... options) {
+ return client.getCrossReferenceSchema(options);
+ }
+
+ public FlightInfo getTableTypes(CallOption... options) {
+ return client.getTableTypes(options);
+ }
+
+ public SchemaResult getTableTypesSchema(CallOption... options) {
+ return client.getTableTypesSchema(options);
+ }
+
+ public PreparedStatement prepare(String query, CallOption... options) {
+ return client.prepare(query, combine(options));
+ }
+
+ public PreparedStatement prepare(
+ String query, FlightSqlClient.Transaction transaction, CallOption... options) {
+ return client.prepare(query, transaction, combine(options));
+ }
+
+ public PreparedStatement prepare(SubstraitPlan plan, CallOption... options) {
+ return client.prepare(plan, combine(options));
+ }
+
+ public PreparedStatement prepare(
+ SubstraitPlan plan, Transaction transaction, CallOption... options) {
+ return client.prepare(plan, transaction, combine(options));
+ }
+
+ public Transaction beginTransaction(CallOption... options) {
+ return client.beginTransaction(options);
+ }
+
+ public Savepoint beginSavepoint(Transaction transaction, String name, CallOption... options) {
+ return client.beginSavepoint(transaction, name, combine(options));
+ }
+
+ public void commit(Transaction transaction, CallOption... options) {
+ client.commit(transaction, combine(options));
+ }
+
+ public void release(Savepoint savepoint, CallOption... options) {
+ client.release(savepoint, combine(options));
+ }
+
+ public void rollback(Transaction transaction, CallOption... options) {
+ client.rollback(transaction, combine(options));
+ }
+
+ public void rollback(Savepoint savepoint, CallOption... options) {
+ client.rollback(savepoint, combine(options));
+ }
+
+ public CancelFlightInfoResult cancelFlightInfo(
+ CancelFlightInfoRequest request, CallOption... options) {
+ return client.cancelFlightInfo(request, combine(options));
+ }
+
+ public CancelResult cancelQuery(FlightInfo info, CallOption... options) {
+ return client.cancelQuery(info, combine(options));
+ }
+
+ public FlightEndpoint renewFlightEndpoint(
+ RenewFlightEndpointRequest request, CallOption... options) {
+ return client.renewFlightEndpoint(request, combine(options));
+ }
+
+ public void close() throws Exception {
+ AutoCloseables.close(client);
+ }
+
+ private CallOption[] combine(CallOption... options) {
+ final CallOption[] result = new CallOption[connectionOptions.length + options.length];
+ System.arraycopy(connectionOptions, 0, result, 0, connectionOptions.length);
+ System.arraycopy(options, 0, result, connectionOptions.length, options.length);
+ return result;
+ }
+}
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java
index f583f2b866..c079060ec8 100644
--- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java
@@ -20,19 +20,35 @@
import com.github.benmanes.caffeine.cache.LoadingCache;
import com.github.benmanes.caffeine.cache.RemovalCause;
import com.google.protobuf.InvalidProtocolBufferException;
+import java.io.ByteArrayInputStream;
+import java.io.DataInputStream;
+import java.io.IOException;
+import java.io.InputStream;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
+import java.util.ArrayList;
import java.util.Collections;
+import java.util.Map;
+import java.util.Optional;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import org.apache.arrow.adbc.core.AdbcConnection;
+import org.apache.arrow.adbc.core.AdbcDriver;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.core.AdbcStatement;
+import org.apache.arrow.adbc.core.AdbcStatusCode;
import org.apache.arrow.adbc.core.BulkIngestMode;
import org.apache.arrow.adbc.sql.SqlQuirks;
+import org.apache.arrow.flight.CallOption;
+import org.apache.arrow.flight.FlightCallHeaders;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightEndpoint;
+import org.apache.arrow.flight.HeaderCallOption;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.Ticket;
+import org.apache.arrow.flight.auth2.BasicAuthCredentialWriter;
+import org.apache.arrow.flight.client.ClientCookieMiddleware;
+import org.apache.arrow.flight.grpc.CredentialCallOption;
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.memory.BufferAllocator;
@@ -42,28 +58,57 @@
public class FlightSqlConnection implements AdbcConnection {
private final BufferAllocator allocator;
- private final FlightSqlClient client;
+ private final AtomicInteger counter;
+ private final FlightSqlClientWithCallOptions client;
private final SqlQuirks quirks;
- private final LoadingCache clientCache;
+ private final Map parameters;
+ private final LoadingCache clientCache;
- FlightSqlConnection(BufferAllocator allocator, FlightClient client, SqlQuirks quirks) {
+ // Cached data to use across additional connections.
+ private ClientCookieMiddleware.Factory cookieMiddlewareFactory;
+ private CallOption[] callOptions;
+
+ // Used to cache the InputStream content as a byte array since
+ // subsequent connections may need to use it but it is supplied as a stream.
+ private byte[] mtlsCertChainBytes;
+ private byte[] mtlsPrivateKeyBytes;
+ private byte[] tlsRootCertsBytes;
+
+ FlightSqlConnection(
+ BufferAllocator allocator,
+ SqlQuirks quirks,
+ Location location,
+ Map parameters)
+ throws AdbcException {
this.allocator = allocator;
- this.client = new FlightSqlClient(client);
+ this.counter = new AtomicInteger(0);
this.quirks = quirks;
+ this.parameters = parameters;
+ FlightSqlClient flightSqlClient = new FlightSqlClient(createInitialConnection(location));
+ this.client = new FlightSqlClientWithCallOptions(flightSqlClient, callOptions);
this.clientCache =
Caffeine.newBuilder()
.expireAfterAccess(5, TimeUnit.MINUTES)
.removalListener(
- (Location key, FlightClient value, RemovalCause cause) -> {
+ (Location key, FlightSqlClientWithCallOptions value, RemovalCause cause) -> {
if (value == null) return;
try {
value.close();
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- throw new RuntimeException(e);
+ } catch (Exception ex) {
+ if (ex instanceof InterruptedException) {
+ Thread.currentThread().interrupt();
+ }
+ throw new RuntimeException(ex);
}
})
- .build(location -> FlightClient.builder(allocator, location).build());
+ .build(
+ loc -> {
+ FlightClient client = buildClient(loc);
+ client.handshake(callOptions);
+ return new FlightSqlClientWithCallOptions(
+ new FlightSqlClient(client), callOptions);
+ });
+ this.clientCache.put(location, this.client);
}
@Override
@@ -137,6 +182,7 @@ public void setAutoCommit(boolean enableAutoCommit) throws AdbcException {
@Override
public void close() throws Exception {
+ clientCache.invalidateAll();
AutoCloseables.close(client, allocator);
}
@@ -144,4 +190,164 @@ public void close() throws Exception {
public String toString() {
return "FlightSqlConnection{" + "client=" + client + '}';
}
+
+ /**
+ * Initialize cached data to share between connections and create, test, and authenticate the
+ * first connection.
+ */
+ private FlightClient createInitialConnection(Location location) throws AdbcException {
+ // Setup cached pre-connection properties.
+ try {
+ final InputStream mtlsCertChain =
+ FlightSqlConnectionProperties.MTLS_CERT_CHAIN.get(parameters);
+ if (mtlsCertChain != null) {
+ this.mtlsCertChainBytes = inputStreamToBytes(mtlsCertChain);
+ }
+
+ final InputStream mtlsPrivateKey =
+ FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.get(parameters);
+ if (mtlsPrivateKey != null) {
+ this.mtlsPrivateKeyBytes = inputStreamToBytes(mtlsPrivateKey);
+ }
+
+ final InputStream tlsRootCerts = FlightSqlConnectionProperties.TLS_ROOT_CERTS.get(parameters);
+ if (tlsRootCerts != null) {
+ this.tlsRootCertsBytes = inputStreamToBytes(tlsRootCerts);
+ }
+ } catch (IOException ex) {
+ throw new AdbcException(
+ String.format(
+ "Error reading stream for one of the options %s, %s, %s.",
+ FlightSqlConnectionProperties.MTLS_CERT_CHAIN.getKey(),
+ FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.getKey(),
+ FlightSqlConnectionProperties.TLS_ROOT_CERTS.getKey()),
+ ex,
+ AdbcStatusCode.IO,
+ null,
+ 0);
+ }
+
+ final boolean useCookieMiddleware =
+ Boolean.TRUE.equals(FlightSqlConnectionProperties.WITH_COOKIE_MIDDLEWARE.get(parameters));
+ if (useCookieMiddleware) {
+ this.cookieMiddlewareFactory = new ClientCookieMiddleware.Factory();
+ }
+
+ // Build the client using the above properties.
+ final FlightClient client = buildClient(location);
+
+ // Add user-specified headers.
+ ArrayList options = new ArrayList<>();
+ final FlightCallHeaders callHeaders = new FlightCallHeaders();
+ for (Map.Entry parameter : parameters.entrySet()) {
+ if (parameter.getKey().startsWith(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX)) {
+ String userHeaderName =
+ parameter
+ .getKey()
+ .substring(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX.length());
+
+ if (parameter.getValue() instanceof String) {
+ callHeaders.insert(userHeaderName, (String) parameter.getValue());
+ } else if (parameter.getValue() instanceof byte[]) {
+ callHeaders.insert(userHeaderName, (byte[]) parameter.getValue());
+ } else {
+ throw new AdbcException(
+ String.format(
+ "Header values must be String or byte[]. The header failing was %s.",
+ parameter.getKey()),
+ null,
+ AdbcStatusCode.INVALID_ARGUMENT,
+ null,
+ 0);
+ }
+ }
+ }
+
+ options.add(new HeaderCallOption(callHeaders));
+
+ // Test the connection.
+ String username = AdbcDriver.PARAM_USERNAME.get(parameters);
+ String password = AdbcDriver.PARAM_PASSWORD.get(parameters);
+ if (username != null && password != null) {
+ Optional bearerToken =
+ client.authenticateBasicToken(username, password);
+ options.add(
+ bearerToken.orElse(
+ new CredentialCallOption(new BasicAuthCredentialWriter(username, password))));
+ this.callOptions = options.toArray(new CallOption[0]);
+ } else {
+ this.callOptions = options.toArray(new CallOption[0]);
+ client.handshake(this.callOptions);
+ }
+
+ return client;
+ }
+
+ /** Returns a yet-to-be authenticated FlightClient */
+ private FlightClient buildClient(Location location) throws AdbcException {
+ final FlightClient.Builder builder =
+ FlightClient.builder()
+ .allocator(
+ allocator.newChildAllocator(
+ "adbc-flightclient-connection-" + counter.getAndIncrement(),
+ 0,
+ allocator.getLimit()))
+ .location(location);
+
+ // Configure TLS options.
+ if (mtlsCertChainBytes != null && mtlsPrivateKeyBytes != null) {
+ builder.clientCertificate(
+ new ByteArrayInputStream(mtlsCertChainBytes),
+ new ByteArrayInputStream(mtlsPrivateKeyBytes));
+ } else if (mtlsCertChainBytes != null) {
+ throw new AdbcException(
+ String.format(
+ "Must provide both %s and %s or neither. %s provided only.",
+ FlightSqlConnectionProperties.MTLS_CERT_CHAIN.getKey(),
+ FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.getKey(),
+ FlightSqlConnectionProperties.MTLS_CERT_CHAIN.getKey()),
+ null,
+ AdbcStatusCode.INVALID_ARGUMENT,
+ null,
+ 0);
+ } else if (mtlsPrivateKeyBytes != null) {
+ throw new AdbcException(
+ String.format(
+ "Must provide both %s and %s or neither. %s provided only.",
+ FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.getKey(),
+ FlightSqlConnectionProperties.MTLS_CERT_CHAIN.getKey(),
+ FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.getKey()),
+ null,
+ AdbcStatusCode.INVALID_ARGUMENT,
+ null,
+ 0);
+ }
+
+ if (tlsRootCertsBytes != null) {
+ builder.trustedCertificates(new ByteArrayInputStream(tlsRootCertsBytes));
+ }
+
+ if (Boolean.TRUE.equals(FlightSqlConnectionProperties.TLS_SKIP_VERIFY.get(parameters))) {
+ builder.verifyServer(false);
+ }
+
+ String hostnameOverride = FlightSqlConnectionProperties.TLS_OVERRIDE_HOSTNAME.get(parameters);
+ if (hostnameOverride != null) {
+ builder.overrideHostname(hostnameOverride);
+ }
+
+ // Setup cookies if needed.
+ if (cookieMiddlewareFactory != null) {
+ builder.intercept(cookieMiddlewareFactory);
+ }
+
+ return builder.build();
+ }
+
+ private static byte[] inputStreamToBytes(InputStream stream) throws IOException {
+ byte[] bytes = new byte[stream.available()];
+ DataInputStream dataInputStream = new DataInputStream(stream);
+ dataInputStream.readFully(bytes);
+ return bytes;
+ }
}
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java
new file mode 100644
index 0000000000..4ab1955a1b
--- /dev/null
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.adbc.driver.flightsql;
+
+import java.io.InputStream;
+import org.apache.arrow.adbc.core.TypedKey;
+
+/** Defines connection options that are used by the FlightSql driver. */
+public interface FlightSqlConnectionProperties {
+ TypedKey MTLS_CERT_CHAIN =
+ new TypedKey<>("adbc.flight.sql.client_option.mtls_cert_chain", InputStream.class);
+ TypedKey MTLS_PRIVATE_KEY =
+ new TypedKey<>("adbc.flight.sql.client_option.mtls_private_key", InputStream.class);
+ TypedKey TLS_OVERRIDE_HOSTNAME =
+ new TypedKey<>("adbc.flight.sql.client_option.tls_override_hostname", String.class);
+ TypedKey TLS_SKIP_VERIFY =
+ new TypedKey<>("adbc.flight.sql.client_option.tls_skip_verify", Boolean.class);
+ TypedKey TLS_ROOT_CERTS =
+ new TypedKey<>("adbc.flight.sql.client_option.tls_root_certs", InputStream.class);
+ TypedKey WITH_COOKIE_MIDDLEWARE =
+ new TypedKey<>("adbc.flight.sql.rpc.with_cookie_middleware", Boolean.class);
+ String RPC_CALL_HEADER_PREFIX = "adbc.flight.sql.rpc.call_header.";
+}
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java
index 11ca360579..af8221d6b0 100644
--- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java
@@ -17,56 +17,66 @@
package org.apache.arrow.adbc.driver.flightsql;
+import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.arrow.adbc.core.AdbcConnection;
import org.apache.arrow.adbc.core.AdbcDatabase;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.sql.SqlQuirks;
-import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.Location;
import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.util.AutoCloseables;
/** An instance of a database (e.g. a handle to an in-memory database). */
public final class FlightSqlDatabase implements AdbcDatabase {
private final BufferAllocator allocator;
private final Location location;
private final SqlQuirks quirks;
- private final FlightClient client;
private final AtomicInteger counter;
+ private final Map parameters;
- FlightSqlDatabase(BufferAllocator allocator, Location location, SqlQuirks quirks)
+ FlightSqlDatabase(
+ BufferAllocator allocator,
+ Location location,
+ SqlQuirks quirks,
+ Map parameters)
throws AdbcException {
this.allocator = allocator;
this.location = location;
this.quirks = quirks;
- try {
- this.client = FlightClient.builder(allocator, location).build();
- } catch (FlightRuntimeException e) {
- throw FlightSqlDriverUtil.fromFlightException(e);
- }
this.counter = new AtomicInteger();
+ this.parameters = parameters;
}
@Override
public AdbcConnection connect() throws AdbcException {
- final FlightClient client;
+ final int count = counter.getAndIncrement();
+ BufferAllocator connectionAllocator =
+ allocator.newChildAllocator("adbc-flight-connection-" + count, 0, allocator.getLimit());
try {
- client = FlightClient.builder(allocator, location).build();
- } catch (FlightRuntimeException e) {
- throw FlightSqlDriverUtil.fromFlightException(e);
+ return new FlightSqlConnection(connectionAllocator, quirks, location, parameters);
+ } catch (FlightRuntimeException ex) {
+ AdbcException adbcException = FlightSqlDriverUtil.fromFlightException(ex);
+ try {
+ AutoCloseables.close(connectionAllocator);
+ } catch (Exception e) {
+ adbcException.addSuppressed(e);
+ }
+ throw adbcException;
+ } catch (Exception ex) {
+ AdbcException adbcException = FlightSqlDriverUtil.fromGeneralException(ex);
+ try {
+ AutoCloseables.close(connectionAllocator);
+ } catch (Exception e) {
+ adbcException.addSuppressed(e);
+ }
+ throw adbcException;
}
- final int count = counter.getAndIncrement();
- return new FlightSqlConnection(
- allocator.newChildAllocator("adbc-jdbc-connection-" + count, 0, allocator.getLimit()),
- client,
- quirks);
}
@Override
- public void close() throws Exception {
- client.close();
- }
+ public void close() throws Exception {}
@Override
public String toString() {
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java
index 1ea0fce094..179e0c416e 100644
--- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java
@@ -39,7 +39,7 @@ public FlightSqlDriver(BufferAllocator allocator) {
public AdbcDatabase open(Map parameters) throws AdbcException {
String uri = PARAM_URI.get(parameters);
if (uri == null) {
- Object target = parameters.get("adbc.url");
+ Object target = parameters.get(AdbcDriver.PARAM_URL);
if (!(target instanceof String)) {
throw AdbcException.invalidArgument(
"[Flight SQL] Must provide String " + PARAM_URI + " parameter");
@@ -65,6 +65,6 @@ public AdbcDatabase open(Map parameters) throws AdbcException {
} else {
quirks = new SqlQuirks();
}
- return new FlightSqlDatabase(allocator, location, (SqlQuirks) quirks);
+ return new FlightSqlDatabase(allocator, location, (SqlQuirks) quirks, parameters);
}
}
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java
index 45b42df2ee..a4ea23dd0b 100644
--- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java
@@ -43,6 +43,10 @@ static AdbcException fromSqlException(SQLException e) {
e.getErrorCode());
}
+ static AdbcException fromGeneralException(Exception ex) {
+ return new AdbcException(ex.getMessage(), ex, AdbcStatusCode.UNKNOWN, null, 0);
+ }
+
static AdbcStatusCode fromFlightStatusCode(FlightStatusCode code) {
switch (code) {
case OK:
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java
index e64508b4bf..77cb2622d1 100644
--- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java
@@ -29,7 +29,6 @@
import org.apache.arrow.adbc.core.BulkIngestMode;
import org.apache.arrow.adbc.core.PartitionDescriptor;
import org.apache.arrow.adbc.sql.SqlQuirks;
-import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightRuntimeException;
@@ -44,8 +43,8 @@
public class FlightSqlStatement implements AdbcStatement {
private final BufferAllocator allocator;
- private final FlightSqlClient client;
- private final LoadingCache clientCache;
+ private final FlightSqlClientWithCallOptions client;
+ private final LoadingCache clientCache;
private final SqlQuirks quirks;
// State for SQL queries
@@ -57,8 +56,8 @@ public class FlightSqlStatement implements AdbcStatement {
FlightSqlStatement(
BufferAllocator allocator,
- FlightSqlClient client,
- LoadingCache clientCache,
+ FlightSqlClientWithCallOptions client,
+ LoadingCache clientCache,
SqlQuirks quirks) {
this.allocator = allocator;
this.client = client;
@@ -69,8 +68,8 @@ public class FlightSqlStatement implements AdbcStatement {
static FlightSqlStatement ingestRoot(
BufferAllocator allocator,
- FlightSqlClient client,
- LoadingCache clientCache,
+ FlightSqlClientWithCallOptions client,
+ LoadingCache clientCache,
SqlQuirks quirks,
String targetTableName,
BulkIngestMode mode) {
@@ -188,7 +187,7 @@ interface Execute {
private R execute(
Execute doPrepared,
- Execute doRegular)
+ Execute doRegular)
throws AdbcException {
try {
if (preparedStatement != null) {
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/InfoMetadataBuilder.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/InfoMetadataBuilder.java
index 318405d6ce..aef679f63c 100644
--- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/InfoMetadataBuilder.java
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/InfoMetadataBuilder.java
@@ -31,7 +31,6 @@
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightStream;
-import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.flight.sql.impl.FlightSql;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
@@ -47,7 +46,7 @@ final class InfoMetadataBuilder implements AutoCloseable {
private static final Map SUPPORTED_CODES = new HashMap<>();
private final Collection requestedCodes;
- private final FlightSqlClient client;
+ private final FlightSqlClientWithCallOptions client;
private VectorSchemaRoot root;
private final UInt4Vector infoCodes;
@@ -80,7 +79,8 @@ interface AddInfo {
});
}
- InfoMetadataBuilder(BufferAllocator allocator, FlightSqlClient client, int[] infoCodes) {
+ InfoMetadataBuilder(
+ BufferAllocator allocator, FlightSqlClientWithCallOptions client, int[] infoCodes) {
if (infoCodes == null) {
this.requestedCodes = new ArrayList<>(SUPPORTED_CODES.keySet());
this.requestedCodes.add(AdbcInfoCode.DRIVER_NAME.getValue());
diff --git a/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderTest.java b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderTest.java
new file mode 100644
index 0000000000..19c66f8116
--- /dev/null
+++ b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderTest.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.adbc.driver.flightsql;
+
+import static org.junit.Assert.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.arrow.adbc.core.AdbcConnection;
+import org.apache.arrow.adbc.core.AdbcDatabase;
+import org.apache.arrow.adbc.core.AdbcDriver;
+import org.apache.arrow.adbc.core.AdbcException;
+import org.apache.arrow.adbc.core.AdbcStatusCode;
+import org.apache.arrow.adbc.drivermanager.AdbcDriverManager;
+import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer;
+import org.apache.arrow.flight.CallHeaders;
+import org.apache.arrow.flight.CallInfo;
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.FlightServerMiddleware;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.flight.RequestContext;
+import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator;
+import org.apache.arrow.flight.auth2.GeneratedBearerTokenAuthenticator;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.vector.ipc.ArrowReader;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+public class HeaderTest {
+
+ private FlightServer.Builder builder;
+ private FlightServer server;
+ private Map params;
+ private AdbcConnection connection;
+ private BufferAllocator allocator;
+ private HeaderValidator.Factory headerValidatorFactory;
+
+ @BeforeEach
+ public void setUp() {
+ allocator = new RootAllocator(Long.MAX_VALUE);
+ headerValidatorFactory = new HeaderValidator.Factory();
+ builder =
+ FlightServer.builder()
+ .middleware(HeaderValidator.KEY, headerValidatorFactory)
+ .location(Location.forGrpcInsecure("localhost", 0))
+ .producer(new MockFlightSqlProducer());
+ params = new HashMap<>();
+ }
+
+ @AfterEach
+ public void tearDown() throws Exception {
+ AutoCloseables.close(connection, server, allocator);
+ connection = null;
+ server = null;
+ allocator = null;
+ }
+
+ @Test
+ public void testArbitraryHeader() throws Exception {
+ final String dummyValue = "dummy";
+ final String dummyHeaderName = "test-header";
+ params.put(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX + dummyHeaderName, dummyValue);
+ server = builder.build();
+ server.start();
+ connect();
+
+ CallHeaders headers = headerValidatorFactory.getHeadersReceivedAtRequest(0);
+ assertEquals(dummyValue, headers.get(dummyHeaderName));
+ }
+
+ @Test
+ public void testCookies() throws Exception {
+ builder.middleware(CookieMiddleware.KEY, new CookieMiddleware.Factory());
+ server = builder.build();
+ server.start();
+
+ params.put(FlightSqlConnectionProperties.WITH_COOKIE_MIDDLEWARE.getKey(), true);
+ connect();
+ try (ArrowReader reader = connection.getInfo(new int[] {})) {
+
+ } catch (Exception ex) {
+ // Swallow exceptions from the RPC call. Only interested in tracking metadata.
+ }
+ CallHeaders secondHeaders = headerValidatorFactory.getHeadersReceivedAtRequest(1);
+ assertTrue(secondHeaders.containsKey("cookie"));
+ }
+
+ @Test
+ public void testBearerToken() throws Exception {
+ builder.headerAuthenticator(
+ new GeneratedBearerTokenAuthenticator(
+ new BasicCallHeaderAuthenticator((username, password) -> () -> username)));
+ server = builder.build();
+ server.start();
+
+ params.put(AdbcDriver.PARAM_USERNAME.getKey(), "dummy_user");
+ params.put(AdbcDriver.PARAM_PASSWORD.getKey(), "dummy_password");
+ connect();
+ try (ArrowReader reader = connection.getInfo(new int[] {})) {
+
+ } catch (Exception ex) {
+ // Swallow exceptions from the RPC call. Only interested in tracking metadata.
+ }
+ CallHeaders secondHeaders = headerValidatorFactory.getHeadersReceivedAtRequest(1);
+ assertTrue(secondHeaders.get("authorization").contains("Bearer"));
+ }
+
+ @Test
+ public void testUnauthenticated() throws Exception {
+ builder.headerAuthenticator(
+ new GeneratedBearerTokenAuthenticator(
+ new BasicCallHeaderAuthenticator((username, password) -> () -> username)));
+ server = builder.build();
+ server.start();
+
+ AdbcException adbcException = assertThrows(AdbcException.class, this::connect);
+ assertEquals(AdbcStatusCode.UNAUTHENTICATED, adbcException.getStatus());
+ }
+
+ static class CookieMiddleware implements FlightServerMiddleware {
+
+ public static final Key KEY = Key.of("CookieMiddleware");
+
+ @Override
+ public void onBeforeSendingHeaders(CallHeaders callHeaders) {
+ callHeaders.insert("set-cookie", "test=test_val");
+ }
+
+ @Override
+ public void onCallCompleted(CallStatus callStatus) {}
+
+ @Override
+ public void onCallErrored(Throwable throwable) {}
+
+ public static class Factory implements FlightServerMiddleware.Factory {
+
+ @Override
+ public CookieMiddleware onCallStarted(
+ CallInfo callInfo, CallHeaders callHeaders, RequestContext requestContext) {
+ return new CookieMiddleware();
+ }
+ }
+ }
+
+ private void connect() throws Exception {
+ int port = server.getPort();
+ String uri = String.format("grpc+tcp://%s:%d", "localhost", port);
+ params.put(AdbcDriver.PARAM_URI.getKey(), uri);
+ AdbcDatabase db =
+ AdbcDriverManager.getInstance()
+ .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params);
+ connection = db.connect();
+ }
+}
diff --git a/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderValidator.java b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderValidator.java
new file mode 100644
index 0000000000..f543dd9958
--- /dev/null
+++ b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderValidator.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.adbc.driver.flightsql;
+
+import java.util.ArrayList;
+import org.apache.arrow.flight.CallHeaders;
+import org.apache.arrow.flight.CallInfo;
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.FlightCallHeaders;
+import org.apache.arrow.flight.FlightServerMiddleware;
+import org.apache.arrow.flight.RequestContext;
+
+public class HeaderValidator implements FlightServerMiddleware {
+ public static final Key KEY = Key.of("HeaderValidator");
+
+ @Override
+ public void onBeforeSendingHeaders(CallHeaders callHeaders) {}
+
+ @Override
+ public void onCallCompleted(CallStatus callStatus) {}
+
+ @Override
+ public void onCallErrored(Throwable throwable) {}
+
+ public static class Factory implements FlightServerMiddleware.Factory {
+
+ private final ArrayList headersReceived = new ArrayList<>();
+
+ @Override
+ public HeaderValidator onCallStarted(
+ CallInfo callInfo, CallHeaders callHeaders, RequestContext requestContext) {
+ CallHeaders cloneHeaders = cloneHeaders(callHeaders);
+ headersReceived.add(cloneHeaders);
+ return new HeaderValidator();
+ }
+
+ public CallHeaders getHeadersReceivedAtRequest(int request) {
+ return cloneHeaders(headersReceived.get(request));
+ }
+
+ private static CallHeaders cloneHeaders(CallHeaders headers) {
+ FlightCallHeaders cloneHeaders = new FlightCallHeaders();
+ for (String key : headers.keys()) {
+ for (String value : headers.getAll(key)) {
+ cloneHeaders.insert(key, value);
+ }
+ }
+ return cloneHeaders;
+ }
+ }
+}
diff --git a/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/MutualTlsTest.java b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/MutualTlsTest.java
new file mode 100644
index 0000000000..f98a7c3470
--- /dev/null
+++ b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/MutualTlsTest.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.adbc.driver.flightsql;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+
+import java.io.InputStream;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.arrow.adbc.core.AdbcConnection;
+import org.apache.arrow.adbc.core.AdbcDatabase;
+import org.apache.arrow.adbc.core.AdbcDriver;
+import org.apache.arrow.adbc.core.AdbcException;
+import org.apache.arrow.adbc.core.AdbcStatusCode;
+import org.apache.arrow.adbc.drivermanager.AdbcDriverManager;
+import org.apache.arrow.driver.jdbc.FlightServerTestRule;
+import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication;
+import org.apache.arrow.driver.jdbc.utils.FlightSqlTestCertificates;
+import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.ClassRule;
+import org.junit.Test;
+
+public class MutualTlsTest {
+
+ @ClassRule public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE;
+
+ private static final String USER_1 = "user1";
+ private static final String PASS_1 = "pass1";
+ private static final String TLS_ROOT_CERTS_PATH;
+ private static final String CLIENT_MTLS_CERT_PATH;
+ private static final String CLIENT_MTLS_KEY_PATH;
+
+ static {
+ final FlightSqlTestCertificates.CertKeyPair certKey =
+ FlightSqlTestCertificates.exampleTlsCerts().get(0);
+
+ TLS_ROOT_CERTS_PATH = certKey.cert.getPath();
+
+ final FlightSqlTestCertificates.CertKeyPair clientMTlsCertKey =
+ FlightSqlTestCertificates.exampleTlsCerts().get(1);
+
+ CLIENT_MTLS_CERT_PATH = clientMTlsCertKey.cert.getPath();
+ CLIENT_MTLS_KEY_PATH = clientMTlsCertKey.key.getPath();
+
+ UserPasswordAuthentication authentication =
+ new UserPasswordAuthentication.Builder().user(USER_1, PASS_1).build();
+
+ FLIGHT_SERVER_TEST_RULE =
+ new FlightServerTestRule.Builder()
+ .authentication(authentication)
+ .useEncryption(certKey.cert, certKey.key)
+ .useMTlsClientVerification(FlightSqlTestCertificates.exampleCACert())
+ .producer(new MockFlightSqlProducer())
+ .build();
+ }
+
+ private BufferAllocator allocator;
+ private Map params;
+
+ @Before
+ public void setUp() {
+ allocator = new RootAllocator(Long.MAX_VALUE);
+ params = new HashMap<>();
+ params.put(AdbcDriver.PARAM_USERNAME.getKey(), USER_1);
+ params.put(AdbcDriver.PARAM_PASSWORD.getKey(), PASS_1);
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ AutoCloseables.close(allocator);
+ }
+
+ @Test
+ public void testClientTlsOnVerifyOffServerOnNoCertSpecified() throws Exception {
+ params.put(AdbcDriver.PARAM_URI.getKey(), getUri(true));
+ params.put(FlightSqlConnectionProperties.TLS_SKIP_VERIFY.getKey(), true);
+ AdbcException adbcException =
+ assertThrows(
+ AdbcException.class,
+ () -> {
+ AdbcDatabase db =
+ AdbcDriverManager.getInstance()
+ .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params);
+ try (AdbcConnection conn = db.connect()) {}
+ });
+ assertEquals(AdbcStatusCode.IO, adbcException.getStatus());
+ }
+
+ @Test
+ public void testClientTlsOnVerifyOnServerOnNoCertSpecified() throws Exception {
+ params.put(AdbcDriver.PARAM_URI.getKey(), getUri(true));
+ try (InputStream stream = Files.newInputStream(Paths.get(TLS_ROOT_CERTS_PATH))) {
+ params.put(FlightSqlConnectionProperties.TLS_ROOT_CERTS.getKey(), stream);
+ AdbcException adbcException =
+ assertThrows(
+ AdbcException.class,
+ () -> {
+ AdbcDatabase db =
+ AdbcDriverManager.getInstance()
+ .connect(
+ FlightSqlDriverFactory.class.getCanonicalName(), allocator, params);
+ try (AdbcConnection conn = db.connect()) {}
+ });
+ assertEquals(AdbcStatusCode.IO, adbcException.getStatus());
+ }
+ }
+
+ @Test
+ public void testClientTlsOnVerifyOnCertsSpecifiedServerOnNoCertSpecified() throws Exception {
+ params.put(AdbcDriver.PARAM_URI.getKey(), getUri(true));
+ try (InputStream rootCertStream = Files.newInputStream(Paths.get(TLS_ROOT_CERTS_PATH));
+ InputStream privateKeyStream = Files.newInputStream(Paths.get(CLIENT_MTLS_KEY_PATH));
+ InputStream clientCertStream = Files.newInputStream(Paths.get(CLIENT_MTLS_CERT_PATH))) {
+ params.put(FlightSqlConnectionProperties.TLS_ROOT_CERTS.getKey(), rootCertStream);
+ params.put(FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.getKey(), privateKeyStream);
+ params.put(FlightSqlConnectionProperties.MTLS_CERT_CHAIN.getKey(), clientCertStream);
+ AdbcDatabase db =
+ AdbcDriverManager.getInstance()
+ .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params);
+ try (AdbcConnection conn = db.connect()) {}
+ }
+ }
+
+ private String getUri(boolean withTls) {
+ String protocol = String.format("grpc%s", withTls ? "+tls" : "+tcp");
+ return String.format(
+ "%s://%s:%d",
+ protocol, FLIGHT_SERVER_TEST_RULE.getHost(), FLIGHT_SERVER_TEST_RULE.getPort());
+ }
+}
diff --git a/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/TlsTest.java b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/TlsTest.java
new file mode 100644
index 0000000000..0301f3d8fb
--- /dev/null
+++ b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/TlsTest.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.adbc.driver.flightsql;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+
+import java.io.InputStream;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.arrow.adbc.core.AdbcConnection;
+import org.apache.arrow.adbc.core.AdbcDatabase;
+import org.apache.arrow.adbc.core.AdbcDriver;
+import org.apache.arrow.adbc.core.AdbcException;
+import org.apache.arrow.adbc.core.AdbcStatusCode;
+import org.apache.arrow.adbc.drivermanager.AdbcDriverManager;
+import org.apache.arrow.driver.jdbc.FlightServerTestRule;
+import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication;
+import org.apache.arrow.driver.jdbc.utils.FlightSqlTestCertificates;
+import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.ClassRule;
+import org.junit.Test;
+
+public class TlsTest {
+
+ @ClassRule public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE;
+
+ private static final String USER_1 = "user1";
+ private static final String PASS_1 = "pass1";
+ private static final String TLS_ROOT_CERTS_PATH;
+
+ static {
+ final FlightSqlTestCertificates.CertKeyPair certKey =
+ FlightSqlTestCertificates.exampleTlsCerts().get(0);
+
+ TLS_ROOT_CERTS_PATH = certKey.cert.getPath();
+
+ UserPasswordAuthentication authentication =
+ new UserPasswordAuthentication.Builder().user(USER_1, PASS_1).build();
+
+ FLIGHT_SERVER_TEST_RULE =
+ new FlightServerTestRule.Builder()
+ .authentication(authentication)
+ .useEncryption(certKey.cert, certKey.key)
+ .producer(new MockFlightSqlProducer())
+ .build();
+ }
+
+ private BufferAllocator allocator;
+ private Map params;
+
+ @Before
+ public void setUp() {
+ allocator = new RootAllocator(Long.MAX_VALUE);
+ params = new HashMap<>();
+ params.put(AdbcDriver.PARAM_USERNAME.getKey(), USER_1);
+ params.put(AdbcDriver.PARAM_PASSWORD.getKey(), PASS_1);
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ AutoCloseables.close(allocator);
+ }
+
+ @Test
+ public void testClientTlsOffServerOn() {
+ params.put(AdbcDriver.PARAM_URI.getKey(), getUri(false));
+ AdbcException adbcException =
+ assertThrows(
+ AdbcException.class,
+ () -> {
+ AdbcDatabase db =
+ AdbcDriverManager.getInstance()
+ .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params);
+ try (AdbcConnection conn = db.connect()) {}
+ });
+ assertEquals(AdbcStatusCode.IO, adbcException.getStatus());
+ }
+
+ @Test
+ public void testClientTlsOnServerOnNoCertSpecified() {
+ params.put(AdbcDriver.PARAM_URI.getKey(), getUri(true));
+ AdbcException adbcException =
+ assertThrows(
+ AdbcException.class,
+ () -> {
+ AdbcDatabase db =
+ AdbcDriverManager.getInstance()
+ .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params);
+ try (AdbcConnection conn = db.connect()) {}
+ });
+ assertEquals(AdbcStatusCode.IO, adbcException.getStatus());
+ }
+
+ @Test
+ public void testClientTlsOnVerifyOffServerOnNoCertSpecified() throws Exception {
+ params.put(AdbcDriver.PARAM_URI.getKey(), getUri(true));
+ params.put(FlightSqlConnectionProperties.TLS_SKIP_VERIFY.getKey(), true);
+ AdbcDatabase db =
+ AdbcDriverManager.getInstance()
+ .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params);
+ try (AdbcConnection conn = db.connect()) {}
+ }
+
+ @Test
+ public void testClientTlsOnVerifyOnCertsSpecifiedServerOnNoCertSpecified() throws Exception {
+ params.put(AdbcDriver.PARAM_URI.getKey(), getUri(true));
+ try (InputStream stream = Files.newInputStream(Paths.get(TLS_ROOT_CERTS_PATH))) {
+ params.put(FlightSqlConnectionProperties.TLS_ROOT_CERTS.getKey(), stream);
+ AdbcDatabase db =
+ AdbcDriverManager.getInstance()
+ .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params);
+ try (AdbcConnection conn = db.connect()) {}
+ }
+ }
+
+ @Test
+ public void testClientTlsOnBadHostnameOverride() throws Exception {
+ params.put(AdbcDriver.PARAM_URI.getKey(), getUri(true));
+ params.put(FlightSqlConnectionProperties.TLS_OVERRIDE_HOSTNAME.getKey(), "fakehost");
+ try (InputStream stream = Files.newInputStream(Paths.get(TLS_ROOT_CERTS_PATH))) {
+ params.put(FlightSqlConnectionProperties.TLS_ROOT_CERTS.getKey(), stream);
+ AdbcException adbcException =
+ assertThrows(
+ AdbcException.class,
+ () -> {
+ AdbcDatabase db =
+ AdbcDriverManager.getInstance()
+ .connect(
+ FlightSqlDriverFactory.class.getCanonicalName(), allocator, params);
+ try (AdbcConnection conn = db.connect()) {}
+ });
+ assertEquals(AdbcStatusCode.IO, adbcException.getStatus());
+ }
+ }
+
+ @Test
+ public void testClientTlsOnGoodHostnameOverride() throws Exception {
+ params.put(AdbcDriver.PARAM_URI.getKey(), getUri(true));
+ params.put(
+ FlightSqlConnectionProperties.TLS_OVERRIDE_HOSTNAME.getKey(),
+ FLIGHT_SERVER_TEST_RULE.getHost());
+ try (InputStream stream = Files.newInputStream(Paths.get(TLS_ROOT_CERTS_PATH))) {
+ params.put(FlightSqlConnectionProperties.TLS_ROOT_CERTS.getKey(), stream);
+ AdbcDatabase db =
+ AdbcDriverManager.getInstance()
+ .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params);
+ try (AdbcConnection conn = db.connect()) {}
+ }
+ }
+
+ private String getUri(boolean withTls) {
+ String protocol = String.format("grpc%s", withTls ? "+tls" : "+tcp");
+ return String.format(
+ "%s://%s:%d",
+ protocol, FLIGHT_SERVER_TEST_RULE.getHost(), FLIGHT_SERVER_TEST_RULE.getPort());
+ }
+}
diff --git a/java/pom.xml b/java/pom.xml
index a632959f51..5b02f67f6c 100644
--- a/java/pom.xml
+++ b/java/pom.xml
@@ -97,28 +97,10 @@
org.apache.arrow
- arrow-jdbc
- ${dep.arrow.version}
-
-
- org.apache.arrow
- arrow-memory-core
- ${dep.arrow.version}
-
-
- org.apache.arrow
- arrow-vector
- ${dep.arrow.version}
-
-
- org.apache.arrow
- flight-core
- ${dep.arrow.version}
-
-
- org.apache.arrow
- flight-sql
+ arrow-bom
${dep.arrow.version}
+ pom
+ import
diff --git a/testing b/testing
new file mode 160000
index 0000000000..25d16511e8
--- /dev/null
+++ b/testing
@@ -0,0 +1 @@
+Subproject commit 25d16511e8d42c2744a1d94d90169e3a36e92631