Skip to content

Commit

Permalink
Rewrite length filters on collection types
Browse files Browse the repository at this point in the history
  • Loading branch information
Zac-HD committed Feb 25, 2024
1 parent a807d15 commit 52f9245
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 10 deletions.
4 changes: 4 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
RELEASE_TYPE: patch

This patch implements filter-rewriting for most length filters on some
additional collection types (:issue:`3795`).
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import sys
import warnings
from collections import abc, defaultdict
from functools import lru_cache
from random import shuffle
from typing import (
Any,
Expand Down Expand Up @@ -842,6 +843,33 @@ def branches(self) -> List[SearchStrategy[Ex]]:
for strategy in self.mapped_strategy.branches
]

def filter(self, condition: Callable[[Ex], Any]) -> "SearchStrategy[Ex]":
# Includes a special case so that we can rewrite filters on collection
# lengths, when most collections are `st.lists(...).map(the_type)`.
ListStrategy = _list_strategy_type()
if not isinstance(self.mapped_strategy, ListStrategy) or not (
(isinstance(self.pack, type) and issubclass(self.pack, abc.Collection))
or self.pack is sorted
):
return super().filter(condition)

# Check whether our inner list strategy can rewrite this filter condition.
# If not, discard the result and _only_ apply a new outer filter.
new = ListStrategy.filter(self.mapped_strategy, condition)
if getattr(new, "filtered_strategy", None) is self.mapped_strategy:
return super().filter(condition) # didn't rewrite

# Apply a new outer filter even though we rewrote the inner strategy,
# because some collections can change the list length (dict, set, etc).
return FilteredStrategy(type(self)(new, self.pack), conditions=(condition,))


@lru_cache
def _list_strategy_type():
from hypothesis.strategies._internal.collections import ListStrategy

return ListStrategy


filter_not_satisfied = UniqueIdentifier("filter not satisfied")

Expand Down
31 changes: 21 additions & 10 deletions hypothesis-python/tests/cover/test_filter_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

import pytest

from hypothesis import given, strategies as st
from hypothesis import given, settings, strategies as st
from hypothesis.errors import HypothesisWarning, Unsatisfiable
from hypothesis.internal.filtering import max_len, min_len
from hypothesis.internal.floats import next_down, next_up
from hypothesis.internal.reflection import get_pretty_function_description
from hypothesis.strategies._internal.core import data
from hypothesis.strategies._internal.lazy import LazyStrategy, unwrap_strategies
from hypothesis.strategies._internal.numbers import FloatStrategy, IntegersStrategy
from hypothesis.strategies._internal.strategies import FilteredStrategy
from hypothesis.strategies._internal.strategies import FilteredStrategy, MappedStrategy
from hypothesis.strategies._internal.strings import TextStrategy

from tests.common.debug import check_can_generate_examples
Expand Down Expand Up @@ -469,24 +469,35 @@ def test_can_rewrite_multiple_length_filters_if_not_lambdas(data):
st.lists(st.integers()),
st.lists(st.integers(), unique=True),
st.lists(st.sampled_from([1, 2, 3])),
# TODO: support more collection types. Might require messing around with
# strategy internals, e.g. in MappedStrategy/FilteredStrategy.
# st.binary(),
# st.binary.map(bytearray),
# st.sets(st.integers()),
# st.dictionaries(st.integers(), st.none()),
st.binary(),
st.sets(st.integers()),
st.frozensets(st.integers()),
st.dictionaries(st.integers(), st.none()),
st.lists(st.integers(), unique_by=lambda x: x % 17).map(tuple),
],
ids=get_pretty_function_description,
)
@settings(max_examples=15)
@given(data=st.data())
def test_filter_rewriting_text_lambda_len(data, strategy, predicate, start, end):
s = strategy.filter(predicate)
unwrapped_nofilter = unwrap_strategies(strategy)
unwrapped = unwrap_strategies(s)
assert isinstance(unwrapped, FilteredStrategy)
assert isinstance(unwrapped.filtered_strategy, type(unwrap_strategies(strategy)))

if was_mapped := isinstance(unwrapped, MappedStrategy):
unwrapped = unwrapped.mapped_strategy

assert isinstance(unwrapped, FilteredStrategy), f"{unwrapped=} {type(unwrapped)=}"
assert isinstance(
unwrapped.filtered_strategy,
type(unwrapped_nofilter.mapped_strategy if was_mapped else unwrapped_nofilter),
)
for pred in unwrapped.flat_conditions:
assert pred.__name__ == "<lambda>"

if isinstance(unwrapped.filtered_strategy, MappedStrategy):
unwrapped = unwrapped.filtered_strategy.mapped_strategy

assert unwrapped.filtered_strategy.min_size == start
assert unwrapped.filtered_strategy.max_size == end
value = data.draw(s)
Expand Down

0 comments on commit 52f9245

Please sign in to comment.