diff --git a/tests/private_beam_test.py b/tests/private_beam_test.py index d7b1789a..f69c3808 100644 --- a/tests/private_beam_test.py +++ b/tests/private_beam_test.py @@ -4,6 +4,7 @@ from apache_beam import pvalue from unittest.mock import patch import apache_beam.testing.util as beam_util +from apache_beam.testing.test_pipeline import TestPipeline import pipeline_dp from pipeline_dp import private_beam @@ -22,6 +23,11 @@ class PrivateBeamTest(unittest.TestCase): def privacy_id_extractor(x): return f"pid:{x}" + @staticmethod + def value_per_key_within_tolerance(expected, actual, tolerance): + return actual[0] == expected[0] and abs(actual[1] - + expected[1]) <= tolerance + def test_make_private_transform_succeeds(self): runner = fn_api_runner.FnApiRunner() with beam.Pipeline(runner=runner) as pipeline: @@ -154,6 +160,45 @@ def test_sum_calls_aggregate_with_params(self, mock_aggregate): public_partitions=sum_params.public_partitions) self.assertEqual(args[1], params) + def test_sum_returns_sensible_result(self): + with TestPipeline() as pipeline: + # Arrange + col = [(u, "pk1", 100) for u in range(30)] + col += [(u + 30, "pk1", -100) for u in range(30)] + pcol = pipeline | 'Create produce' >> beam.Create(col) + # Use very high epsilon and delta to minimize noise and test + # flakiness. + budget_accountant = budget_accounting.NaiveBudgetAccountant( + total_epsilon=800, total_delta=0.999) + private_collection = ( + pcol | 'Create private collection' >> private_beam.MakePrivate( + budget_accountant=budget_accountant, + privacy_id_extractor=lambda x: x[0])) + + sum_params = aggregate_params.SumParams( + noise_kind=pipeline_dp.NoiseKind.GAUSSIAN, + max_partitions_contributed=2, + max_contributions_per_partition=3, + low=1, + high=2, + budget_weight=1, + partition_extractor=lambda x: x[1], + value_extractor=lambda x: x[2]) + + # Act + result = private_collection | private_beam.Sum( + sum_params=sum_params) + budget_accountant.compute_budgets() + + # Assert + # This is a health check to validate that the result is sensible. + # Hence, we use a very large tolerance to reduce test flakiness. + beam_util.assert_that( + result, + beam_util.equal_to([("pk1", 90)], + equals_fn=lambda e, a: PrivateBeamTest. + value_per_key_within_tolerance(e, a, 10))) + @patch('pipeline_dp.dp_engine.DPEngine.aggregate') def test_count_calls_aggregate_with_params(self, mock_aggregate): runner = fn_api_runner.FnApiRunner()