diff --git a/cubed/core/optimization.py b/cubed/core/optimization.py index 7f5346ff..d8932005 100644 --- a/cubed/core/optimization.py +++ b/cubed/core/optimization.py @@ -1,4 +1,11 @@ -from cubed.primitive.blockwise import can_fuse_pipelines, fuse +import networkx as nx + +from cubed.primitive.blockwise import ( + can_fuse_multiple_pipelines, + can_fuse_pipelines, + fuse, + fuse_multiple, +) def simple_optimize_dag(dag): @@ -56,3 +63,127 @@ def can_fuse(n): dag.remove_node(op1) return dag + + +sym_counter = 0 + + +def gensym(name="op"): + global sym_counter + sym_counter += 1 + return f"{name}-{sym_counter:03}" + + +def predecessors(dag, name): + """Return a node's predecessors, with repeats for multiple edges.""" + for pre, _ in dag.in_edges(name): + yield pre + + +def predecessor_ops(dag, name): + """Return an op node's op predecessors""" + for input in predecessors(dag, name): + for pre in predecessors(dag, input): + yield pre + + +def is_fusable(node_dict): + "Return True if a node can be fused." + return "pipeline" in node_dict + + +def can_fuse_predecessors(dag, name, *, max_total_nargs=4): + nodes = dict(dag.nodes(data=True)) + + # if node itself can't be fused then there is nothing to fuse + if not is_fusable(nodes[name]): + return False + + # if no predecessor ops can be fused then there is nothing to fuse + if all(not is_fusable(nodes[pre]) for pre in predecessor_ops(dag, name)): + return False + + # if there is more than a single predecessor op, and the total number of args to + # the fused function would be more than an allowed maximum, then don't fuse + if len(list(predecessor_ops(dag, name))) > 1: + total_nargs = sum( + len(list(predecessors(dag, pre))) if is_fusable(nodes[pre]) else 1 + for pre in predecessor_ops(dag, name) + ) + if total_nargs > max_total_nargs: + return False + + predecessor_pipelines = [ + nodes[pre]["pipeline"] + for pre in predecessor_ops(dag, name) + if is_fusable(nodes[pre]) + ] + return can_fuse_multiple_pipelines(nodes[name]["pipeline"], *predecessor_pipelines) + + +def fuse_predecessors(dag, name): + """Fuse a node with its immediate predecessors.""" + + # if can't fuse then return dag unchanged + if not can_fuse_predecessors(dag, name): + return dag + + nodes = dict(dag.nodes(data=True)) + + pipeline = nodes[name]["pipeline"] + + # if a predecessor op has no pipeline then just use None + predecessor_pipelines = [ + nodes[pre]["pipeline"] if is_fusable(nodes[pre]) else None + for pre in predecessor_ops(dag, name) + ] + + # if a predecessor op has no func then use 1 for nargs + predecessor_funcs_nargs = [ + len(list(predecessors(dag, pre))) if is_fusable(nodes[pre]) else 1 + for pre in predecessor_ops(dag, name) + ] + + fused_pipeline = fuse_multiple( + pipeline, + *predecessor_pipelines, + predecessor_funcs_nargs=predecessor_funcs_nargs, + ) + + fused_dag = dag.copy() + fused_nodes = dict(fused_dag.nodes(data=True)) + + fused_nodes[name]["pipeline"] = fused_pipeline + + # re-wire dag to remove predecessor nodes that have been fused + + # 1. update edges to change inputs + for input in predecessors(dag, name): + pre = next(predecessors(dag, input)) + if not is_fusable(fused_nodes[pre]): + # if a predecessor is marked as not fusable then don't change the edge + continue + fused_dag.remove_edge(input, name) + for pre in predecessor_ops(dag, name): + if not is_fusable(fused_nodes[pre]): + # if a predecessor is marked as not fusable then don't change the edge + continue + for input in predecessors(dag, pre): + fused_dag.add_edge(input, name) + + # 2. remove predecessor nodes with no successors + # (ones with successors are needed by other nodes) + for input in predecessors(dag, name): + if fused_dag.out_degree(input) == 0: + for pre in list(predecessors(fused_dag, input)): + fused_dag.remove_node(pre) + fused_dag.remove_node(input) + + return fused_dag + + +def multiple_inputs_optimize_dag(dag): + """Fuse multiple inputs.""" + for name in list(nx.topological_sort(dag)): + dag = fuse_predecessors(dag, name) + return dag diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index 64d8ea3b..a99fb149 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -15,7 +15,7 @@ from cubed.runtime.types import CubedPipeline from cubed.storage.zarr import T_ZarrArray, lazy_empty from cubed.types import T_Chunks, T_DType, T_Shape, T_Store -from cubed.utils import chunk_memory, get_item, to_chunksize +from cubed.utils import chunk_memory, get_item, split_into, to_chunksize from cubed.vendor.dask.array.core import normalize_chunks from cubed.vendor.dask.blockwise import _get_coord_mapping, _make_dims, lol_product from cubed.vendor.dask.core import flatten @@ -265,6 +265,16 @@ def can_fuse_pipelines(pipeline1: CubedPipeline, pipeline2: CubedPipeline) -> bo return False +def can_fuse_multiple_pipelines( + pipeline: CubedPipeline, *predecessor_pipelines: CubedPipeline +) -> bool: + if is_fuse_candidate(pipeline) and all( + is_fuse_candidate(p) for p in predecessor_pipelines + ): + return all(pipeline.num_tasks == p.num_tasks for p in predecessor_pipelines) + return False + + def fuse(pipeline1: CubedPipeline, pipeline2: CubedPipeline) -> CubedPipeline: """ Fuse two blockwise pipelines into a single pipeline, avoiding writing to (or reading from) the target of the first pipeline. @@ -304,6 +314,84 @@ def fused_func(*args): ) +def fuse_multiple( + pipeline: CubedPipeline, + *predecessor_pipelines: CubedPipeline, + predecessor_funcs_nargs=None, +) -> CubedPipeline: + """ + Fuse a blockwise pipeline and its predecessors into a single pipeline, avoiding writing to (or reading from) the targets of the predecessor pipelines. + """ + + assert all( + pipeline.num_tasks == p.num_tasks + for p in predecessor_pipelines + if p is not None + ) + + mappable = pipeline.mappable + + def apply_pipeline_block_func(pipeline, arg): + if pipeline is None: + return (arg,) + return pipeline.config.block_function(arg) + + def fused_blockwise_func(out_key): + # this will change when multiple outputs are supported + args = pipeline.config.block_function(out_key) + # flatten one level of args as the fused_func adds back grouping structure + func_args = tuple( + item + for p, a in zip(predecessor_pipelines, args) + for item in apply_pipeline_block_func(p, a) + ) + return func_args + + def apply_pipeline_func(pipeline, *args): + if pipeline is None: + return args[0] + return pipeline.config.function(*args) + + def fused_func(*args): + # split all args to the fused function into groups, one for each predecessor function + split_args = split_into(args, predecessor_funcs_nargs) + func_args = [ + apply_pipeline_func(p, *a) + for p, a in zip(predecessor_pipelines, split_args) + ] + return pipeline.config.function(*func_args) + + read_proxies = dict(pipeline.config.reads_map) + for p in predecessor_pipelines: + if p is not None: + read_proxies.update(p.config.reads_map) + write_proxy = pipeline.config.write + spec = BlockwiseSpec(fused_blockwise_func, fused_func, read_proxies, write_proxy) + + target_array = pipeline.target_array + projected_mem = max( + pipeline.projected_mem, + *(p.projected_mem for p in predecessor_pipelines if p is not None), + ) + reserved_mem = max( + pipeline.reserved_mem, + *(p.reserved_mem for p in predecessor_pipelines if p is not None), + ) + num_tasks = pipeline.num_tasks + + return CubedPipeline( + apply_blockwise, + gensym("fused_apply_blockwise"), + mappable, + spec, + target_array, + projected_mem, + reserved_mem, + num_tasks, + None, + ) + + # blockwise functions diff --git a/cubed/tests/test_optimization.py b/cubed/tests/test_optimization.py index e406064a..d97588b1 100644 --- a/cubed/tests/test_optimization.py +++ b/cubed/tests/test_optimization.py @@ -1,9 +1,16 @@ +from functools import partial + +import networkx as nx import numpy as np import pytest from numpy.testing import assert_array_equal import cubed import cubed.array_api as xp +from cubed.backend_array_api import namespace as nxp +from cubed.core.ops import elemwise +from cubed.core.optimization import fuse_predecessors, gensym +from cubed.core.plan import arrays_to_plan from cubed.tests.utils import TaskCounter @@ -118,3 +125,468 @@ def custom_optimize_function(dag): ) == num_tasks_with_no_optimization ) + + +def get_optimize_function(arr): + # use fuse_predecessors to test one level of fusion + return partial(fuse_predecessors, name=next(arr.plan.dag.predecessors(arr.name))) + + +# utility functions for testing structural equivalence of dags + + +def create_dag(): + return nx.MultiDiGraph() + + +def add_op(dag, func, inputs, outputs, fusable=True): + name = gensym(func.__name__) + dag.add_node(name, func=func, fusable=fusable) + for n in inputs: + dag.add_edge(n, name) + for n in outputs: + dag.add_node(n) + dag.add_edge(name, n) + + return name + + +def placeholder_func(*x): + return 1 + + +def add_placeholder_op(dag, inputs, outputs): + add_op(dag, placeholder_func, [a.name for a in inputs], [b.name for b in outputs]) + + +def structurally_equivalent(dag1, dag2): + # compare structure, and node labels for values but not operators since they are placeholders + + # draw_dag(dag1) # uncomment for debugging + + labelled_dag1 = nx.convert_node_labels_to_integers(dag1, label_attribute="label") + labelled_dag2 = nx.convert_node_labels_to_integers(dag2, label_attribute="label") + + def nm(node_attrs1, node_attrs2): + label1 = node_attrs1["label"] + label2 = node_attrs2["label"] + # - in a label indicates that the node is an operator; don't compare these names + if "-" in label1: + return "-" in label2 + return label1 == label2 + + return nx.is_isomorphic(labelled_dag1, labelled_dag2, node_match=nm) + + +def draw_dag(dag, name="dag"): + gv = nx.drawing.nx_pydot.to_pydot(dag) + format = "svg" + full_filename = f"{name}.{format}" + gv.write(full_filename, format=format) + + +# simple unary function +# +# a -> a +# | | +# b c +# | +# c +# +def test_fuse_unary_op(spec): + a = xp.ones((2, 2), chunks=(2, 2), spec=spec) + b = xp.negative(a) + c = xp.negative(b) + + opt_fn = get_optimize_function(c) + + c.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (a,), (c,)) + assert structurally_equivalent( + c.plan.optimize(optimize_function=opt_fn).dag, + expected_fused_dag, + ) + + num_created_arrays = 2 # b, c + assert c.plan.num_tasks(optimize_graph=False) == num_created_arrays + 2 + num_created_arrays = 1 # c + assert c.plan.num_tasks(optimize_function=opt_fn) == num_created_arrays + 1 + + task_counter = TaskCounter() + result = c.compute(callbacks=[task_counter], optimize_function=opt_fn) + assert task_counter.value == num_created_arrays + 1 + + assert_array_equal(result, np.ones((2, 2))) + + +# simple binary function +# +# a b -> a b +# | | \ / +# c d e +# \ / +# e +# +def test_fuse_binary_op(spec): + a = xp.ones((2, 2), chunks=(2, 2), spec=spec) + b = xp.ones((2, 2), chunks=(2, 2), spec=spec) + c = xp.negative(a) + d = xp.negative(b) + e = xp.add(c, d) + + opt_fn = get_optimize_function(e) + + e.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (), (b,)) + add_placeholder_op(expected_fused_dag, (a, b), (e,)) + assert structurally_equivalent( + e.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag + ) + + num_created_arrays = 3 # c, d, e + assert e.plan.num_tasks(optimize_graph=False) == num_created_arrays + 3 + num_created_arrays = 1 # e + assert e.plan.num_tasks(optimize_function=opt_fn) == num_created_arrays + 1 + + task_counter = TaskCounter() + result = e.compute(callbacks=[task_counter], optimize_function=opt_fn) + assert task_counter.value == num_created_arrays + 1 + + assert_array_equal(result, -2 * np.ones((2, 2))) + + +# unary and binary functions +# +# a b c -> a b c +# | \ / \ | / +# d e f +# \ / +# f +# +def test_fuse_unary_and_binary_op(spec): + a = xp.ones((2, 2), chunks=(2, 2), spec=spec) + b = xp.ones((2, 2), chunks=(2, 2), spec=spec) + c = xp.ones((2, 2), chunks=(2, 2), spec=spec) + d = xp.negative(a) + e = xp.add(b, c) + f = xp.add(d, e) + + opt_fn = get_optimize_function(f) + + f.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (), (b,)) + add_placeholder_op(expected_fused_dag, (), (c,)) + add_placeholder_op(expected_fused_dag, (a, b, c), (f,)) + assert structurally_equivalent( + f.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag + ) + + result = f.compute(optimize_function=opt_fn) + assert_array_equal(result, np.ones((2, 2))) + + +# mixed levels +# +# b c -> a b c +# \ / \ | / +# a d e +# \ / +# e +# +def test_fuse_mixed_levels(spec): + a = xp.ones((2, 2), chunks=(2, 2), spec=spec) + b = xp.ones((2, 2), chunks=(2, 2), spec=spec) + c = xp.ones((2, 2), chunks=(2, 2), spec=spec) + d = xp.add(b, c) + e = xp.add(a, d) + + opt_fn = get_optimize_function(e) + + e.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (), (b,)) + add_placeholder_op(expected_fused_dag, (), (c,)) + add_placeholder_op(expected_fused_dag, (a, b, c), (e,)) + assert structurally_equivalent( + e.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag + ) + + result = e.compute(optimize_function=opt_fn) + assert_array_equal(result, 3 * np.ones((2, 2))) + + +# diamond +# +# a -> a +# / \ ‖ +# b c d +# \ / +# d +# +def test_fuse_diamond(spec): + a = xp.ones((2, 2), chunks=(2, 2), spec=spec) + b = xp.positive(a) + c = xp.positive(a) + d = xp.add(b, c) + + opt_fn = get_optimize_function(d) + + d.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (a, a), (d,)) + assert structurally_equivalent( + d.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag + ) + + result = d.compute(optimize_function=opt_fn) + assert_array_equal(result, 2 * np.ones((2, 2))) + + +# mixed levels and diamond +# from https://github.com/tomwhite/cubed/issues/126 +# +# a -> a +# | /| +# b b | +# /| \| +# c | d +# \| +# d +# +def test_fuse_mixed_levels_and_diamond(spec): + a = xp.ones((2, 2), chunks=(2, 2), spec=spec) + b = xp.positive(a) + c = xp.positive(b) + d = xp.add(b, c) + + opt_fn = get_optimize_function(d) + + d.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (a,), (b,)) + add_placeholder_op(expected_fused_dag, (a, b), (d,)) + assert structurally_equivalent( + d.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag + ) + + result = d.compute(optimize_function=opt_fn) + assert_array_equal(result, 2 * np.ones((2, 2))) + + +# repeated argument +# from https://github.com/tomwhite/cubed/issues/65 +# +# a -> a +# | ‖ +# b c +# ‖ +# c +# +def test_fuse_repeated_argument(spec): + a = xp.ones((2, 2), chunks=(2, 2), spec=spec) + b = xp.negative(a) + c = xp.add(b, b) + + opt_fn = get_optimize_function(c) + + c.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (a, a), (c,)) + assert structurally_equivalent( + c.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag + ) + + result = c.compute(optimize_function=opt_fn) + assert_array_equal(result, -2 * np.ones((2, 2))) + + +# other dependents +# +# a -> a +# | / \ +# b c b +# / \ | +# c d d +# +def test_fuse_other_dependents(spec): + a = xp.ones((2, 2), chunks=(2, 2), spec=spec) + b = xp.negative(a) + c = xp.negative(b) + d = xp.negative(b) + + # only fuse c; leave d unfused + opt_fn = get_optimize_function(c) + + # note multi-arg forms of visualize and compute below + cubed.visualize(c, d, optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (a,), (b,)) + add_placeholder_op(expected_fused_dag, (a,), (c,)) + add_placeholder_op(expected_fused_dag, (b,), (d,)) + plan = arrays_to_plan(c, d) + assert structurally_equivalent( + plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag + ) + + c_result, d_result = cubed.compute(c, d, optimize_function=opt_fn) + assert_array_equal(c_result, np.ones((2, 2))) + assert_array_equal(d_result, np.ones((2, 2))) + + +# large fan-in +# +# a b c d e f g h -> a b c d e f g h +# \ / \ / \ / \ / \ / \ / \ / \ / +# i j k m i j k m +# \ / \ / \ \ / / +# n o \ \ / / +# \ / ----- p ----- +# \ / +# \ / +# p +# +def test_fuse_large_fan_in(spec): + a = xp.ones((2, 2), chunks=(2, 2), spec=spec) + b = xp.ones((2, 2), chunks=(2, 2), spec=spec) + c = xp.ones((2, 2), chunks=(2, 2), spec=spec) + d = xp.ones((2, 2), chunks=(2, 2), spec=spec) + e = xp.ones((2, 2), chunks=(2, 2), spec=spec) + f = xp.ones((2, 2), chunks=(2, 2), spec=spec) + g = xp.ones((2, 2), chunks=(2, 2), spec=spec) + h = xp.ones((2, 2), chunks=(2, 2), spec=spec) + + i = xp.add(a, b) + j = xp.add(c, d) + k = xp.add(e, f) + m = xp.add(g, h) + + n = xp.add(i, j) + o = xp.add(k, m) + + p = xp.add(n, o) + + opt_fn = get_optimize_function(p) + + p.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (), (b,)) + add_placeholder_op(expected_fused_dag, (), (c,)) + add_placeholder_op(expected_fused_dag, (), (d,)) + add_placeholder_op(expected_fused_dag, (), (e,)) + add_placeholder_op(expected_fused_dag, (), (f,)) + add_placeholder_op(expected_fused_dag, (), (g,)) + add_placeholder_op(expected_fused_dag, (), (h,)) + add_placeholder_op(expected_fused_dag, (a, b), (i,)) + add_placeholder_op(expected_fused_dag, (c, d), (j,)) + add_placeholder_op(expected_fused_dag, (e, f), (k,)) + add_placeholder_op(expected_fused_dag, (g, h), (m,)) + add_placeholder_op( + expected_fused_dag, + ( + i, + j, + k, + m, + ), + (p,), + ) + assert structurally_equivalent( + p.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag + ) + + result = p.compute(optimize_function=opt_fn) + assert_array_equal(result, 8 * np.ones((2, 2))) + + +# unary large fan-in +# +# a b c d e f g h -> a b c d e f g h +# \ \ \ \ / / / / \ \ \ \ / / / / +# \ \ \ \ / / / / \ \ \ \ / / / / +# \ \ \ \ / / / / \ \ \ \ / / / / +# \ \ \ \ / / / / \ \ \ \ / / / / +# ----- i ----- ----- j ----- +# | +# j +# +def test_fuse_unary_large_fan_in(spec): + a = xp.ones((2, 2), chunks=(2, 2), spec=spec) + b = xp.ones((2, 2), chunks=(2, 2), spec=spec) + c = xp.ones((2, 2), chunks=(2, 2), spec=spec) + d = xp.ones((2, 2), chunks=(2, 2), spec=spec) + e = xp.ones((2, 2), chunks=(2, 2), spec=spec) + f = xp.ones((2, 2), chunks=(2, 2), spec=spec) + g = xp.ones((2, 2), chunks=(2, 2), spec=spec) + h = xp.ones((2, 2), chunks=(2, 2), spec=spec) + + # use elemwise and stack since add can only take 2 args + def stack_add(*a): + return nxp.sum(nxp.stack(a), axis=0) + + i = elemwise(stack_add, a, b, c, d, e, f, g, h, dtype=a.dtype) + j = xp.negative(i) + + opt_fn = get_optimize_function(j) + + j.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (), (b,)) + add_placeholder_op(expected_fused_dag, (), (c,)) + add_placeholder_op(expected_fused_dag, (), (d,)) + add_placeholder_op(expected_fused_dag, (), (e,)) + add_placeholder_op(expected_fused_dag, (), (f,)) + add_placeholder_op(expected_fused_dag, (), (g,)) + add_placeholder_op(expected_fused_dag, (), (h,)) + add_placeholder_op( + expected_fused_dag, + ( + a, + b, + c, + d, + e, + f, + g, + h, + ), + (j,), + ) + assert structurally_equivalent( + j.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag + ) + + result = j.compute(optimize_function=opt_fn) + assert_array_equal(result, -8 * np.ones((2, 2))) diff --git a/cubed/tests/test_utils.py b/cubed/tests/test_utils.py index fd1ebdaa..b9a46204 100644 --- a/cubed/tests/test_utils.py +++ b/cubed/tests/test_utils.py @@ -10,6 +10,7 @@ join_path, memory_repr, peak_measured_mem, + split_into, to_chunksize, ) @@ -70,3 +71,9 @@ def test_extract_stack_summaries(): assert stack_summaries[-1].name == "test_extract_stack_summaries" assert stack_summaries[-1].module == "cubed.tests.test_utils" assert not stack_summaries[-1].is_cubed() + + +def test_split_into(): + assert list(split_into([1, 2, 3, 4, 5, 6], [1, 2, 3])) == [[1], [2, 3], [4, 5, 6]] + assert list(split_into([1, 2, 3, 4, 5, 6], [2, 3])) == [[1, 2], [3, 4, 5]] + assert list(split_into([1, 2, 3, 4], [1, 2, 3, 4])) == [[1], [2, 3], [4], []] diff --git a/cubed/utils.py b/cubed/utils.py index 9ecc7811..9cc66b89 100644 --- a/cubed/utils.py +++ b/cubed/utils.py @@ -5,6 +5,7 @@ import sysconfig import traceback from dataclasses import dataclass +from itertools import islice from math import prod from operator import add from pathlib import Path @@ -242,3 +243,12 @@ def is_numeric_str(s: str) -> bool: return size else: raise ValueError(f"Invalid value: {size}. Must be a positive value") + + +# Based on more_itertools +def split_into(iterable, sizes): + """Yield a list of sequential items from *iterable* of length 'n' for each + integer 'n' in *sizes*.""" + it = iter(iterable) + for size in sizes: + yield list(islice(it, size))