Skip to content

Commit

Permalink
fix(src/Drivers/Apache): correctly handle empty response and add Clie…
Browse files Browse the repository at this point in the history
…nt tests
  • Loading branch information
birschick-bq committed Oct 23, 2024
1 parent 65406ce commit 6d2a6dd
Show file tree
Hide file tree
Showing 8 changed files with 412 additions and 18 deletions.
25 changes: 14 additions & 11 deletions csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public HiveServer2Reader(
HiveServer2Statement statement,
Schema schema,
DataTypeConversion dataTypeConversion,
CancellationToken cancellationToken = default)
CancellationToken _ = default)
{
_statement = statement;
Schema = schema;
Expand All @@ -88,22 +88,20 @@ public HiveServer2Reader(
// Await the fetch response
TFetchResultsResp response = await FetchNext(_statement, cancellationToken);

// Build the current batch
RecordBatch result = CreateBatch(response, out int fetchedRows);

if ((_statement.BatchSize > 0 && fetchedRows < _statement.BatchSize) || fetchedRows == 0)
int columnCount = GetColumnCount(response);
int rowCount = GetRowCount(response, columnCount); ;
if ((_statement.BatchSize > 0 && rowCount < _statement.BatchSize) || rowCount == 0)
{
// This is the last batch
_statement = null;
}

// Return the current batch.
return result;
// Build the current batch, if any data exists
return rowCount > 0 ? CreateBatch(response, columnCount, rowCount) : null;
}

private RecordBatch CreateBatch(TFetchResultsResp response, out int length)
private RecordBatch CreateBatch(TFetchResultsResp response, int columnCount, int rowCount)
{
int columnCount = response.Results.Columns.Count;
IList<IArrowArray> columnData = [];
bool shouldConvertScalar = _dataTypeConversion.HasFlag(DataTypeConversion.Scalar);
for (int i = 0; i < columnCount; i++)
Expand All @@ -113,10 +111,15 @@ private RecordBatch CreateBatch(TFetchResultsResp response, out int length)
columnData.Add(columnArray);
}

length = columnCount > 0 ? GetArray(response.Results.Columns[0]).Length : 0;
return new RecordBatch(Schema, columnData, length);
return new RecordBatch(Schema, columnData, rowCount);
}

private static int GetColumnCount(TFetchResultsResp response) =>
response.Results.Columns.Count;

private static int GetRowCount(TFetchResultsResp response, int columnCount) =>
columnCount > 0 ? GetArray(response.Results.Columns[0]).Length : 0;

