Skip to content

Commit

Permalink
Introduce combine_accumulators_per_key in PipelineBackend (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
dvadym authored Jan 19, 2022
1 parent 4e60cc6 commit 7248f72
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 6 deletions.
70 changes: 64 additions & 6 deletions pipeline_dp/pipeline_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Adapters for working with pipeline frameworks."""

from functools import partial
import functools
import multiprocessing as mp
import random
import numpy as np
Expand All @@ -9,6 +9,7 @@
import apache_beam as beam
import apache_beam.transforms.combiners as combiners
import pipeline_dp.accumulator as accumulator
import pipeline_dp.combiners as dp_combiners
import typing
import collections
import itertools
Expand Down Expand Up @@ -87,6 +88,23 @@ def reduce_accumulators_per_key(self, col, stage_name: str):
"""
pass

@abc.abstractmethod
def combine_accumulators_per_key(self, col, combiner: dp_combiners.Combiner,
stage_name: str):
"""Reduces the input collection so that all elements per each key are merged.
Args:
col: input collection which contains tuples (key, accumulator).
combiner: combiner which knows how to perform aggregation on
accumulators in col.
stage_name: name of the stage.
Returns:
A collection of tuples (key, accumulator).
"""
pass

@abc.abstractmethod
def flatten(self, col1, col2, stage_name: str):
"""
Expand Down Expand Up @@ -229,6 +247,21 @@ def merge_accumulators(accumulators):
return col | self._ulg.unique(stage_name) >> beam.CombinePerKey(
merge_accumulators)

def combine_accumulators_per_key(self, col, combiner: dp_combiners.Combiner,
stage_name: str):

def merge_accumulators(accumulators):
res = None
for acc in accumulators:
if res:
res = combiner.merge_accumulators(res, acc)
else:
res = acc
return res

return col | self._ulg.unique(stage_name) >> beam.CombinePerKey(
merge_accumulators)

def flatten(self, col1, col2, stage_name: str):
return (col1, col2) | self._ulg.unique(stage_name) >> beam.Flatten()

Expand Down Expand Up @@ -308,6 +341,13 @@ def count_per_element(self, rdd, stage_name: str = None):
def reduce_accumulators_per_key(self, rdd, stage_name: str):
return rdd.reduceByKey(lambda acc1, acc2: acc1.add_accumulator(acc2))

def combine_accumulators_per_key(self,
rdd,
combiner: dp_combiners.Combiner,
stage_name: str = None):
return rdd.reduceByKey(
lambda acc1, acc2: combiner.merge_accumulators(acc1, acc2))

def is_serialization_immediate_on_reduce_by_key(self):
return True

Expand Down Expand Up @@ -382,6 +422,18 @@ def count_per_element(self, col, stage_name: typing.Optional[str] = None):
def reduce_accumulators_per_key(self, col, stage_name: str = None):
return self.map_values(self.group_by_key(col), accumulator.merge)

def combine_accumulators_per_key(self,
col,
combiner: dp_combiners.Combiner,
stage_name: str = None):

def merge_accumulators(accumulators):
return functools.reduce(
lambda acc1, acc2: combiner.merge_accumulators(acc1, acc2),
accumulators)

return self.map_values(self.group_by_key(col), merge_accumulators)

def flatten(self, col1, col2, stage_name: str = None):
return itertools.chain(col1, col2)

Expand Down Expand Up @@ -463,7 +515,7 @@ def insert_row(captures, row):
key, val = row
results_dict_[key].append(val)

insert_row = partial(insert_row, (self.results_dict,))
insert_row = functools.partial(insert_row, (self.results_dict,))

super().__init__(insert_row,
job_inputs,
Expand Down Expand Up @@ -496,7 +548,7 @@ def insert_row(captures, key):
(results_dict_,) = captures
results_dict_[key] += 1

insert_row = partial(insert_row, (self.results_dict,))
insert_row = functools.partial(insert_row, (self.results_dict,))

super().__init__(insert_row,
job_inputs,
Expand Down Expand Up @@ -554,7 +606,7 @@ def filter_by_key(self,
def mapped_fn(keys_to_keep_, kv):
return kv, (kv[0] in keys_to_keep_)

mapped_fn = partial(mapped_fn, keys_to_keep)
mapped_fn = functools.partial(mapped_fn, keys_to_keep)
key_keep = self.map(col, mapped_fn, stage_name)
return (row for row, keep in key_keep if keep)

Expand All @@ -579,7 +631,7 @@ def mapped_fn(captures, row):
samples = random.sample(samples, n_)
return partition_key, samples

mapped_fn = partial(mapped_fn, (n,))
mapped_fn = functools.partial(mapped_fn, (n,))
groups = self.group_by_key(col, stage_name)
return self.map(groups, mapped_fn, stage_name)

Expand All @@ -592,5 +644,11 @@ def reduce_accumulators_per_key(self,
stage_name: typing.Optional[str] = None):
return self.map_values(col, accumulator.merge)

def combine_accumulators_per_key(self, col, combiner: dp_combiners.Combiner,
stage_name: str):
raise NotImplementedError(
"combine_accumulators_per_key is not implmeneted for MultiProcLocalBackend"
)

def flatten(self, col1, col2, stage_name: str = None):
return itertools.chain(col1, col2)
return itertools.chain(col1, col2)
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pipeline_dp.pipeline_backend import MultiProcLocalBackend, SparkRDDBackend
from pipeline_dp.pipeline_backend import LocalBackend
from pipeline_dp.pipeline_backend import BeamBackend
import pipeline_dp.combiners as dp_combiners


class PipelineBackendTest(unittest.TestCase):
Expand Down Expand Up @@ -81,6 +82,23 @@ def test_reduce_accumulators_per_key(self):
beam_util.assert_that(result,
beam_util.equal_to([(6, 2), (7, 2), (8, 1)]))

def test_combine_accumulators_per_key(self):
with test_pipeline.TestPipeline() as p:
col = p | "Create PCollection" >> beam.Create([(6, 1), (7, 1),
(6, 1), (7, 1),
(8, 1)])
sum_combiner = SumCombiner()
col = self.ops.group_by_key(col, "group_by_key")
col = self.ops.map_values(col, sum_combiner.create_accumulator,
"Wrap into accumulators")
col = self.ops.combine_accumulators_per_key(
col, sum_combiner, "Reduce accumulators per key")
result = self.ops.map_values(col, sum_combiner.compute_metrics,
"Compute metrics")

beam_util.assert_that(result,
beam_util.equal_to([(6, 2), (7, 2), (8, 1)]))


class BeamBackendStageNameTest(unittest.TestCase):

Expand Down Expand Up @@ -277,6 +295,16 @@ def test_reduce_accumulators_per_key(self):
result = dict(result)
self.assertDictEqual(result, {1: 41, 2: 47, 3: 33})

def test_combine_accumulators_per_key(self):
data = self.sc.parallelize([(1, 2), (2, 1), (1, 4), (3, 8), (2, 3)])
rdd = self.ops.group_by_key(data)
sum_combiner = SumCombiner()
rdd = self.ops.map_values(rdd, sum_combiner.create_accumulator)
rdd = self.ops.combine_accumulators_per_key(rdd, sum_combiner)
rdd = self.ops.map_values(rdd, sum_combiner.compute_metrics)
result = dict(rdd.collect())
self.assertDictEqual(result, {1: 6, 2: 4, 3: 8})

def test_map_tuple(self):
data = [(1, 2), (3, 4)]
dist_data = self.sc.parallelize(data)
Expand Down Expand Up @@ -410,6 +438,16 @@ def test_local_reduce_accumulators_per_key(self):
result = list(map(lambda row: (row[0], row[1].get_metrics()), col))
self.assertEqual(result, [(1, 6), (2, 4), (3, 8)])

def test_local_combine_accumulators_per_key(self):
data = [(1, 2), (2, 1), (1, 4), (3, 8), (2, 3)]
col = self.ops.group_by_key(data)
sum_combiner = SumCombiner()
col = self.ops.map_values(col, sum_combiner.create_accumulator)
col = self.ops.combine_accumulators_per_key(col, sum_combiner)
col = self.ops.map_values(col, sum_combiner.compute_metrics)
result = list(col)
self.assertEqual(result, [(1, 6), (2, 4), (3, 8)])

def test_laziness(self):

def exceptions_generator_function():
Expand Down Expand Up @@ -751,5 +789,17 @@ def add_accumulator(self,
return self


class SumCombiner(dp_combiners.Combiner):

def create_accumulator(self, values) -> float:
return sum(values)

def merge_accumulators(self, sum1: float, sum2: float):
return sum1 + sum2

def compute_metrics(self, sum: float) -> float:
return sum


if __name__ == '__main__':
unittest.main()

0 comments on commit 7248f72

Please sign in to comment.