Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate Float shrinker to the ir #3899

Merged
merged 9 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
RELEASE_TYPE: patch

This patch starts work on refactoring our shrinker internals. There is no user-visible change.
67 changes: 64 additions & 3 deletions hypothesis-python/src/hypothesis/internal/conjecture/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,50 @@ class IRNode:
kwargs: IRKWargsType = attr.ib()
was_forced: bool = attr.ib()

def copy(self, *, with_value: IRType) -> "IRNode":
# we may want to allow this combination in the future, but for now it's
# a footgun.
assert not self.was_forced, "modifying a forced node doesn't make sense"
return IRNode(
ir_type=self.ir_type,
value=with_value,
kwargs=self.kwargs,
was_forced=self.was_forced,
)


def ir_value_permitted(value, ir_type, kwargs):
if ir_type == "integer":
if kwargs["min_value"] is not None and value < kwargs["min_value"]:
return False
if kwargs["max_value"] is not None and value > kwargs["max_value"]:
return False

return True
elif ir_type == "float":
if math.isnan(value):
return kwargs["allow_nan"]
return (
sign_aware_lte(kwargs["min_value"], value)
and sign_aware_lte(value, kwargs["max_value"])
) and not (0 < abs(value) < kwargs["smallest_nonzero_magnitude"])
elif ir_type == "string":
if len(value) < kwargs["min_size"]:
return False
if kwargs["max_size"] is not None and len(value) > kwargs["max_size"]:
return False
return all(ord(c) in kwargs["intervals"] for c in value)
elif ir_type == "bytes":
return len(value) == kwargs["size"]
elif ir_type == "boolean":
if value and kwargs["p"] <= 2 ** (-64):
return False
if not value and kwargs["p"] >= (1 - 2 ** (-64)):
return False
tybug marked this conversation as resolved.
Show resolved Hide resolved
return True

raise NotImplementedError(f"unhandled type {type(value)} of ir value {value}")


@dataclass_transform()
@attr.s(slots=True)
Expand Down Expand Up @@ -1991,8 +2035,8 @@ def draw_boolean(
p: float = 0.5,
*,
forced: Optional[bool] = None,
observe: bool = True,
fake_forced: bool = False,
observe: bool = True,
) -> bool:
# Internally, we treat probabilities lower than 1 / 2**64 as
# unconditionally false.
Expand Down Expand Up @@ -2049,9 +2093,26 @@ def _pooled_kwargs(self, ir_type, kwargs):

def _pop_ir_tree_node(self, ir_type: IRTypeName, kwargs: IRKWargsType) -> IRNode:
assert self.ir_tree_nodes is not None

if self.ir_tree_nodes == []:
self.mark_overrun()

node = self.ir_tree_nodes.pop(0)
assert node.ir_type == ir_type
assert kwargs == node.kwargs
# Unlike buffers, not every ir tree is a valid choice sequence. We
# don't have many options here beyond giving up when a modified tree
# becomes misaligned.
#
# For what it's worth, misaligned buffers — albeit valid — are rather
# unlikely to be *useful* buffers, so this isn't an enormous
# downgrade.
if node.ir_type != ir_type:
self.mark_overrun()
tybug marked this conversation as resolved.
Show resolved Hide resolved

# if a node has different kwargs (and so is misaligned), but has a value
# that is allowed by the expected kwargs, then we can coerce this node
# into an aligned one by using its value. It's unclear how useful this is.
if not ir_value_permitted(node.value, node.ir_type, kwargs):
self.mark_overrun()

return node
Zac-HD marked this conversation as resolved.
Show resolved Hide resolved

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
Status,
StringKWargs,
)
from hypothesis.internal.floats import count_between_floats, float_to_int, int_to_float
from hypothesis.internal.floats import (
count_between_floats,
float_to_int,
int_to_float,
sign_aware_lte,
)


class PreviouslyUnseenBehaviour(HypothesisException):
Expand Down Expand Up @@ -184,7 +189,35 @@ def compute_max_children(ir_type, kwargs):
return sum(len(intervals) ** k for k in range(min_size, max_size + 1))

elif ir_type == "float":
return count_between_floats(kwargs["min_value"], kwargs["max_value"])
min_value = kwargs["min_value"]
max_value = kwargs["max_value"]
smallest_nonzero_magnitude = kwargs["smallest_nonzero_magnitude"]

count = count_between_floats(min_value, max_value)

# we have two intervals:
# a. [min_value, max_value]
# b. [-smallest_nonzero_magnitude, smallest_nonzero_magnitude]
#
# which could be subsets (in either order), overlapping, or disjoint. We
# want the interval difference a - b.

