From 9e763a50fe5b8c8b06706845e1f5152727cbf11e Mon Sep 17 00:00:00 2001 From: Stu Hood Date: Wed, 19 Jan 2022 15:16:57 -0800 Subject: [PATCH] [internal] Remove the minimum bucket size of batching to improve stability. (#14210) As a followup to #14186, this change improves the stability (and thus cache hit rates) of batching by removing the minimum bucket size. It also fixes an issue in the tests, and expands the range that they test. As mentioned in the expanded comments: capping bucket sizes (in either the `min` or the `max` direction) can cause streaks of bucket changes: when a bucket hits a `min`/`max` threshold and ignores a boundary, it increases the chance that the next bucket will trip a threshold as well. Although it would be most-stable to remove the `max` threshold entirely, it is necessary to resolve the correctness issue of #13462. But we _can_ remove the `min` threshold, and so this change does that. [ci skip-rust] [ci skip-build-wheels] --- src/python/pants/core/goals/fmt.py | 5 +++- src/python/pants/core/goals/lint.py | 5 +++- src/python/pants/core/goals/style_request.py | 5 +++- src/python/pants/util/collections.py | 29 ++++++++++---------- src/python/pants/util/collections_test.py | 17 +++++++----- 5 files changed, 36 insertions(+), 25 deletions(-) diff --git a/src/python/pants/core/goals/fmt.py b/src/python/pants/core/goals/fmt.py index 3aef75e0eb6..326f4dc2cfa 100644 --- a/src/python/pants/core/goals/fmt.py +++ b/src/python/pants/core/goals/fmt.py @@ -212,7 +212,10 @@ async def fmt( ) for fmt_requests, targets in targets_by_fmt_request_order.items() for target_batch in partition_sequentially( - targets, key=lambda t: t.address.spec, size_min=fmt_subsystem.batch_size + targets, + key=lambda t: t.address.spec, + size_target=fmt_subsystem.batch_size, + size_max=4 * fmt_subsystem.batch_size, ) ) diff --git a/src/python/pants/core/goals/lint.py b/src/python/pants/core/goals/lint.py index 609bd9c84f1..bdf0aa913ec 100644 --- a/src/python/pants/core/goals/lint.py +++ b/src/python/pants/core/goals/lint.py @@ -230,7 +230,10 @@ def address_str(fs: FieldSet) -> str: for request in requests if request.field_sets for field_set_batch in partition_sequentially( - request.field_sets, key=address_str, size_min=lint_subsystem.batch_size + request.field_sets, + key=address_str, + size_target=lint_subsystem.batch_size, + size_max=4 * lint_subsystem.batch_size, ) ) diff --git a/src/python/pants/core/goals/style_request.py b/src/python/pants/core/goals/style_request.py index 0fe9ef49e19..93896b2d695 100644 --- a/src/python/pants/core/goals/style_request.py +++ b/src/python/pants/core/goals/style_request.py @@ -26,7 +26,7 @@ def style_batch_size_help(uppercase: str, lowercase: str) -> str: return ( - f"The target minimum number of files that will be included in each {lowercase} batch.\n" + f"The target number of files to be included in each {lowercase} batch.\n" "\n" f"{uppercase} processes are batched for a few reasons:\n" "\n" @@ -38,6 +38,9 @@ def style_batch_size_help(uppercase: str, lowercase: str) -> str: "parallelism, or -- if they do support internal parallelism -- to improve scheduling " "behavior when multiple processes are competing for cores and so internal " "parallelism cannot be used perfectly.\n" + "\n" + "In order to improve cache hit rates (see 2.), batches are created at stable boundaries, " + 'and so this value is only a "target" batch size (rather than an exact value).' ) diff --git a/src/python/pants/util/collections.py b/src/python/pants/util/collections.py index 5606ffd38de..7a1f44e0e72 100644 --- a/src/python/pants/util/collections.py +++ b/src/python/pants/util/collections.py @@ -80,31 +80,32 @@ def partition_sequentially( items: Iterable[_T], *, key: Callable[[_T], str], - size_min: int, + size_target: int, size_max: int | None = None, ) -> Iterator[list[_T]]: - """Stably partitions the given items into batches of at least size_min. + """Stably partitions the given items into batches of around `size_target` items. The "stability" property refers to avoiding adjusting all batches when a single item is added, which could happen if the items were trivially windowed using `itertools.islice` and an item was added near the front of the list. - Batches will be capped to `size_max`, which defaults `size_min*2`. + Batches will optionally be capped to `size_max`, but note that this can weaken the stability + properties of the bucketing, by forcing bucket boundaries to be created where they otherwise + might not. """ - # To stably partition the arguments into ranges of at least `size_min`, we sort them, and - # create a new batch sequentially once we have the minimum number of entries, _and_ we encounter - # an item hash prefixed with a threshold of zeros. + # To stably partition the arguments into ranges of approximately `size_target`, we sort them, + # and create a new batch sequentially once we encounter an item hash prefixed with a threshold + # of zeros. # # The hashes act like a (deterministic) series of rolls of an evenly distributed die. The # probability of a hash prefixed with Z zero bits is 1/2^Z, and so to break after N items on # average, we look for `Z == log2(N)` zero bits. # - # Breaking on these deterministic boundaries means that adding any single item will affect - # either one bucket (if the item does not create a boundary) or two (if it does create a - # boundary). - zero_prefix_threshold = math.log(max(4, size_min) // 4, 2) - size_max = size_min * 2 if size_max is None else size_max + # Breaking on these deterministic boundaries reduces the chance that adding or removing items + # causes multiple buckets to be recalculated. But when a `size_max` value is set, it's possible + # for adding items to cause multiple sequential buckets to be affected. + zero_prefix_threshold = math.log(max(1, size_target), 2) batch: list[_T] = [] @@ -121,10 +122,8 @@ def emit_batch() -> list[_T]: for item_key, item in keyed_items: batch.append(item) - if ( - len(batch) >= size_min - and native_engine.hash_prefix_zero_bits(item_key) >= zero_prefix_threshold - ) or (len(batch) >= size_max): + prefix_zero_bits = native_engine.hash_prefix_zero_bits(item_key) + if prefix_zero_bits >= zero_prefix_threshold or (size_max and len(batch) >= size_max): yield emit_batch() if batch: yield emit_batch() diff --git a/src/python/pants/util/collections_test.py b/src/python/pants/util/collections_test.py index 2acf605f94e..2f070437b49 100644 --- a/src/python/pants/util/collections_test.py +++ b/src/python/pants/util/collections_test.py @@ -88,21 +88,24 @@ def test_ensure_str_list() -> None: ensure_str_list([0, 1]) # type: ignore[list-item] -@pytest.mark.parametrize("size_min", [0, 1, 16, 32, 64, 128]) -def test_partition_sequentially(size_min: int) -> None: +@pytest.mark.parametrize("size_target", [0, 1, 8, 16, 32, 64, 128]) +def test_partition_sequentially(size_target: int) -> None: # Adding an item at any position in the input sequence should affect either 1 or 2 (if the added # item becomes a boundary) buckets in the output. def partitioned_buckets(items: list[str]) -> set[tuple[str, ...]]: - return set(tuple(p) for p in partition_sequentially(items, key=str, size_min=size_min)) + return set( + tuple(p) for p in partition_sequentially(items, key=str, size_target=size_target) + ) # We start with base items containing every other element from a sorted sequence. - all_items = sorted((f"item{i}" for i in range(0, 64))) + all_items = sorted((f"item{i}" for i in range(0, 1024))) base_items = [item for i, item in enumerate(all_items) if i % 2 == 0] base_partitions = partitioned_buckets(base_items) - # Then test that adding any of the remaining items elements (which will be interspersed in the - # base items) only affects 1 or 2 buckets in the output. + # Then test that adding any of the remaining items (which will be interspersed in the base + # items) only affects 1 or 2 buckets in the output (representing between a 1 and 4 delta + # in the `^`/symmetric_difference between before and after). for to_add in [item for i, item in enumerate(all_items) if i % 2 == 1]: updated_partitions = partitioned_buckets([to_add, *base_items]) - assert 1 <= len(base_partitions ^ updated_partitions) <= 2 + assert 1 <= len(base_partitions ^ updated_partitions) <= 4