diff --git a/python/datafusion/tests/test_context.py b/python/datafusion/tests/test_context.py index 4af00a3b..ab86faa9 100644 --- a/python/datafusion/tests/test_context.py +++ b/python/datafusion/tests/test_context.py @@ -372,6 +372,25 @@ def test_dataset_filter(ctx, capfd): assert result[0].column(1) == pa.array([-3]) +def test_dataset_count(ctx): + # `datafusion-python` issue: https://github.com/apache/datafusion-python/issues/800 + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + dataset = ds.dataset([batch]) + ctx.register_dataset("t", dataset) + + # Testing the dataframe API + df = ctx.table("t") + assert df.count() == 3 + + # Testing the SQL API + count = ctx.sql("SELECT COUNT(*) FROM t") + count = count.collect() + assert count[0].column(0) == pa.array([3]) + + def test_pyarrow_predicate_pushdown_is_null(ctx, capfd): """Ensure that pyarrow filter gets pushed down for `IsNull`""" # create a RecordBatch and register it as a pyarrow.dataset.Dataset