Skip to content

Commit

Permalink
Merge pull request #3895 from Zac-HD/more-rewrites
Browse files Browse the repository at this point in the history
Rewrite length filters for diverse collection types
  • Loading branch information
Zac-HD authored Feb 25, 2024
2 parents 3eea8c5 + aebcdd9 commit 7f5b065
Show file tree
Hide file tree
Showing 10 changed files with 165 additions and 59 deletions.
6 changes: 6 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
RELEASE_TYPE: patch

This patch implements filter-rewriting for most length filters on some
additional collection types (:issue:`3795`), and fixes several latent
bugs where unsatisfiable or partially-infeasible rewrites could trigger
internal errors.
4 changes: 2 additions & 2 deletions hypothesis-python/src/hypothesis/extra/ghostwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
from hypothesis.strategies._internal.lazy import LazyStrategy, unwrap_strategies
from hypothesis.strategies._internal.strategies import (
FilteredStrategy,
MappedSearchStrategy,
MappedStrategy,
OneOfStrategy,
SampledFromStrategy,
)
Expand Down Expand Up @@ -627,7 +627,7 @@ def _imports_for_strategy(strategy):
strategy = unwrap_strategies(strategy)

# Get imports for s.map(f), s.filter(f), s.flatmap(f), including both s and f
if isinstance(strategy, MappedSearchStrategy):
if isinstance(strategy, MappedStrategy):
imports |= _imports_for_strategy(strategy.mapped_strategy)
imports |= _imports_for_object(strategy.pack)
if isinstance(strategy, FilteredStrategy):
Expand Down
4 changes: 2 additions & 2 deletions hypothesis-python/src/hypothesis/extra/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from hypothesis.strategies._internal.numbers import Real
from hypothesis.strategies._internal.strategies import (
Ex,
MappedSearchStrategy,
MappedStrategy,
T,
check_strategy,
)
Expand Down Expand Up @@ -516,7 +516,7 @@ def arrays(
# If there's a redundant cast to the requested dtype, remove it. This unlocks
# optimizations such as fast unique sampled_from, and saves some time directly too.
unwrapped = unwrap_strategies(elements)
if isinstance(unwrapped, MappedSearchStrategy) and unwrapped.pack == dtype.type:
if isinstance(unwrapped, MappedStrategy) and unwrapped.pack == dtype.type:
elements = unwrapped.mapped_strategy
if isinstance(shape, int):
shape = (shape,)
Expand Down
13 changes: 10 additions & 3 deletions hypothesis-python/src/hypothesis/internal/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@

from hypothesis.internal.compat import ceil, floor
from hypothesis.internal.floats import next_down, next_up
from hypothesis.internal.reflection import extract_lambda_source
from hypothesis.internal.reflection import (
extract_lambda_source,
get_pretty_function_description,
)

Ex = TypeVar("Ex")
Predicate = Callable[[Ex], bool]
Expand Down Expand Up @@ -64,6 +67,10 @@ class ConstructivePredicate(NamedTuple):
def unchanged(cls, predicate: Predicate) -> "ConstructivePredicate":
return cls({}, predicate)

def __repr__(self) -> str:
fn = get_pretty_function_description(self.predicate)
return f"{self.__class__.__name__}(kwargs={self.kwargs!r}, predicate={fn})"


ARG = object()

Expand Down Expand Up @@ -147,8 +154,8 @@ def merge_preds(*con_predicates: ConstructivePredicate) -> ConstructivePredicate
elif kw["max_value"] == base["max_value"]:
base["exclude_max"] |= kw.get("exclude_max", False)

has_len = {"len" in kw for kw, _ in con_predicates}
assert len(has_len) == 1, "can't mix numeric with length constraints"
has_len = {"len" in kw for kw, _ in con_predicates if kw}
assert len(has_len) <= 1, "can't mix numeric with length constraints"
if has_len == {True}:
base["len"] = True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
T4,
T5,
Ex,
MappedSearchStrategy,
MappedStrategy,
SearchStrategy,
T,
check_strategy,
Expand Down Expand Up @@ -211,6 +211,9 @@ def filter(self, condition):
new = copy.copy(self)
new.min_size = max(self.min_size, kwargs.get("min_value", self.min_size))
new.max_size = min(self.max_size, kwargs.get("max_value", self.max_size))
# Unsatisfiable filters are easiest to understand without rewriting.
if new.min_size > new.max_size:
return SearchStrategy.filter(self, condition)
# Recompute average size; this is cheaper than making it into a property.
new.average_size = min(
max(new.min_size * 2, new.min_size + 5),
Expand Down Expand Up @@ -302,7 +305,7 @@ def do_draw(self, data):
return result


class FixedKeysDictStrategy(MappedSearchStrategy):
class FixedKeysDictStrategy(MappedStrategy):
"""A strategy which produces dicts with a fixed set of keys, given a
strategy for each of their equivalent values.
Expand All @@ -311,19 +314,19 @@ class FixedKeysDictStrategy(MappedSearchStrategy):
"""

def __init__(self, strategy_dict):
self.dict_type = type(strategy_dict)
dict_type = type(strategy_dict)
self.keys = tuple(strategy_dict.keys())
super().__init__(strategy=TupleStrategy(strategy_dict[k] for k in self.keys))
super().__init__(
strategy=TupleStrategy(strategy_dict[k] for k in self.keys),
pack=lambda value: dict_type(zip(self.keys, value)),
)

def calc_is_empty(self, recur):
return recur(self.mapped_strategy)

def __repr__(self):
return f"FixedKeysDictStrategy({self.keys!r}, {self.mapped_strategy!r})"

def pack(self, value):
return self.dict_type(zip(self.keys, value))


class FixedAndOptionalKeysDictStrategy(SearchStrategy):
"""A strategy which produces dicts with a fixed set of keys, given a
Expand Down
36 changes: 20 additions & 16 deletions hypothesis-python/src/hypothesis/strategies/_internal/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,21 @@ def unwrap_strategies(s):
assert unwrap_depth >= 0


def _repr_filter(condition):
return f".filter({get_pretty_function_description(condition)})"


class LazyStrategy(SearchStrategy):
"""A strategy which is defined purely by conversion to and from another
strategy.
Its parameter and distribution come from that other strategy.
"""

def __init__(self, function, args, kwargs, filters=(), *, force_repr=None):
def __init__(self, function, args, kwargs, *, transforms=(), force_repr=None):
super().__init__()
self.__wrapped_strategy = None
self.__representation = force_repr
self.function = function
self.__args = args
self.__kwargs = kwargs
self.__filters = filters
self._transformations = transforms

@property
def supports_find(self):
Expand Down Expand Up @@ -115,23 +111,28 @@ def wrapped_strategy(self):
self.__wrapped_strategy = self.function(
*unwrapped_args, **unwrapped_kwargs
)
for f in self.__filters:
self.__wrapped_strategy = self.__wrapped_strategy.filter(f)
for method, fn in self._transformations:
self.__wrapped_strategy = getattr(self.__wrapped_strategy, method)(fn)
return self.__wrapped_strategy

def filter(self, condition):
try:
repr_ = f"{self!r}{_repr_filter(condition)}"
except Exception:
repr_ = None
return LazyStrategy(
def __with_transform(self, method, fn):
repr_ = self.__representation
if repr_:
repr_ = f"{repr_}.{method}({get_pretty_function_description(fn)})"
return type(self)(
self.function,
self.__args,
self.__kwargs,
(*self.__filters, condition),
transforms=(*self._transformations, (method, fn)),
force_repr=repr_,
)

def map(self, pack):
return self.__with_transform("map", pack)

def filter(self, condition):
return self.__with_transform("filter", condition)

def do_validate(self):
w = self.wrapped_strategy
assert isinstance(w, SearchStrategy), f"{self!r} returned non-strategy {w!r}"
Expand All @@ -156,7 +157,10 @@ def __repr__(self):
}
self.__representation = repr_call(
self.function, _args, kwargs_for_repr, reorder=False
) + "".join(map(_repr_filter, self.__filters))
) + "".join(
f".{method}({get_pretty_function_description(fn)})"
for method, fn in self._transformations
)
return self.__representation

def do_draw(self, data):
Expand Down
84 changes: 68 additions & 16 deletions hypothesis-python/src/hypothesis/strategies/_internal/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import sys
import warnings
from collections import abc, defaultdict
from functools import lru_cache
from random import shuffle
from typing import (
Any,
Expand Down Expand Up @@ -60,7 +61,7 @@
calculating = UniqueIdentifier("calculating")

MAPPED_SEARCH_STRATEGY_DO_DRAW_LABEL = calc_label_from_name(
"another attempted draw in MappedSearchStrategy"
"another attempted draw in MappedStrategy"
)

FILTERED_SEARCH_STRATEGY_DO_DRAW_LABEL = calc_label_from_name(
Expand Down Expand Up @@ -346,7 +347,7 @@ def map(self, pack: Callable[[Ex], T]) -> "SearchStrategy[T]":
"""
if is_identity_function(pack):
return self # type: ignore # Mypy has no way to know that `Ex == T`
return MappedSearchStrategy(pack=pack, strategy=self)
return MappedStrategy(self, pack=pack)

def flatmap(
self, expand: Callable[[Ex], "SearchStrategy[T]"]
Expand Down Expand Up @@ -468,9 +469,6 @@ class SampledFromStrategy(SearchStrategy):
"""A strategy which samples from a set of elements. This is essentially
equivalent to using a OneOfStrategy over Just strategies but may be more
efficient and convenient.
The conditional distribution chooses uniformly at random from some
non-empty subset of the elements.
"""

_MAX_FILTER_CALLS = 10_000
Expand Down Expand Up @@ -521,7 +519,10 @@ def _transform(self, element):
# Used in UniqueSampledListStrategy
for name, f in self._transformations:
if name == "map":
element = f(element)
result = f(element)
if build_context := _current_build_context.value:
build_context.record_call(result, f, [element], {})
element = result
else:
assert name == "filter"
if not f(element):
Expand Down Expand Up @@ -794,18 +795,17 @@ def one_of(
return OneOfStrategy(args)


class MappedSearchStrategy(SearchStrategy[Ex]):
class MappedStrategy(SearchStrategy[Ex]):
"""A strategy which is defined purely by conversion to and from another
strategy.
Its parameter and distribution come from that other strategy.
"""

def __init__(self, strategy, pack=None):
def __init__(self, strategy, pack):
super().__init__()
self.mapped_strategy = strategy
if pack is not None:
self.pack = pack
self.pack = pack

def calc_is_empty(self, recur):
return recur(self.mapped_strategy)
Expand All @@ -821,11 +821,6 @@ def __repr__(self):
def do_validate(self):
self.mapped_strategy.validate()

def pack(self, x):
"""Take a value produced by the underlying mapped_strategy and turn it
into a value suitable for outputting from this strategy."""
raise NotImplementedError(f"{self.__class__.__name__}.pack()")

def do_draw(self, data: ConjectureData) -> Any:
with warnings.catch_warnings():
if isinstance(self.pack, type) and issubclass(
Expand All @@ -847,10 +842,67 @@ def do_draw(self, data: ConjectureData) -> Any:
@property
def branches(self) -> List[SearchStrategy[Ex]]:
return [
MappedSearchStrategy(pack=self.pack, strategy=strategy)
MappedStrategy(strategy, pack=self.pack)
for strategy in self.mapped_strategy.branches
]

def filter(self, condition: Callable[[Ex], Any]) -> "SearchStrategy[Ex]":
# Includes a special case so that we can rewrite filters on collection
# lengths, when most collections are `st.lists(...).map(the_type)`.
ListStrategy = _list_strategy_type()
if not isinstance(self.mapped_strategy, ListStrategy) or not (
(isinstance(self.pack, type) and issubclass(self.pack, abc.Collection))
or self.pack in _collection_ish_functions()
):
return super().filter(condition)

# Check whether our inner list strategy can rewrite this filter condition.
# If not, discard the result and _only_ apply a new outer filter.
new = ListStrategy.filter(self.mapped_strategy, condition)
if getattr(new, "filtered_strategy", None) is self.mapped_strategy:
return super().filter(condition) # didn't rewrite

# Apply a new outer filter even though we rewrote the inner strategy,
# because some collections can change the list length (dict, set, etc).
return FilteredStrategy(type(self)(new, self.pack), conditions=(condition,))


@lru_cache
def _list_strategy_type():
from hypothesis.strategies._internal.collections import ListStrategy

return ListStrategy


def _collection_ish_functions():
funcs = [sorted]
if np := sys.modules.get("numpy"):
# c.f. https://numpy.org/doc/stable/reference/routines.array-creation.html
# Probably only `np.array` and `np.asarray` will be used in practice,
# but why should that stop us when we've already gone this far?
funcs += [
np.empty_like,
np.eye,
np.identity,
np.ones_like,
np.zeros_like,
np.array,
np.asarray,
np.asanyarray,
np.ascontiguousarray,
np.asmatrix,
np.copy,
np.rec.array,
np.rec.fromarrays,
np.rec.fromrecords,
np.diag,
# bonus undocumented functions from tab-completion:
np.asarray_chkfinite,
np.asfarray,
np.asfortranarray,
]
return funcs


filter_not_satisfied = UniqueIdentifier("filter not satisfied")

Expand Down
2 changes: 1 addition & 1 deletion hypothesis-python/tests/conjecture/test_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def test_can_override_label():
def test_will_mark_too_deep_examples_as_invalid():
d = ConjectureData.for_buffer(bytes(0))

s = st.none()
s = st.integers()
for _ in range(MAX_DEPTH + 1):
s = s.map(lambda x: None)

Expand Down
2 changes: 1 addition & 1 deletion hypothesis-python/tests/cover/test_custom_reprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class Bar(Foo):

def test_reprs_as_created():
@given(foo=st.builds(Foo), bar=st.from_type(Bar), baz=st.none().map(Foo))
@settings(print_blob=False, max_examples=10_000)
@settings(print_blob=False, max_examples=10_000, derandomize=True)
def inner(foo, bar, baz):
assert baz.x is None
assert foo.x <= 0 or bar.x >= 0
Expand Down
Loading

0 comments on commit 7f5b065

Please sign in to comment.