# next_down because endpoints are ok with smallest_nonzero_magnitude
min_point = max(min_value, -flt.next_down(smallest_nonzero_magnitude))
max_point = min(max_value, flt.next_down(smallest_nonzero_magnitude))

if min_point > max_point:
# case: disjoint intervals.
return count

count -= count_between_floats(min_point, max_point)
if sign_aware_lte(min_value, -0.0) and sign_aware_lte(-0.0, max_value):
# account for -0.0
count += 1
if sign_aware_lte(min_value, 0.0) and sign_aware_lte(0.0, max_value):
# account for 0.0
count += 1
return count

raise NotImplementedError(f"unhandled ir_type {ir_type}")

Expand Down Expand Up @@ -247,16 +280,30 @@ def floats_between(a, b):

min_value = kwargs["min_value"]
max_value = kwargs["max_value"]
smallest_nonzero_magnitude = kwargs["smallest_nonzero_magnitude"]

# handle zeroes separately so smallest_nonzero_magnitude can think of
# itself as a complete interval (instead of a hole at ±0).
if sign_aware_lte(min_value, -0.0) and sign_aware_lte(-0.0, max_value):
yield -0.0
if sign_aware_lte(min_value, 0.0) and sign_aware_lte(0.0, max_value):
yield 0.0

if flt.is_negative(min_value):
if flt.is_negative(max_value):
# if both are negative, have to invert order
yield from floats_between(max_value, min_value)
# case: both negative.
max_point = min(max_value, -smallest_nonzero_magnitude)
# float_to_int increases as negative magnitude increases, so
# invert order.
yield from floats_between(max_point, min_value)
else:
yield from floats_between(-0.0, min_value)
yield from floats_between(0.0, max_value)
# case: straddles midpoint (which is between -0.0 and 0.0).
yield from floats_between(-smallest_nonzero_magnitude, min_value)
yield from floats_between(smallest_nonzero_magnitude, max_value)
else:
yield from floats_between(min_value, max_value)
# case: both positive.
min_point = max(min_value, smallest_nonzero_magnitude)
yield from floats_between(min_point, max_value)


@attr.s(slots=True)
Expand Down
67 changes: 37 additions & 30 deletions hypothesis-python/src/hypothesis/internal/conjecture/shrinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

import math
from collections import defaultdict
from typing import TYPE_CHECKING, Callable, Dict, Optional

