From 13aa400e6dbe1a71ae25122c59dcf66dbd464b6b Mon Sep 17 00:00:00 2001 From: Zac Hatfield-Dodds Date: Sun, 25 Feb 2024 01:16:44 -0800 Subject: [PATCH] MappedSearchStrategy cleanup --- .../src/hypothesis/extra/ghostwriter.py | 4 ++-- .../src/hypothesis/extra/numpy.py | 4 ++-- .../strategies/_internal/collections.py | 14 ++++++------- .../strategies/_internal/strategies.py | 21 ++++++------------- 4 files changed, 17 insertions(+), 26 deletions(-) diff --git a/hypothesis-python/src/hypothesis/extra/ghostwriter.py b/hypothesis-python/src/hypothesis/extra/ghostwriter.py index 8917d5bd87..2854b48c29 100644 --- a/hypothesis-python/src/hypothesis/extra/ghostwriter.py +++ b/hypothesis-python/src/hypothesis/extra/ghostwriter.py @@ -122,7 +122,7 @@ from hypothesis.strategies._internal.lazy import LazyStrategy, unwrap_strategies from hypothesis.strategies._internal.strategies import ( FilteredStrategy, - MappedSearchStrategy, + MappedStrategy, OneOfStrategy, SampledFromStrategy, ) @@ -627,7 +627,7 @@ def _imports_for_strategy(strategy): strategy = unwrap_strategies(strategy) # Get imports for s.map(f), s.filter(f), s.flatmap(f), including both s and f - if isinstance(strategy, MappedSearchStrategy): + if isinstance(strategy, MappedStrategy): imports |= _imports_for_strategy(strategy.mapped_strategy) imports |= _imports_for_object(strategy.pack) if isinstance(strategy, FilteredStrategy): diff --git a/hypothesis-python/src/hypothesis/extra/numpy.py b/hypothesis-python/src/hypothesis/extra/numpy.py index 29d73f76be..4cfb1ca8d8 100644 --- a/hypothesis-python/src/hypothesis/extra/numpy.py +++ b/hypothesis-python/src/hypothesis/extra/numpy.py @@ -50,7 +50,7 @@ from hypothesis.strategies._internal.numbers import Real from hypothesis.strategies._internal.strategies import ( Ex, - MappedSearchStrategy, + MappedStrategy, T, check_strategy, ) @@ -516,7 +516,7 @@ def arrays( # If there's a redundant cast to the requested dtype, remove it. This unlocks # optimizations such as fast unique sampled_from, and saves some time directly too. unwrapped = unwrap_strategies(elements) - if isinstance(unwrapped, MappedSearchStrategy) and unwrapped.pack == dtype.type: + if isinstance(unwrapped, MappedStrategy) and unwrapped.pack == dtype.type: elements = unwrapped.mapped_strategy if isinstance(shape, int): shape = (shape,) diff --git a/hypothesis-python/src/hypothesis/strategies/_internal/collections.py b/hypothesis-python/src/hypothesis/strategies/_internal/collections.py index e8f8f21ba4..8311661b87 100644 --- a/hypothesis-python/src/hypothesis/strategies/_internal/collections.py +++ b/hypothesis-python/src/hypothesis/strategies/_internal/collections.py @@ -23,7 +23,7 @@ T4, T5, Ex, - MappedSearchStrategy, + MappedStrategy, SearchStrategy, T, check_strategy, @@ -302,7 +302,7 @@ def do_draw(self, data): return result -class FixedKeysDictStrategy(MappedSearchStrategy): +class FixedKeysDictStrategy(MappedStrategy): """A strategy which produces dicts with a fixed set of keys, given a strategy for each of their equivalent values. @@ -311,9 +311,12 @@ class FixedKeysDictStrategy(MappedSearchStrategy): """ def __init__(self, strategy_dict): - self.dict_type = type(strategy_dict) + dict_type = type(strategy_dict) self.keys = tuple(strategy_dict.keys()) - super().__init__(strategy=TupleStrategy(strategy_dict[k] for k in self.keys)) + super().__init__( + strategy=TupleStrategy(strategy_dict[k] for k in self.keys), + pack=lambda value: dict_type(zip(self.keys, value)), + ) def calc_is_empty(self, recur): return recur(self.mapped_strategy) @@ -321,9 +324,6 @@ def calc_is_empty(self, recur): def __repr__(self): return f"FixedKeysDictStrategy({self.keys!r}, {self.mapped_strategy!r})" - def pack(self, value): - return self.dict_type(zip(self.keys, value)) - class FixedAndOptionalKeysDictStrategy(SearchStrategy): """A strategy which produces dicts with a fixed set of keys, given a diff --git a/hypothesis-python/src/hypothesis/strategies/_internal/strategies.py b/hypothesis-python/src/hypothesis/strategies/_internal/strategies.py index af2fa72937..75beb17ac6 100644 --- a/hypothesis-python/src/hypothesis/strategies/_internal/strategies.py +++ b/hypothesis-python/src/hypothesis/strategies/_internal/strategies.py @@ -60,7 +60,7 @@ calculating = UniqueIdentifier("calculating") MAPPED_SEARCH_STRATEGY_DO_DRAW_LABEL = calc_label_from_name( - "another attempted draw in MappedSearchStrategy" + "another attempted draw in MappedStrategy" ) FILTERED_SEARCH_STRATEGY_DO_DRAW_LABEL = calc_label_from_name( @@ -346,7 +346,7 @@ def map(self, pack: Callable[[Ex], T]) -> "SearchStrategy[T]": """ if is_identity_function(pack): return self # type: ignore # Mypy has no way to know that `Ex == T` - return MappedSearchStrategy(pack=pack, strategy=self) + return MappedStrategy(self, pack=pack) def flatmap( self, expand: Callable[[Ex], "SearchStrategy[T]"] @@ -468,9 +468,6 @@ class SampledFromStrategy(SearchStrategy): """A strategy which samples from a set of elements. This is essentially equivalent to using a OneOfStrategy over Just strategies but may be more efficient and convenient. - - The conditional distribution chooses uniformly at random from some - non-empty subset of the elements. """ _MAX_FILTER_CALLS = 10_000 @@ -794,18 +791,17 @@ def one_of( return OneOfStrategy(args) -class MappedSearchStrategy(SearchStrategy[Ex]): +class MappedStrategy(SearchStrategy[Ex]): """A strategy which is defined purely by conversion to and from another strategy. Its parameter and distribution come from that other strategy. """ - def __init__(self, strategy, pack=None): + def __init__(self, strategy, pack): super().__init__() self.mapped_strategy = strategy - if pack is not None: - self.pack = pack + self.pack = pack def calc_is_empty(self, recur): return recur(self.mapped_strategy) @@ -821,11 +817,6 @@ def __repr__(self): def do_validate(self): self.mapped_strategy.validate() - def pack(self, x): - """Take a value produced by the underlying mapped_strategy and turn it - into a value suitable for outputting from this strategy.""" - raise NotImplementedError(f"{self.__class__.__name__}.pack()") - def do_draw(self, data: ConjectureData) -> Any: with warnings.catch_warnings(): if isinstance(self.pack, type) and issubclass( @@ -847,7 +838,7 @@ def do_draw(self, data: ConjectureData) -> Any: @property def branches(self) -> List[SearchStrategy[Ex]]: return [ - MappedSearchStrategy(pack=self.pack, strategy=strategy) + MappedStrategy(strategy, pack=self.pack) for strategy in self.mapped_strategy.branches ]