From 3a45f2ef5d3fbdead0b30e3052d848eabdc8b540 Mon Sep 17 00:00:00 2001 From: "Richard (Rick) Zamora" Date: Wed, 10 May 2023 09:07:01 -0500 Subject: [PATCH] Fix `Projection` `meta` (#78) --- dask_expr/expr.py | 19 ++++++++++++++++++- dask_expr/reductions.py | 4 ---- dask_expr/tests/test_collection.py | 3 ++- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/dask_expr/expr.py b/dask_expr/expr.py index b8617c02..7f04c818 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -17,8 +17,10 @@ _get_meta_map_partitions, apply_and_enforce, is_dataframe_like, + is_index_like, + is_series_like, ) -from dask.utils import M, apply, funcname, import_required +from dask.utils import M, apply, funcname, import_required, is_arraylike replacement_rules = [] @@ -83,6 +85,14 @@ def _tree_repr_lines(self, indent=0, recursive=True): if isinstance(op, pd.core.base.PandasObject): op = "" + elif is_dataframe_like(op): + op = "" + elif is_index_like(op): + op = "" + elif is_series_like(op): + op = "" + elif is_arraylike(op): + op = "" elif repr(op) != repr(default): if param: @@ -777,6 +787,13 @@ def columns(self): else: return self.operand("columns") + @property + def _meta(self): + if is_dataframe_like(self.frame._meta): + return super()._meta + # Avoid column selection for Series/Index + return self.frame._meta + def _node_label_args(self): return [self.frame, self.operand("columns")] diff --git a/dask_expr/reductions.py b/dask_expr/reductions.py index 442a6864..10533ac9 100644 --- a/dask_expr/reductions.py +++ b/dask_expr/reductions.py @@ -188,10 +188,6 @@ def chunk_kwargs(self): min_count=self.min_count, ) - @property - def _meta(self): - return self.frame._meta.sum(**self.chunk_kwargs) - def _simplify_up(self, parent): if isinstance(parent, Projection): return self.frame[parent.operand("columns")].sum(*self.operands[1:]) diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 4d33f842..e8fd5ce0 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -2,6 +2,7 @@ import pickle import dask +import numpy as np import pandas as pd import pytest from dask.dataframe.utils import assert_eq @@ -47,7 +48,7 @@ def test_meta_divisions_name(): assert list(df.columns) == list(a.columns) assert df.npartitions == 2 - assert df.x.sum()._meta == 0 + assert np.isscalar(df.x.sum()._meta) assert df.x.sum().npartitions == 1 assert "mul" in df._name