Expand All @@ -19,14 +20,9 @@
prefix_selection_order,
random_selection_order,
)
from hypothesis.internal.conjecture.data import (
DRAW_FLOAT_LABEL,
ConjectureData,
ConjectureResult,
Status,
)
from hypothesis.internal.conjecture.data import ConjectureData, ConjectureResult, Status
from hypothesis.internal.conjecture.dfa import ConcreteDFA
from hypothesis.internal.conjecture.floats import float_to_lex, lex_to_float
from hypothesis.internal.conjecture.floats import is_simple
from hypothesis.internal.conjecture.junkdrawer import (
binary_search,
find_integer,
Expand Down Expand Up @@ -379,6 +375,12 @@ def calls(self):
test function."""
return self.engine.call_count

def consider_new_tree(self, tree):
data = ConjectureData.for_ir_tree(tree)
self.engine.test_function(data)

return self.consider_new_buffer(data.buffer)

def consider_new_buffer(self, buffer):
"""Returns True if after running this buffer the result would be
the current shrink_target."""
Expand Down Expand Up @@ -774,6 +776,10 @@ def buffer(self):
def blocks(self):
return self.shrink_target.blocks

@property
def nodes(self):
return self.shrink_target.examples.ir_tree_nodes

@property
def examples(self):
return self.shrink_target.examples
Expand Down Expand Up @@ -1207,31 +1213,32 @@ def minimize_floats(self, chooser):
anything particularly meaningful for non-float values.
"""

ex = chooser.choose(
self.examples,
lambda ex: (
ex.label == DRAW_FLOAT_LABEL
and len(ex.children) == 2
and ex.children[1].length == 8
),
node = chooser.choose(
self.nodes, lambda node: node.ir_type == "float" and not node.was_forced
)
# avoid shrinking integer-valued floats. In our current ordering, these
# are already simpler than all other floats, so it's better to shrink
# them in other passes.
if is_simple(node.value):
return
tybug marked this conversation as resolved.
Show resolved Hide resolved

u = ex.children[1].start
v = ex.children[1].end
buf = self.shrink_target.buffer
b = buf[u:v]
f = lex_to_float(int_from_bytes(b))
b2 = int_to_bytes(float_to_lex(f), 8)
if b == b2 or self.consider_new_buffer(buf[:u] + b2 + buf[v:]):
Float.shrink(
f,
lambda x: self.consider_new_buffer(
self.shrink_target.buffer[:u]
+ int_to_bytes(float_to_lex(x), 8)
+ self.shrink_target.buffer[v:]
),
random=self.random,
)
i = self.nodes.index(node)
# the Float shrinker was only built to handle positive floats. We'll
# shrink the positive portion and reapply the sign after, which is
# equivalent to this shrinker's previous behavior. We'll want to refactor
# Float to handle negative floats natively in the future. (likely a pure
# code quality change, with no shrinking impact.)
sign = math.copysign(1.0, node.value)
Float.shrink(
abs(node.value),
lambda val: self.consider_new_tree(
self.nodes[:i]
+ [node.copy(with_value=sign * val)]
+ self.nodes[i + 1 :]
),
random=self.random,
node=node,
)

@defines_shrink_pass()
def redistribute_block_pairs(self, chooser):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def incorporate(self, value):
def consider(self, value):
"""Returns True if make_immutable(value) == self.current after calling
self.incorporate(value)."""
self.debug(f"considering {value}")
value = self.make_immutable(value)
if value == self.current:
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import math
import sys

from hypothesis.internal.conjecture.data import ir_value_permitted
from hypothesis.internal.conjecture.floats import float_to_lex
from hypothesis.internal.conjecture.shrinking.common import Shrinker
from hypothesis.internal.conjecture.shrinking.integer import Integer
Expand All @@ -19,9 +20,20 @@


class Float(Shrinker):
def setup(self):
def setup(self, node):
self.NAN = math.nan
self.debugging_enabled = True
self.node = node

def consider(self, value):
min_value = self.node.kwargs["min_value"]
max_value = self.node.kwargs["max_value"]
if not ir_value_permitted(value, "float", self.node.kwargs):
self.debug(
f"rejecting {value} as out of bounds for [{min_value}, {max_value}]"
)
return False
tybug marked this conversation as resolved.
Show resolved Hide resolved
return super().consider(value)

def make_immutable(self, f):
f = float(f)
Expand Down
30 changes: 28 additions & 2 deletions hypothesis-python/tests/conjecture/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from hypothesis.internal.conjecture.engine import BUFFER_SIZE, ConjectureRunner
from hypothesis.internal.conjecture.utils import calc_label_from_name
from hypothesis.internal.entropy import deterministic_PRNG
from hypothesis.internal.floats import sign_aware_lte
from hypothesis.internal.floats import SMALLEST_SUBNORMAL, sign_aware_lte
from hypothesis.strategies._internal.strings import OneCharStringStrategy, TextStrategy

from tests.common.strategies import intervals
Expand Down Expand Up @@ -220,6 +220,8 @@ def draw_float_kwargs(
pivot = forced if (use_forced and not math.isnan(forced)) else None
min_value = -math.inf
max_value = math.inf
smallest_nonzero_magnitude = SMALLEST_SUBNORMAL
allow_nan = True if (use_forced and math.isnan(forced)) else draw(st.booleans())

if use_min_value:
min_value = draw(st.floats(max_value=pivot, allow_nan=False))
Expand All @@ -231,7 +233,31 @@ def draw_float_kwargs(
min_val = pivot if sign_aware_lte(min_value, pivot) else min_value
max_value = draw(st.floats(min_value=min_val, allow_nan=False))

return {"min_value": min_value, "max_value": max_value, "forced": forced}
largest_magnitude = max(abs(min_value), abs(max_value))
# can't force something smaller than our smallest magnitude.
if pivot is not None and pivot != 0.0:
largest_magnitude = min(largest_magnitude, pivot)

# avoid drawing from an empty range
if largest_magnitude > 0:
smallest_nonzero_magnitude = draw(
st.floats(
min_value=0,
# smallest_nonzero_magnitude breaks internal clamper invariants if
# it is allowed to be larger than the magnitude of {min, max}_value.
max_value=largest_magnitude,
allow_nan=False,
exclude_min=True,
allow_infinity=False,
)
)
return {
"min_value": min_value,
"max_value": max_value,
"forced": forced,
"allow_nan": allow_nan,
"smallest_nonzero_magnitude": smallest_nonzero_magnitude,
}


@st.composite
Expand Down
Loading
Loading