Skip to content

Commit

Permalink
Expose ExecutionContext.register_csv to the python bindings (#524)
Browse files Browse the repository at this point in the history
* Expose register_csv

* Validate delimiter

* Fix tests

* Pass schema

* unused imports

* add linting

* Update deps

* Restore venv
  • Loading branch information
kszucs authored Aug 6, 2021
1 parent 01a51ac commit 5a7bbcc
Show file tree
Hide file tree
Showing 9 changed files with 303 additions and 164 deletions.
19 changes: 11 additions & 8 deletions .github/workflows/python_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,21 @@ jobs:
- uses: actions/setup-python@v2
with:
python-version: "3.9"
- name: Install Python dependencies
run: python -m pip install --upgrade pip setuptools wheel
- name: Run tests
- name: Create Virtualenv
run: |
cd python/
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
pip install -r python/requirements.txt
- name: Run Linters
run: |
source venv/bin/activate
flake8 python
black --line-length 79 --check python
- name: Run tests
run: |
source venv/bin/activate
cd python
maturin develop
pytest -v .
env:
CARGO_HOME: "/home/runner/.cargo"
Expand Down
10 changes: 6 additions & 4 deletions python/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
maturin
toml
pyarrow
pytest

black
flake8
isort
maturin
mypy
numpy
pandas
pyarrow
pytest
toml
285 changes: 148 additions & 137 deletions python/requirements.txt

Large diffs are not rendered by default.

45 changes: 45 additions & 0 deletions python/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@
// specific language governing permissions and limitations
// under the License.

use std::path::PathBuf;
use std::{collections::HashSet, sync::Arc};

use rand::distributions::Alphanumeric;
use rand::Rng;

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;

use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::MemTable;
use datafusion::execution::context::ExecutionContext as _ExecutionContext;
use datafusion::prelude::CsvReadOptions;

use crate::dataframe;
use crate::errors;
Expand Down Expand Up @@ -97,6 +100,48 @@ impl ExecutionContext {
Ok(())
}

#[args(
schema = "None",
has_header = "true",
delimiter = "\",\"",
schema_infer_max_records = "1000",
file_extension = "\".csv\""
)]
fn register_csv(
&mut self,
name: &str,
path: PathBuf,
schema: Option<&PyAny>,
has_header: bool,
delimiter: &str,
schema_infer_max_records: usize,
file_extension: &str,
) -> PyResult<()> {
let path = path
.to_str()
.ok_or(PyValueError::new_err("Unable to convert path to a string"))?;
let schema = match schema {
Some(s) => Some(to_rust::to_rust_schema(s)?),
None => None,
};
let delimiter = delimiter.as_bytes();
if delimiter.len() != 1 {
return Err(PyValueError::new_err(
"Delimiter must be a single character",
));
}

let mut options = CsvReadOptions::new()
.has_header(has_header)
.delimiter(delimiter[0])
.schema_infer_max_records(schema_infer_max_records)
.file_extension(file_extension);
options.schema = schema.as_ref();

errors::wrap(self.ctx.register_csv(name, path, options))?;
Ok(())
}

fn register_udf(
&mut self,
name: &str,
Expand Down
9 changes: 9 additions & 0 deletions python/src/to_rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::convert::TryFrom;
use std::sync::Arc;

use datafusion::arrow::{
Expand Down Expand Up @@ -111,3 +112,11 @@ pub fn to_rust_scalar(ob: &PyAny) -> PyResult<ScalarValue> {
}
})
}

pub fn to_rust_schema(ob: &PyAny) -> PyResult<Schema> {
let c_schema = ffi::FFI_ArrowSchema::empty();
let c_schema_ptr = &c_schema as *const ffi::FFI_ArrowSchema;
ob.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?;
let schema = Schema::try_from(&c_schema).map_err(errors::DataFusionError::from)?;
Ok(schema)
}
13 changes: 10 additions & 3 deletions python/tests/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np
import pyarrow as pa
import pyarrow.csv
import pyarrow.parquet as pq

# used to write parquet files
Expand Down Expand Up @@ -49,7 +50,9 @@ def data_datetime(f):
datetime.datetime.now() - datetime.timedelta(days=1),
datetime.datetime.now() + datetime.timedelta(days=1),
]
return pa.array(data, type=pa.timestamp(f), mask=np.array([False, True, False]))
return pa.array(
data, type=pa.timestamp(f), mask=np.array([False, True, False])
)


def data_date32():
Expand All @@ -58,7 +61,9 @@ def data_date32():
datetime.date(1980, 1, 1),
datetime.date(2030, 1, 1),
]
return pa.array(data, type=pa.date32(), mask=np.array([False, True, False]))
return pa.array(
data, type=pa.date32(), mask=np.array([False, True, False])
)


def data_timedelta(f):
Expand All @@ -67,7 +72,9 @@ def data_timedelta(f):
datetime.timedelta(days=1),
datetime.timedelta(seconds=1),
]
return pa.array(data, type=pa.duration(f), mask=np.array([False, True, False]))
return pa.array(
data, type=pa.duration(f), mask=np.array([False, True, False])
)


