Skip to content

Commit

Permalink
MappedSearchStrategy cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Zac-HD committed Feb 25, 2024
1 parent 3eea8c5 commit 13aa400
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 26 deletions.
4 changes: 2 additions & 2 deletions hypothesis-python/src/hypothesis/extra/ghostwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
from hypothesis.strategies._internal.lazy import LazyStrategy, unwrap_strategies
from hypothesis.strategies._internal.strategies import (
FilteredStrategy,
MappedSearchStrategy,
MappedStrategy,
OneOfStrategy,
SampledFromStrategy,
)
Expand Down Expand Up @@ -627,7 +627,7 @@ def _imports_for_strategy(strategy):
strategy = unwrap_strategies(strategy)

# Get imports for s.map(f), s.filter(f), s.flatmap(f), including both s and f
if isinstance(strategy, MappedSearchStrategy):
if isinstance(strategy, MappedStrategy):
imports |= _imports_for_strategy(strategy.mapped_strategy)
imports |= _imports_for_object(strategy.pack)
if isinstance(strategy, FilteredStrategy):
Expand Down
4 changes: 2 additions & 2 deletions hypothesis-python/src/hypothesis/extra/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from hypothesis.strategies._internal.numbers import Real
from hypothesis.strategies._internal.strategies import (
Ex,
MappedSearchStrategy,
MappedStrategy,
T,
check_strategy,
)
Expand Down Expand Up @@ -516,7 +516,7 @@ def arrays(
# If there's a redundant cast to the requested dtype, remove it. This unlocks
# optimizations such as fast unique sampled_from, and saves some time directly too.
unwrapped = unwrap_strategies(elements)
if isinstance(unwrapped, MappedSearchStrategy) and unwrapped.pack == dtype.type:
if isinstance(unwrapped, MappedStrategy) and unwrapped.pack == dtype.type:
elements = unwrapped.mapped_strategy
if isinstance(shape, int):
shape = (shape,)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
T4,
T5,
Ex,
MappedSearchStrategy,
MappedStrategy,
SearchStrategy,
T,
check_strategy,
Expand Down Expand Up @@ -302,7 +302,7 @@ def do_draw(self, data):
return result


class FixedKeysDictStrategy(MappedSearchStrategy):
class FixedKeysDictStrategy(MappedStrategy):
"""A strategy which produces dicts with a fixed set of keys, given a
strategy for each of their equivalent values.
Expand All @@ -311,19 +311,19 @@ class FixedKeysDictStrategy(MappedSearchStrategy):
"""

def __init__(self, strategy_dict):
self.dict_type = type(strategy_dict)
dict_type = type(strategy_dict)
self.keys = tuple(strategy_dict.keys())
super().__init__(strategy=TupleStrategy(strategy_dict[k] for k in self.keys))
super().__init__(
strategy=TupleStrategy(strategy_dict[k] for k in self.keys),
pack=lambda value: dict_type(zip(self.keys, value)),
)

def calc_is_empty(self, recur):
return recur(self.mapped_strategy)

def __repr__(self):
return f"FixedKeysDictStrategy({self.keys!r}, {self.mapped_strategy!r})"

def pack(self, value):
return self.dict_type(zip(self.keys, value))


class FixedAndOptionalKeysDictStrategy(SearchStrategy):
"""A strategy which produces dicts with a fixed set of keys, given a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
calculating = UniqueIdentifier("calculating")

MAPPED_SEARCH_STRATEGY_DO_DRAW_LABEL = calc_label_from_name(
"another attempted draw in MappedSearchStrategy"
"another attempted draw in MappedStrategy"
)

FILTERED_SEARCH_STRATEGY_DO_DRAW_LABEL = calc_label_from_name(
Expand Down Expand Up @@ -346,7 +346,7 @@ def map(self, pack: Callable[[Ex], T]) -> "SearchStrategy[T]":
"""
if is_identity_function(pack):
return self # type: ignore # Mypy has no way to know that `Ex == T`
return MappedSearchStrategy(pack=pack, strategy=self)
return MappedStrategy(self, pack=pack)

def flatmap(
self, expand: Callable[[Ex], "SearchStrategy[T]"]
Expand Down Expand Up @@ -468,9 +468,6 @@ class SampledFromStrategy(SearchStrategy):
"""A strategy which samples from a set of elements. This is essentially
equivalent to using a OneOfStrategy over Just strategies but may be more
efficient and convenient.
The conditional distribution chooses uniformly at random from some
non-empty subset of the elements.
"""

_MAX_FILTER_CALLS = 10_000
Expand Down Expand Up @@ -794,18 +791,17 @@ def one_of(
return OneOfStrategy(args)


class MappedSearchStrategy(SearchStrategy[Ex]):
class MappedStrategy(SearchStrategy[Ex]):
"""A strategy which is defined purely by conversion to and from another
strategy.
Its parameter and distribution come from that other strategy.
"""

def __init__(self, strategy, pack=None):
def __init__(self, strategy, pack):
super().__init__()
self.mapped_strategy = strategy
if pack is not None:
self.pack = pack
self.pack = pack

def calc_is_empty(self, recur):
return recur(self.mapped_strategy)
Expand All @@ -821,11 +817,6 @@ def __repr__(self):
def do_validate(self):
self.mapped_strategy.validate()

def pack(self, x):
"""Take a value produced by the underlying mapped_strategy and turn it
into a value suitable for outputting from this strategy."""
raise NotImplementedError(f"{self.__class__.__name__}.pack()")

def do_draw(self, data: ConjectureData) -> Any:
with warnings.catch_warnings():
if isinstance(self.pack, type) and issubclass(
Expand All @@ -847,7 +838,7 @@ def do_draw(self, data: ConjectureData) -> Any:
@property
def branches(self) -> List[SearchStrategy[Ex]]:
return [
MappedSearchStrategy(pack=self.pack, strategy=strategy)
MappedStrategy(strategy, pack=self.pack)
for strategy in self.mapped_strategy.branches
]

Expand Down

0 comments on commit 13aa400

Please sign in to comment.