From 2ac41fc723630fe8ce6eaa948fe4c41d693c1178 Mon Sep 17 00:00:00 2001 From: Doug Davis Date: Mon, 25 Sep 2023 15:33:11 -0500 Subject: [PATCH] add tests --- src/dask_awkward/lib/optimize.py | 5 ++++ tests/test_optimize.py | 49 ++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index 585864383..80e49152c 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -231,6 +231,9 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph: chains = [] deps = copy.copy(dsk.dependencies) + # TODO: add some comments to the chaining algorithm w.r.t. when we + # use it and when we don't. + required_layers = {k[0] for k in keys} layers = {} # find chains; each chain list is at least two keys long dependents = dsk.dependents @@ -250,6 +253,7 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph: and dsk.dependencies[list(children)[0]] == {lay} and isinstance(dsk.layers[list(children)[0]], AwkwardBlockwiseLayer) and len(dsk.layers[lay]) == len(dsk.layers[list(children)[0]]) + and lay not in required_layers ): # walk forwards lay = list(children)[0] @@ -263,6 +267,7 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph: and dependents[list(parents)[0]] == {lay} and isinstance(dsk.layers[list(parents)[0]], AwkwardBlockwiseLayer) and len(dsk.layers[lay]) == len(dsk.layers[list(parents)[0]]) + and list(parents)[0] not in required_layers ): # walk backwards lay = list(parents)[0] diff --git a/tests/test_optimize.py b/tests/test_optimize.py index afb2da88b..3aa57531d 100644 --- a/tests/test_optimize.py +++ b/tests/test_optimize.py @@ -4,6 +4,7 @@ import dask import dask_awkward as dak +from dask_awkward.lib.testutils import assert_eq def test_multiple_computes(pq_points_dir: str) -> None: @@ -27,3 +28,51 @@ def test_multiple_computes(pq_points_dir: str) -> None: things = dask.compute(ds1.points, ds2.points.x, ds2.points.y, ds1.points.y, ds3) assert things[-1].tolist() == ak.Array(lists[0] + lists[1]).tolist() # type: ignore + + +def identity(x): + return x + + +def test_multiple_compute_incapsulated(): + array = ak.Array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])[[0, 2]] + darray = dak.from_awkward(array, 1) + darray_result = darray.map_partitions(identity) + + first, second = dask.compute(darray, darray_result) + + assert ak.almost_equal(first, second) + assert first.layout.form == second.layout.form + + +def test_multiple_computes_multiple_incapsulated(daa, caa): + dstep1 = daa.points.x + dstep2 = dstep1**2 + dstep3 = dstep2 + 2 + dstep4 = dstep3 - 1 + dstep5 = dstep4 - dstep2 + + cstep1 = caa.points.x + cstep2 = cstep1**2 + cstep3 = cstep2 + 2 + cstep4 = cstep3 - 1 + cstep5 = cstep4 - cstep2 + + # multiple computes all work and evaluate to the expected result + c5, c4, c2 = dask.compute(dstep5, dstep4, dstep2) + assert_eq(c5, cstep5) + assert_eq(c2, cstep2) + assert_eq(c4, cstep4) + + # if optimized together we still have 2 layers + opt4, opt3 = dask.optimize(dstep4, dstep3) + assert len(opt4.dask.layers) == 2 + assert len(opt3.dask.layers) == 2 + assert_eq(opt4, cstep4) + assert_eq(opt3, cstep3) + + # if optimized alone we get optimized to 1 entire chain smushed + # down to 1 layer + (opt4_alone,) = dask.optimize(dstep4) + assert len(opt4_alone.dask.layers) == 1 + assert_eq(opt4_alone, opt4)