From 7ca568bcf1cc02ca78dc03bc28786cbdf76fa7fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Mon, 15 Nov 2021 10:45:49 +0100 Subject: [PATCH] Migrate to new backend API --- ibis/backends/datafusion/__init__.py | 119 ++++++++++++++- ibis/backends/datafusion/client.py | 214 --------------------------- ibis/backends/datafusion/compiler.py | 88 ++++++++++- pyproject.toml | 1 + 4 files changed, 195 insertions(+), 227 deletions(-) delete mode 100644 ibis/backends/datafusion/client.py diff --git a/ibis/backends/datafusion/__init__.py b/ibis/backends/datafusion/__init__.py index 728072c5b7777..89f2bcedad004 100644 --- a/ibis/backends/datafusion/__init__.py +++ b/ibis/backends/datafusion/__init__.py @@ -1,17 +1,25 @@ # from pkg_resources import parse_version +import re +from typing import Mapping + +import datafusion as df +import pyarrow as pa + +import ibis.common.exceptions as com +import ibis.expr.schema as sch +import ibis.expr.types as ir from ibis.backends.base import BaseBackend -from .client import DataFusionClient, DataFusionDatabase, DataFusionTable from .compiler import DataFusionExprTranslator +# TODO(kszucs): support nested and parametric types +# consolidate with the logic from the parquet backend + class Backend(BaseBackend): name = 'datafusion' builder = None - client_class = DataFusionClient - database_class = DataFusionDatabase - table_class = DataFusionTable translator_class = DataFusionExprTranslator def version(self): @@ -30,7 +38,22 @@ def connect(self, config): ------- DataFusionClient """ - return self.client_class(backend=self, config=config) + new_backend = self.__class__() + if isinstance(config, df.ExecutionContext): + new_backend.context = config + else: + new_backend.context = df.ExecutionContext() + + for name, path in config.items(): + strpath = str(path) + if strpath.endswith('.csv'): + new_backend.register_csv(name, path) + elif strpath.endswith('.parquet'): + new_backend.register_parquet(name, path) + else: + raise ValueError('Wrong format') + + return new_backend def current_database(self): raise NotImplementedError("") @@ -38,5 +61,87 @@ def current_database(self): def list_databases(self): raise NotImplementedError("") - def list_tables(self): - raise NotImplementedError("") + def list_tables(self, like=None): + """List the available tables.""" + tables = list(self.context.tables()) + if like is not None: + pattern = re.compile(like) + return list(filter(lambda t: pattern.findall(t), tables)) + return tables + + # def database(self, name='public'): + # '''Construct a database called `name`.''' + # catalog = self.context.catalog() + # database = catalog.database(name) + # return self.database_class(name, self) + + def table(self, name, schema=None): + catalog = self.context.catalog() + database = catalog.database('public') + table = database.table(name) + schema = sch.infer(table.schema) + return self.table_class(name, schema, self).to_expr() + + def register_csv(self, name, path, schema=None): + self.context.register_csv(name, path, schema=schema) + + def register_parquet(self, name, path, schema=None): + self.context.register_parquet(name, path, schema=schema) + + def execute( + self, + expr: ir.Expr, + params: Mapping[ir.Expr, object] = None, + limit: str = 'default', + **kwargs, + ): + def _collect(frame): + batches = frame.collect() + if batches: + table = pa.Table.from_batches(batches) + else: + # TODO(kszucs): file a bug to datafusion because the fields' + # nullability from frame.schema() is not always consistent + # with the first record batch's schema + table = pa.Table.from_batches(batches, schema=frame.schema()) + return table.to_pandas() + + if isinstance(expr, ir.TableExpr): + frame = self.compile(expr, params, **kwargs) + return _collect(frame) + elif isinstance(expr, ir.ColumnExpr): + # expression must be named for the projection + expr = expr.name('tmp').to_projection() + frame = self.compile(expr, params, **kwargs) + return _collect(frame)['tmp'] + elif isinstance(expr, ir.ScalarExpr): + if expr.op().root_tables(): + # there are associated datafusion tables so convert the expr + # to a selection which we can directly convert to a datafusion + # plan + expr = expr.name('tmp').to_projection() + frame = self.compile(expr, params, **kwargs) + else: + # doesn't have any tables associated so create a plan from a + # dummy datafusion table + compiled = self.compile(expr, params, **kwargs) + frame = self.context.empty_table().select(compiled) + return _collect(frame).iloc[0, 0] + else: + raise com.IbisError( + f"Cannot execute expression of type: {type(expr)}" + ) + + def compile( + self, expr: ir.Expr, params: Mapping[ir.Expr, object] = None, **kwargs + ): + """Compile `expr`. + + Notes + ----- + For the dask backend returns a dask graph that you can run ``.compute`` + on to get a pandas object. + + """ + translator = self.translator_class() + return translator.translate(expr) diff --git a/ibis/backends/datafusion/client.py b/ibis/backends/datafusion/client.py deleted file mode 100644 index 34b010fa87bfa..0000000000000 --- a/ibis/backends/datafusion/client.py +++ /dev/null @@ -1,214 +0,0 @@ -import re -from typing import Mapping - -import datafusion as df -import pyarrow as pa - -import ibis.common.exceptions as com -import ibis.expr.datatypes as dt -import ibis.expr.operations as ops -import ibis.expr.schema as sch -import ibis.expr.types as ir -from ibis.backends.base import Client - -# TODO(kszucs): support nested and parametric types -# consolidate with the logic from the parquet backend - -_to_ibis_dtypes = { - pa.int8(): dt.Int8, - pa.int16(): dt.Int16, - pa.int32(): dt.Int32, - pa.int64(): dt.Int64, - pa.uint8(): dt.UInt8, - pa.uint16(): dt.UInt16, - pa.uint32(): dt.UInt32, - pa.uint64(): dt.UInt64, - pa.float16(): dt.Float16, - pa.float32(): dt.Float32, - pa.float64(): dt.Float64, - pa.string(): dt.String, - pa.binary(): dt.Binary, - pa.bool_(): dt.Boolean, - pa.timestamp('ns'): dt.Timestamp, - pa.timestamp('ms'): dt.Timestamp, -} - -_to_pyarrow_types = { - dt.Int8: pa.int8(), - dt.Int16: pa.int16(), - dt.Int32: pa.int32(), - dt.Int64: pa.int64(), - dt.UInt8: pa.uint8(), - dt.UInt16: pa.uint16(), - dt.UInt32: pa.uint32(), - dt.UInt64: pa.uint64(), - dt.Float16: pa.float16(), - dt.Float32: pa.float32(), - dt.Float64: pa.float64(), - dt.String: pa.string(), - dt.Binary: pa.binary(), - dt.Boolean: pa.bool_(), - dt.Timestamp: pa.timestamp('ns'), -} - -# TODO(kszucs): the following conversions are really rudimentary -# we should have a pyarrow backend which would be responsible -# for conversions between ibis types to pyarrow types - - -def to_pyarrow_type(dtype): - try: - return _to_pyarrow_types[dtype.__class__] - except KeyError as e: - if isinstance(dtype, dt.Array): - return pa.list_(to_pyarrow_type(dtype.value_type)) - elif isinstance(dtype, dt.Set): - return pa.list_(to_pyarrow_type(dtype.value_type)) - elif isinstance(dtype, dt.Interval): - try: - return pa.duration(dtype.unit) - except ValueError: - raise com.IbisTypeError( - f"Unsupported interval unit: {dtype.unit}" - ) - raise e - - -@dt.dtype.register(pa.DataType) -def from_pyarrow_type(arrow_type, nullable=True): - return _to_ibis_dtypes[arrow_type](nullable=nullable) - - -@sch.infer.register(pa.Schema) -def infer_pyarrow_schema(schema): - fields = [(f.name, dt.dtype(f.type, nullable=f.nullable)) for f in schema] - return sch.schema(fields) - - -class DataFusionTable(ops.DatabaseTable): - pass - - -class DataFusionDatabase(ops.DatabaseTable): - pass - - -class DataFusionClient(Client): - def __init__(self, backend, config): - self.backend = backend - - if isinstance(config, df.ExecutionContext): - self.context = config - else: - self.context = df.ExecutionContext() - for name, path in config.items(): - strpath = str(path) - if strpath.endswith('.csv'): - self.register_csv(name, path) - elif strpath.endswith('.parquet'): - self.register_parquet(name, path) - else: - raise ValueError('Wrong format') - - def register_csv(self, name, path, schema=None): - self.context.register_csv(name, path, schema=schema) - - def register_parquet(self, name, path, schema=None): - self.context.register_parquet(name, path, schema=schema) - - # def database(self, name='public'): - # '''Construct a database called `name`.''' - # catalog = self.context.catalog() - # database = catalog.database(name) - # return self.database_class(name, self) - - def table(self, name, schema=None): - catalog = self.context.catalog() - database = catalog.database('public') - table = database.table(name) - schema = sch.infer(table.schema) - return self.backend.table_class(name, schema, self).to_expr() - - def list_tables(self, like=None): - """List the available tables.""" - tables = list(self.context.tables()) - if like is not None: - pattern = re.compile(like) - return list(filter(lambda t: pattern.findall(t), tables)) - return tables - - def list_databases(self, like=None): - raise com.NotImplementedError() - - def exists_table(self, name): - """ - Determine if the indicated table or view exists. - - Parameters - ---------- - name : str - database : str - - Returns - ------- - bool - """ - return bool(self.list_tables(like=name)) - - def execute( - self, - expr: ir.Expr, - params: Mapping[ir.Expr, object] = None, - limit: str = 'default', - **kwargs, - ): - def _collect(frame): - batches = frame.collect() - if batches: - table = pa.Table.from_batches(batches) - else: - # TODO(kszucs): file a bug to datafusion because the fields' - # nullability from frame.schema() is not always consistent - # with the first record batch's schema - table = pa.Table.from_batches(batches, schema=frame.schema()) - return table.to_pandas() - - if isinstance(expr, ir.TableExpr): - frame = self.compile(expr, params, **kwargs) - return _collect(frame) - elif isinstance(expr, ir.ColumnExpr): - # expression must be named for the projection - expr = expr.name('tmp').to_projection() - frame = self.compile(expr, params, **kwargs) - return _collect(frame)['tmp'] - elif isinstance(expr, ir.ScalarExpr): - if expr.op().root_tables(): - # there are associated datafusion tables so convert the expr - # to a selection which we can directly convert to a datafusion - # plan - expr = expr.name('tmp').to_projection() - frame = self.compile(expr, params, **kwargs) - else: - # doesn't have any tables associated so create a plan from a - # dummy datafusion table - compiled = self.compile(expr, params, **kwargs) - frame = self.context.empty_table().select(compiled) - return _collect(frame).iloc[0, 0] - else: - raise com.IbisError( - "Cannot execute expression of type: {}".format(type(expr)) - ) - - def compile( - self, expr: ir.Expr, params: Mapping[ir.Expr, object] = None, **kwargs - ): - """Compile `expr`. - - Notes - ----- - For the dask backend returns a dask graph that you can run ``.compute`` - on to get a pandas object. - - """ - translator = self.backend.translator_class() - return translator.translate(expr) diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index 26a08a55f2051..4fd1732c5ceb0 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -2,13 +2,84 @@ import operator import datafusion as df +import datafusion.functions import pyarrow as pa import ibis.common.exceptions as com +import ibis.expr.datatypes as dt import ibis.expr.operations as ops +import ibis.expr.schema as sch import ibis.expr.types as ir -from .client import DataFusionTable, to_pyarrow_type +_to_ibis_dtypes = { + pa.int8(): dt.Int8, + pa.int16(): dt.Int16, + pa.int32(): dt.Int32, + pa.int64(): dt.Int64, + pa.uint8(): dt.UInt8, + pa.uint16(): dt.UInt16, + pa.uint32(): dt.UInt32, + pa.uint64(): dt.UInt64, + pa.float16(): dt.Float16, + pa.float32(): dt.Float32, + pa.float64(): dt.Float64, + pa.string(): dt.String, + pa.binary(): dt.Binary, + pa.bool_(): dt.Boolean, + pa.timestamp('ns'): dt.Timestamp, + pa.timestamp('ms'): dt.Timestamp, +} + +_to_pyarrow_types = { + dt.Int8: pa.int8(), + dt.Int16: pa.int16(), + dt.Int32: pa.int32(), + dt.Int64: pa.int64(), + dt.UInt8: pa.uint8(), + dt.UInt16: pa.uint16(), + dt.UInt32: pa.uint32(), + dt.UInt64: pa.uint64(), + dt.Float16: pa.float16(), + dt.Float32: pa.float32(), + dt.Float64: pa.float64(), + dt.String: pa.string(), + dt.Binary: pa.binary(), + dt.Boolean: pa.bool_(), + dt.Timestamp: pa.timestamp('ns'), +} + +# TODO(kszucs): the following conversions are really rudimentary +# we should have a pyarrow backend which would be responsible +# for conversions between ibis types to pyarrow types + + +def to_pyarrow_type(dtype): + try: + return _to_pyarrow_types[dtype.__class__] + except KeyError as e: + if isinstance(dtype, dt.Array): + return pa.list_(to_pyarrow_type(dtype.value_type)) + elif isinstance(dtype, dt.Set): + return pa.list_(to_pyarrow_type(dtype.value_type)) + elif isinstance(dtype, dt.Interval): + try: + return pa.duration(dtype.unit) + except ValueError: + raise com.IbisTypeError( + f"Unsupported interval unit: {dtype.unit}" + ) + raise e + + +@dt.dtype.register(pa.DataType) +def from_pyarrow_type(arrow_type, nullable=True): + return _to_ibis_dtypes[arrow_type](nullable=nullable) + + +@sch.infer.register(pa.Schema) +def infer_pyarrow_schema(schema): + fields = [(f.name, dt.dtype(f.type, nullable=f.nullable)) for f in schema] + return sch.schema(fields) class DataFusionExprTranslator: @@ -41,7 +112,7 @@ def translate(self, expr, **kwargs): formatter = self._registry[type(op)] except KeyError: raise com.OperationNotDefinedError( - 'No translation rule for {}'.format(type(op)) + f'No translation rule for {type(op)}' ) result = formatter(self, expr, **kwargs) @@ -51,7 +122,7 @@ def translate(self, expr, **kwargs): compiles = DataFusionExprTranslator.compiles -@compiles(DataFusionTable) +@compiles(ops.DatabaseTable) def compile_table(t, expr): op = expr.op() name, _, client = op.args @@ -431,14 +502,16 @@ def compile_mean(t, expr): def _prepare_contains_options(options): + print(type(options)) if isinstance(options, ir.AnyScalar): # TODO(kszucs): it would be better if we could pass an arrow # ListScalar to datafusions in_list function - return [df.literal(pa.scalar(v)) for v in options.op().value] + return [df.literal(v) for v in options.op().value] elif isinstance(options, ir.ListExpr): return options else: - return t.translate(options) + pass + # return t.translate(options) @compiles(ops.Contains) @@ -446,6 +519,8 @@ def compile_contains(t, expr): op = expr.op() value = t.translate(op.value) options = _prepare_contains_options(op.options) + print(value) + print(options) return df.functions.in_list(value, options, False) @@ -461,10 +536,11 @@ def compile_not_contains(t, expr): def compile_elementwise_udf(t, expr): op = expr.op() - udf = df.functions.udf( + udf = df.udf( op.func, input_types=list(map(to_pyarrow_type, op.input_type)), return_type=to_pyarrow_type(op.return_type), + volatility="volatile", ) args = map(t.translate, op.func_args) diff --git a/pyproject.toml b/pyproject.toml index 72874f20ce7ed..6c4768fcf14c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dask = { version = "^2021.2.0", optional = true, extras = [ "dataframe" ] } # TODO(kszucs): pin datafusion's version +# datafusion = { verion = ">=0.4", optional = true } geoalchemy2 = { version = ">=0.6,<0.10", optional = true } geopandas = { version = ">=0.6,<0.11", optional = true } graphviz = { version = "^0.16", optional = true }