diff --git a/pipeline_dp/pipeline_backend.py b/pipeline_dp/pipeline_backend.py index 29f8b1e9..42c249db 100644 --- a/pipeline_dp/pipeline_backend.py +++ b/pipeline_dp/pipeline_backend.py @@ -1,6 +1,6 @@ """Adapters for working with pipeline frameworks.""" -from functools import partial +import functools import multiprocessing as mp import random import numpy as np @@ -9,6 +9,7 @@ import apache_beam as beam import apache_beam.transforms.combiners as combiners import pipeline_dp.accumulator as accumulator +import pipeline_dp.combiners as dp_combiners import typing import collections import itertools @@ -87,6 +88,23 @@ def reduce_accumulators_per_key(self, col, stage_name: str): """ pass + @abc.abstractmethod + def combine_accumulators_per_key(self, col, combiner: dp_combiners.Combiner, + stage_name: str): + """Reduces the input collection so that all elements per each key are merged. + + Args: + col: input collection which contains tuples (key, accumulator). + combiner: combiner which knows how to perform aggregation on + accumulators in col. + stage_name: name of the stage. + + Returns: + A collection of tuples (key, accumulator). + + """ + pass + @abc.abstractmethod def flatten(self, col1, col2, stage_name: str): """ @@ -229,6 +247,21 @@ def merge_accumulators(accumulators): return col | self._ulg.unique(stage_name) >> beam.CombinePerKey( merge_accumulators) + def combine_accumulators_per_key(self, col, combiner: dp_combiners.Combiner, + stage_name: str): + + def merge_accumulators(accumulators): + res = None + for acc in accumulators: + if res: + res = combiner.merge_accumulators(res, acc) + else: + res = acc + return res + + return col | self._ulg.unique(stage_name) >> beam.CombinePerKey( + merge_accumulators) + def flatten(self, col1, col2, stage_name: str): return (col1, col2) | self._ulg.unique(stage_name) >> beam.Flatten() @@ -308,6 +341,13 @@ def count_per_element(self, rdd, stage_name: str = None): def reduce_accumulators_per_key(self, rdd, stage_name: str): return rdd.reduceByKey(lambda acc1, acc2: acc1.add_accumulator(acc2)) + def combine_accumulators_per_key(self, + rdd, + combiner: dp_combiners.Combiner, + stage_name: str = None): + return rdd.reduceByKey( + lambda acc1, acc2: combiner.merge_accumulators(acc1, acc2)) + def is_serialization_immediate_on_reduce_by_key(self): return True @@ -382,6 +422,18 @@ def count_per_element(self, col, stage_name: typing.Optional[str] = None): def reduce_accumulators_per_key(self, col, stage_name: str = None): return self.map_values(self.group_by_key(col), accumulator.merge) + def combine_accumulators_per_key(self, + col, + combiner: dp_combiners.Combiner, + stage_name: str = None): + + def merge_accumulators(accumulators): + return functools.reduce( + lambda acc1, acc2: combiner.merge_accumulators(acc1, acc2), + accumulators) + + return self.map_values(self.group_by_key(col), merge_accumulators) + def flatten(self, col1, col2, stage_name: str = None): return itertools.chain(col1, col2) @@ -463,7 +515,7 @@ def insert_row(captures, row): key, val = row results_dict_[key].append(val) - insert_row = partial(insert_row, (self.results_dict,)) + insert_row = functools.partial(insert_row, (self.results_dict,)) super().__init__(insert_row, job_inputs, @@ -496,7 +548,7 @@ def insert_row(captures, key): (results_dict_,) = captures results_dict_[key] += 1 - insert_row = partial(insert_row, (self.results_dict,)) + insert_row = functools.partial(insert_row, (self.results_dict,)) super().__init__(insert_row, job_inputs, @@ -554,7 +606,7 @@ def filter_by_key(self, def mapped_fn(keys_to_keep_, kv): return kv, (kv[0] in keys_to_keep_) - mapped_fn = partial(mapped_fn, keys_to_keep) + mapped_fn = functools.partial(mapped_fn, keys_to_keep) key_keep = self.map(col, mapped_fn, stage_name) return (row for row, keep in key_keep if keep) @@ -579,7 +631,7 @@ def mapped_fn(captures, row): samples = random.sample(samples, n_) return partition_key, samples - mapped_fn = partial(mapped_fn, (n,)) + mapped_fn = functools.partial(mapped_fn, (n,)) groups = self.group_by_key(col, stage_name) return self.map(groups, mapped_fn, stage_name) @@ -592,5 +644,11 @@ def reduce_accumulators_per_key(self, stage_name: typing.Optional[str] = None): return self.map_values(col, accumulator.merge) + def combine_accumulators_per_key(self, col, combiner: dp_combiners.Combiner, + stage_name: str): + raise NotImplementedError( + "combine_accumulators_per_key is not implmeneted for MultiProcLocalBackend" + ) + def flatten(self, col1, col2, stage_name: str = None): - return itertools.chain(col1, col2) \ No newline at end of file + return itertools.chain(col1, col2) diff --git a/tests/pipeline_operations_test.py b/tests/pipeline_backend_test.py similarity index 93% rename from tests/pipeline_operations_test.py rename to tests/pipeline_backend_test.py index 44a5bfb9..3dc5ec0d 100644 --- a/tests/pipeline_operations_test.py +++ b/tests/pipeline_backend_test.py @@ -11,6 +11,7 @@ from pipeline_dp.pipeline_backend import MultiProcLocalBackend, SparkRDDBackend from pipeline_dp.pipeline_backend import LocalBackend from pipeline_dp.pipeline_backend import BeamBackend +import pipeline_dp.combiners as dp_combiners class PipelineBackendTest(unittest.TestCase): @@ -81,6 +82,23 @@ def test_reduce_accumulators_per_key(self): beam_util.assert_that(result, beam_util.equal_to([(6, 2), (7, 2), (8, 1)])) + def test_combine_accumulators_per_key(self): + with test_pipeline.TestPipeline() as p: + col = p | "Create PCollection" >> beam.Create([(6, 1), (7, 1), + (6, 1), (7, 1), + (8, 1)]) + sum_combiner = SumCombiner() + col = self.ops.group_by_key(col, "group_by_key") + col = self.ops.map_values(col, sum_combiner.create_accumulator, + "Wrap into accumulators") + col = self.ops.combine_accumulators_per_key( + col, sum_combiner, "Reduce accumulators per key") + result = self.ops.map_values(col, sum_combiner.compute_metrics, + "Compute metrics") + + beam_util.assert_that(result, + beam_util.equal_to([(6, 2), (7, 2), (8, 1)])) + class BeamBackendStageNameTest(unittest.TestCase): @@ -277,6 +295,16 @@ def test_reduce_accumulators_per_key(self): result = dict(result) self.assertDictEqual(result, {1: 41, 2: 47, 3: 33}) + def test_combine_accumulators_per_key(self): + data = self.sc.parallelize([(1, 2), (2, 1), (1, 4), (3, 8), (2, 3)]) + rdd = self.ops.group_by_key(data) + sum_combiner = SumCombiner() + rdd = self.ops.map_values(rdd, sum_combiner.create_accumulator) + rdd = self.ops.combine_accumulators_per_key(rdd, sum_combiner) + rdd = self.ops.map_values(rdd, sum_combiner.compute_metrics) + result = dict(rdd.collect()) + self.assertDictEqual(result, {1: 6, 2: 4, 3: 8}) + def test_map_tuple(self): data = [(1, 2), (3, 4)] dist_data = self.sc.parallelize(data) @@ -410,6 +438,16 @@ def test_local_reduce_accumulators_per_key(self): result = list(map(lambda row: (row[0], row[1].get_metrics()), col)) self.assertEqual(result, [(1, 6), (2, 4), (3, 8)]) + def test_local_combine_accumulators_per_key(self): + data = [(1, 2), (2, 1), (1, 4), (3, 8), (2, 3)] + col = self.ops.group_by_key(data) + sum_combiner = SumCombiner() + col = self.ops.map_values(col, sum_combiner.create_accumulator) + col = self.ops.combine_accumulators_per_key(col, sum_combiner) + col = self.ops.map_values(col, sum_combiner.compute_metrics) + result = list(col) + self.assertEqual(result, [(1, 6), (2, 4), (3, 8)]) + def test_laziness(self): def exceptions_generator_function(): @@ -751,5 +789,17 @@ def add_accumulator(self, return self +class SumCombiner(dp_combiners.Combiner): + + def create_accumulator(self, values) -> float: + return sum(values) + + def merge_accumulators(self, sum1: float, sum2: float): + return sum1 + sum2 + + def compute_metrics(self, sum: float) -> float: + return sum + + if __name__ == '__main__': unittest.main()