def data_binary_other():
Expand Down
16 changes: 12 additions & 4 deletions python/tests/test_math_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
def df():
ctx = ExecutionContext()
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays([pa.array([0.1, -0.7, 0.55])], names=["value"])
batch = pa.RecordBatch.from_arrays(
[pa.array([0.1, -0.7, 0.55])], names=["value"]
)
return ctx.create_dataframe([[batch]])


Expand Down Expand Up @@ -56,7 +58,13 @@ def test_math_functions(df):
np.testing.assert_array_almost_equal(result.column(4), np.arcsin(values))
np.testing.assert_array_almost_equal(result.column(5), np.arccos(values))
np.testing.assert_array_almost_equal(result.column(6), np.exp(values))
np.testing.assert_array_almost_equal(result.column(7), np.log(values + 1.0))
np.testing.assert_array_almost_equal(result.column(8), np.log2(values + 1.0))
np.testing.assert_array_almost_equal(result.column(9), np.log10(values + 1.0))
np.testing.assert_array_almost_equal(
result.column(7), np.log(values + 1.0)
)
np.testing.assert_array_almost_equal(
result.column(8), np.log2(values + 1.0)
)
np.testing.assert_array_almost_equal(
result.column(9), np.log10(values + 1.0)
)
np.testing.assert_array_less(result.column(10), np.ones_like(values))
5 changes: 2 additions & 3 deletions python/tests/test_pa_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@


def test_type_ids():
"""having this fixed is very important because internally we rely on this id to parse from
python"""
# Having this fixed is very important because internally we rely on this id
# to parse from python
for idx, arrow_type in [
(0, pa.null()),
(1, pa.bool_()),
Expand All @@ -47,5 +47,4 @@ def test_type_ids():
(34, pa.large_utf8()),
(35, pa.large_binary()),
]:

assert idx == arrow_type.id
65 changes: 60 additions & 5 deletions python/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import numpy as np
import pyarrow as pa
import pytest
from datafusion import ExecutionContext

from datafusion import ExecutionContext
from . import generic as helpers


Expand All @@ -33,12 +33,63 @@ def test_no_table(ctx):
ctx.sql("SELECT a FROM b").collect()


def test_register(ctx, tmp_path):
def test_register_csv(ctx, tmp_path):
path = tmp_path / "test.csv"

table = pa.Table.from_arrays(
[
[1, 2, 3, 4],
["a", "b", "c", "d"],
[1.1, 2.2, 3.3, 4.4],
],
names=["int", "str", "float"],
)
pa.csv.write_csv(table, path)

ctx.register_csv("csv", path)
ctx.register_csv("csv1", str(path))
ctx.register_csv(
"csv2",
path,
has_header=True,
delimiter=",",
schema_infer_max_records=10,
)
alternative_schema = pa.schema(
[
("some_int", pa.int16()),
("some_bytes", pa.string()),
("some_floats", pa.float32()),
]
)
ctx.register_csv("csv3", path, schema=alternative_schema)

assert ctx.tables() == {"csv", "csv1", "csv2", "csv3"}

for table in ["csv", "csv1", "csv2"]:
result = ctx.sql(f"SELECT COUNT(int) FROM {table}").collect()
result = pa.Table.from_batches(result)
assert result.to_pydict() == {"COUNT(int)": [4]}

result = ctx.sql("SELECT * FROM csv3").collect()
result = pa.Table.from_batches(result)
assert result.schema == alternative_schema

with pytest.raises(
ValueError, match="Delimiter must be a single character"
):
ctx.register_csv("csv4", path, delimiter="wrong")


def test_register_parquet(ctx, tmp_path):
path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())
ctx.register_parquet("t", path)

assert ctx.tables() == {"t"}

result = ctx.sql("SELECT COUNT(a) FROM t").collect()
result = pa.Table.from_batches(result)
assert result.to_pydict() == {"COUNT(a)": [100]}


def test_execute(ctx, tmp_path):
data = [1, 1, 2, 2, 3, 11, 12]
Expand Down Expand Up @@ -112,7 +163,9 @@ def test_cast(ctx, tmp_path):
"float",
]

select = ", ".join([f"CAST(9 AS {t}) AS A{i}" for i, t in enumerate(valid_types)])
select = ", ".join(
[f"CAST(9 AS {t}) AS A{i}" for i, t in enumerate(valid_types)]
)

# can execute, which implies that we can cast
ctx.sql(f"SELECT {select} FROM t").collect()
Expand Down Expand Up @@ -141,7 +194,9 @@ def test_udf(
ctx, tmp_path, fn, input_types, output_type, input_values, expected_values
):
# write to disk
path = helpers.write_parquet(tmp_path / "a.parquet", pa.array(input_values))
path = helpers.write_parquet(
tmp_path / "a.parquet", pa.array(input_values)
)
ctx.register_parquet("t", path)
ctx.register_udf("udf", fn, input_types, output_type)

Expand Down

0 comments on commit 5a7bbcc

Please sign in to comment.