Skip to content

Commit

Permalink
Fix python tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Sep 1, 2021
1 parent fccd0f1 commit af41f59
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 21 deletions.
3 changes: 0 additions & 3 deletions python/datafusion/tests/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,11 @@
# specific language governing permissions and limitations
# under the License.

import numpy as np
import pyarrow as pa
import pytest

from datafusion import ExecutionContext

from . import generic as helpers


@pytest.fixture
def ctx():
Expand Down
4 changes: 1 addition & 3 deletions python/datafusion/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
import pyarrow as pa
import pytest

from datafusion import ExecutionContext, DataFrame
from datafusion import DataFrame, ExecutionContext
from datafusion import functions as f

from . import generic as helpers


@pytest.fixture
def df():
Expand Down
28 changes: 13 additions & 15 deletions python/datafusion/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def test_register_csv(ctx, tmp_path):
assert ctx.tables() == {"csv", "csv1", "csv2", "csv3"}

for table in ["csv", "csv1", "csv2"]:
result = ctx.sql(f"SELECT COUNT(int) FROM {table}").collect()
result = ctx.sql(f"SELECT COUNT(int) AS cnt FROM {table}").collect()
result = pa.Table.from_batches(result)
assert result.to_pydict() == {"COUNT(int)": [4]}
assert result.to_pydict() == {"cnt": [4]}

result = ctx.sql("SELECT * FROM csv3").collect()
result = pa.Table.from_batches(result)
Expand All @@ -87,9 +87,9 @@ def test_register_parquet(ctx, tmp_path):
ctx.register_parquet("t", path)
assert ctx.tables() == {"t"}

result = ctx.sql("SELECT COUNT(a) FROM t").collect()
result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect()
result = pa.Table.from_batches(result)
assert result.to_pydict() == {"COUNT(a)": [100]}
assert result.to_pydict() == {"cnt": [100]}


def test_execute(ctx, tmp_path):
Expand All @@ -102,30 +102,30 @@ def test_execute(ctx, tmp_path):
assert ctx.tables() == {"t"}

# count
result = ctx.sql("SELECT COUNT(a) FROM t").collect()
result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect()

expected = pa.array([7], pa.uint64())
expected = [pa.RecordBatch.from_arrays([expected], ["COUNT(a)"])]
expected = [pa.RecordBatch.from_arrays([expected], ["cnt"])]
assert result == expected

# where
expected = pa.array([2], pa.uint64())
expected = [pa.RecordBatch.from_arrays([expected], ["COUNT(a)"])]
result = ctx.sql("SELECT COUNT(a) FROM t WHERE a > 10").collect()
expected = [pa.RecordBatch.from_arrays([expected], ["cnt"])]
result = ctx.sql("SELECT COUNT(a) AS cnt FROM t WHERE a > 10").collect()
assert result == expected

# group by
results = ctx.sql(
"SELECT CAST(a as int), COUNT(a) FROM t GROUP BY CAST(a as int)"
"SELECT CAST(a as int) AS a, COUNT(a) AS cnt FROM t GROUP BY a"
).collect()

# group by returns batches
result_keys = []
result_values = []
for result in results:
pydict = result.to_pydict()
result_keys.extend(pydict["CAST(a AS Int32)"])
result_values.extend(pydict["COUNT(a)"])
result_keys.extend(pydict["a"])
result_values.extend(pydict["cnt"])

result_keys, result_values = (
list(t) for t in zip(*sorted(zip(result_keys, result_values)))
Expand All @@ -136,14 +136,12 @@ def test_execute(ctx, tmp_path):

# order by
result = ctx.sql(
"SELECT a, CAST(a AS int) FROM t ORDER BY a DESC LIMIT 2"
"SELECT a, CAST(a AS int) AS a_int FROM t ORDER BY a DESC LIMIT 2"
).collect()
expected_a = pa.array([50.0219, 50.0152], pa.float64())
expected_cast = pa.array([50, 50], pa.int32())
expected = [
pa.RecordBatch.from_arrays(
[expected_a, expected_cast], ["a", "CAST(a AS Int32)"]
)
pa.RecordBatch.from_arrays([expected_a, expected_cast], ["a", "a_int"])
]
np.testing.assert_equal(expected[0].column(1), expected[0].column(1))

Expand Down

0 comments on commit af41f59

Please sign in to comment.