private static async Task<TFetchResultsResp> FetchNext(HiveServer2Statement statement, CancellationToken cancellationToken = default)
{
var request = new TFetchResultsReq(statement.OperationHandle, TFetchOrientation.FETCH_NEXT, statement.BatchSize);
Expand Down
16 changes: 10 additions & 6 deletions csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,22 @@ public static void CanClientExecuteUpdate(Adbc.Client.AdbcConnection adbcConnect
/// </summary>
/// <param name="adbcConnection">The <see cref="Adbc.Client.AdbcConnection"/> to use.</param>
/// <param name="testConfiguration">The <see cref="TestConfiguration"/> to use</param>
public static void CanClientGetSchema(Adbc.Client.AdbcConnection adbcConnection, TestConfiguration testConfiguration)
/// <param name="customQuery">The custom query to use instead of query from <see cref="TestConfiguration.Query" /></param>"/>
/// <param name="expectedColumnCount">The custom column count to use instead of query from <see cref="TestMetadata.ExpectedColumnCount" /></param>
public static void CanClientGetSchema(Adbc.Client.AdbcConnection adbcConnection, TestConfiguration testConfiguration, string? customQuery = default, int? expectedColumnCount = default)
{
if (adbcConnection == null) throw new ArgumentNullException(nameof(adbcConnection));
if (testConfiguration == null) throw new ArgumentNullException(nameof(testConfiguration));

adbcConnection.Open();

using AdbcCommand adbcCommand = new AdbcCommand(testConfiguration.Query, adbcConnection);
using AdbcCommand adbcCommand = new AdbcCommand(customQuery ?? testConfiguration.Query, adbcConnection);
using AdbcDataReader reader = adbcCommand.ExecuteReader(CommandBehavior.SchemaOnly);

DataTable? table = reader.GetSchemaTable();

// there is one row per field
Assert.Equal(testConfiguration.Metadata.ExpectedColumnCount, table?.Rows.Count);
Assert.Equal(expectedColumnCount ?? testConfiguration.Metadata.ExpectedColumnCount, table?.Rows.Count);
}

/// <summary>
Expand All @@ -98,7 +100,9 @@ public static void CanClientGetSchema(Adbc.Client.AdbcConnection adbcConnection,
public static void CanClientExecuteQuery(
Adbc.Client.AdbcConnection adbcConnection,
TestConfiguration testConfiguration,
Action<AdbcCommand>? additionalCommandOptionsSetter = null)
Action<AdbcCommand>? additionalCommandOptionsSetter = null,
string? customQuery = default,
int? expectedResultsCount = default)
{
if (adbcConnection == null) throw new ArgumentNullException(nameof(adbcConnection));
if (testConfiguration == null) throw new ArgumentNullException(nameof(testConfiguration));
Expand All @@ -107,7 +111,7 @@ public static void CanClientExecuteQuery(

adbcConnection.Open();

using AdbcCommand adbcCommand = new AdbcCommand(testConfiguration.Query, adbcConnection);
using AdbcCommand adbcCommand = new AdbcCommand(customQuery ?? testConfiguration.Query, adbcConnection);
additionalCommandOptionsSetter?.Invoke(adbcCommand);
using AdbcDataReader reader = adbcCommand.ExecuteReader();

Expand All @@ -131,7 +135,7 @@ public static void CanClientExecuteQuery(
}
finally { reader.Close(); }

Assert.Equal(testConfiguration.ExpectedResultsCount, count);
Assert.Equal(expectedResultsCount ?? testConfiguration.ExpectedResultsCount, count);
}

/// <summary>
Expand Down
5 changes: 5 additions & 0 deletions csharp/test/Apache.Arrow.Adbc.Tests/TestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ protected string[] GetQueries()
return queries;
}

protected SampleDataBuilder GetSampleDataBuilder()
{
return TestEnvironment.GetSampleDataBuilder();
}

/// <summary>
/// Gets a the Spark ADBC driver with settings from the <see cref="SparkTestConfiguration"/>.
/// </summary>
Expand Down
2 changes: 2 additions & 0 deletions csharp/test/Apache.Arrow.Adbc.Tests/TestEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ protected TestEnvironment(Func<AdbcConnection> getConnection)

public abstract AdbcDriver CreateNewDriver();

public abstract SampleDataBuilder GetSampleDataBuilder();

public abstract Dictionary<string, string> GetDriverParameters(TConfig testConfiguration);

public virtual string GetCreateTemporaryTableStatement(string tableName, string columns)
Expand Down
1 change: 1 addition & 0 deletions csharp/test/Drivers/Apache/Impala/ImpalaTestEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,6 @@ public override Dictionary<string, string> GetDriverParameters(ApacheTestConfigu

public override string GetInsertStatement(string tableName, string columnName, string? value) =>
string.Format("INSERT INTO {0} ({1}) SELECT {2};", tableName, columnName, value ?? "NULL");
public override SampleDataBuilder GetSampleDataBuilder() => throw new NotImplementedException();
}
}
3 changes: 3 additions & 0 deletions csharp/test/Drivers/Apache/Spark/BinaryBooleanValueTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ await ValidateInsertSelectDeleteTwoValuesAsync(
[InlineData("CAST(NULL AS CHAR(10))")]
[InlineData("CAST(NULL AS BOOLEAN)")]
[InlineData("CAST(NULL AS BINARY)")]
[InlineData("CAST(NULL AS MAP<STRING, INT>)")]
[InlineData("CAST(NULL AS STRUCT<NAME: STRING>)")]
[InlineData("CAST(NULL AS ARRAY<INT>)")]
public async Task TestNullData(string projectionClause)
{
string selectStatement = $"SELECT {projectionClause};";
Expand Down
227 changes: 227 additions & 0 deletions csharp/test/Drivers/Apache/Spark/ClientTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System;
using System.Collections.Generic;
using Apache.Arrow.Adbc.Drivers.Apache.Spark;
using Apache.Arrow.Adbc.Tests.Xunit;
using Xunit;
using Xunit.Abstractions;

namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
{
/// <summary>
/// Class for testing the ADBC Client using the BigQuery ADBC driver.
/// </summary>
/// <remarks>
/// Tests are ordered to ensure data is created for the other
/// queries to run.
/// </remarks>
[TestCaseOrderer("Apache.Arrow.Adbc.Tests.Xunit.TestOrderer", "Apache.Arrow.Adbc.Tests")]
public class ClientTests : TestBase<SparkTestConfiguration, SparkTestEnvironment>
{
public ClientTests(ITestOutputHelper? outputHelper) : base(outputHelper, new SparkTestEnvironment.Factory())
{
Skip.IfNot(Utils.CanExecuteTestConfig(TestConfigVariable));
}

/// <summary>
/// Validates if the client execute updates.
/// </summary>
[SkippableFact, Order(1)]
public void CanClientExecuteUpdate()
{
using (Adbc.Client.AdbcConnection adbcConnection = GetAdbcConnection())
{
adbcConnection.Open();

string[] queries = GetQueries();
int affectedRows = ValidateAffectedRows ? 1 : -1;

List<int> expectedResults = TestEnvironment.ServerType != SparkServerType.Databricks
? [
-1, // DROP TABLE
-1, // CREATE TABLE
affectedRows, // INSERT
affectedRows, // INSERT
affectedRows, // INSERT
//1, // UPDATE
//1, // DELETE
]
: [
-1, // DROP TABLE
-1, // CREATE TABLE
affectedRows, // INSERT
affectedRows, // INSERT
affectedRows, // INSERT
affectedRows, // UPDATE
affectedRows, // DELETE
];


Tests.ClientTests.CanClientExecuteUpdate(adbcConnection, TestConfiguration, queries, expectedResults);
}
}

/// <summary>
/// Validates if the client can get the schema.
/// </summary>
[SkippableFact, Order(2)]
public void CanClientGetSchema()
{
using (Adbc.Client.AdbcConnection adbcConnection = GetAdbcConnection())
{
Tests.ClientTests.CanClientGetSchema(adbcConnection, TestConfiguration, $"SELECT * FROM {TestConfiguration.Metadata.Table}");
}
}

/// <summary>
/// Validates if the client can connect to a live server and
/// parse the results.
/// </summary>
[SkippableFact, Order(3)]
public void CanClientExecuteQuery()
{
using (Adbc.Client.AdbcConnection adbcConnection = GetAdbcConnection())
{
Tests.ClientTests.CanClientExecuteQuery(adbcConnection, TestConfiguration);
}
}

/// <summary>
/// Validates if the client can connect to a live server and
/// parse the results.
/// </summary>
[SkippableFact, Order(5)]
public void CanClientExecuteEmptyQuery()
{
using (Adbc.Client.AdbcConnection adbcConnection = GetAdbcConnection())
{
Tests.ClientTests.CanClientExecuteQuery(
adbcConnection,
TestConfiguration,
customQuery: $"SELECT * FROM {TestConfiguration.Metadata.Table} WHERE FALSE",
expectedResultsCount: 0);
}
}

/// <summary>
/// Validates if the client is retrieving and converting values
/// to the expected types.
/// </summary>
[SkippableFact, Order(4)]
public void VerifyTypesAndValues()
{
using (Adbc.Client.AdbcConnection dbConnection = GetAdbcConnection())
{
SampleDataBuilder sampleDataBuilder = GetSampleDataBuilder();

Tests.ClientTests.VerifyTypesAndValues(dbConnection, sampleDataBuilder);
}
}

[SkippableFact]
public void VerifySchemaTablesWithNoConstraints()
{
using (Adbc.Client.AdbcConnection adbcConnection = GetAdbcConnection(includeTableConstraints: false))
{
adbcConnection.Open();

string schema = "Tables";

var tables = adbcConnection.GetSchema(schema);

Assert.True(tables.Rows.Count > 0, $"No tables were found in the schema '{schema}'");
}
}


[SkippableFact]
public void VerifySchemaTables()
{
using (Adbc.Client.AdbcConnection adbcConnection = GetAdbcConnection())
{
adbcConnection.Open();

var collections = adbcConnection.GetSchema("MetaDataCollections");
Assert.Equal(7, collections.Rows.Count);
Assert.Equal(2, collections.Columns.Count);

var restrictions = adbcConnection.GetSchema("Restrictions");
Assert.Equal(11, restrictions.Rows.Count);
Assert.Equal(3, restrictions.Columns.Count);

var catalogs = adbcConnection.GetSchema("Catalogs");
Assert.Single(catalogs.Columns);
var catalog = (string?)catalogs.Rows[0].ItemArray[0];

catalogs = adbcConnection.GetSchema("Catalogs", new[] { catalog });
Assert.Equal(1, catalogs.Rows.Count);

string random = "X" + Guid.NewGuid().ToString("N");

catalogs = adbcConnection.GetSchema("Catalogs", new[] { random });
Assert.Equal(0, catalogs.Rows.Count);

var schemas = adbcConnection.GetSchema("Schemas", new[] { catalog });
Assert.Equal(2, schemas.Columns.Count);
var schema = (string?)schemas.Rows[0].ItemArray[1];

schemas = adbcConnection.GetSchema("Schemas", new[] { catalog, schema });
Assert.Equal(1, schemas.Rows.Count);

schemas = adbcConnection.GetSchema("Schemas", new[] { random });
Assert.Equal(0, schemas.Rows.Count);

schemas = adbcConnection.GetSchema("Schemas", new[] { catalog, random });
Assert.Equal(0, schemas.Rows.Count);

schemas = adbcConnection.GetSchema("Schemas", new[] { random, random });
Assert.Equal(0, schemas.Rows.Count);

var tableTypes = adbcConnection.GetSchema("TableTypes");
Assert.Single(tableTypes.Columns);

var tables = adbcConnection.GetSchema("Tables", new[] { catalog, schema });
Assert.Equal(4, tables.Columns.Count);

tables = adbcConnection.GetSchema("Tables", new[] { catalog, random });
Assert.Equal(0, tables.Rows.Count);

tables = adbcConnection.GetSchema("Tables", new[] { random, schema });
Assert.Equal(0, tables.Rows.Count);

tables = adbcConnection.GetSchema("Tables", new[] { random, random });
Assert.Equal(0, tables.Rows.Count);

tables = adbcConnection.GetSchema("Tables", new[] { catalog, schema, random });
Assert.Equal(0, tables.Rows.Count);

var columns = adbcConnection.GetSchema("Columns", new[] { catalog, schema });
Assert.Equal(16, columns.Columns.Count);
}
}

private Adbc.Client.AdbcConnection GetAdbcConnection(bool includeTableConstraints = true)
{
return new Adbc.Client.AdbcConnection(
NewDriver, GetDriverParameters(TestConfiguration),
[]
);
}
}
}
Loading

0 comments on commit 6d2a6dd

Please sign in to comment.