Skip to content

Commit

Permalink
Improve type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
Zac-HD committed Oct 5, 2024
1 parent 2f611af commit 09308df
Show file tree
Hide file tree
Showing 17 changed files with 95 additions and 66 deletions.
26 changes: 18 additions & 8 deletions hypothesis-python/src/_hypothesis_ftz_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@

import importlib
import sys
from typing import TYPE_CHECKING, Callable, Optional, Set, Tuple

if TYPE_CHECKING:
from multiprocessing import Queue
from typing import TypeAlias

FTZCulprits: "TypeAlias" = Tuple[Optional[bool], Set[str]]


KNOWN_EVER_CULPRITS = (
# https://moyix.blogspot.com/2022/09/someones-been-messing-with-my-subnormals.html
Expand All @@ -35,32 +43,34 @@
)


def flush_to_zero():
def flush_to_zero() -> bool:
# If this subnormal number compares equal to zero we have a problem
return 2.0**-1073 == 0


def run_in_process(fn, *args):
def run_in_process(fn: Callable[..., FTZCulprits], *args: object) -> FTZCulprits:
import multiprocessing as mp

mp.set_start_method("spawn", force=True)
q = mp.Queue()
q: "Queue[FTZCulprits]" = mp.Queue()
p = mp.Process(target=target, args=(q, fn, *args))
p.start()
retval = q.get()
p.join()
return retval


def target(q, fn, *args):
def target(
q: "Queue[FTZCulprits]", fn: Callable[..., FTZCulprits], *args: object
) -> None:
q.put(fn(*args))


def always_imported_modules():
def always_imported_modules() -> FTZCulprits:
return flush_to_zero(), set(sys.modules)


def modules_imported_by(mod):
def modules_imported_by(mod: str) -> FTZCulprits:
"""Return the set of modules imported transitively by mod."""
before = set(sys.modules)
try:
Expand All @@ -77,7 +87,7 @@ def modules_imported_by(mod):
CHECKED_CACHE = set()


def identify_ftz_culprits():
def identify_ftz_culprits() -> str:
"""Find the modules in sys.modules which cause "mod" to be imported."""
# If we've run this function before, return the same result.
global KNOWN_FTZ
Expand All @@ -94,7 +104,7 @@ def identify_ftz_culprits():
# that importing them in a new process sets the FTZ state. As a heuristic, we'll
# start with packages known to have ever enabled FTZ, then top-level packages as
# a way to eliminate large fractions of the search space relatively quickly.
def key(name):
def key(name: str) -> Tuple[bool, int, str]:
"""Prefer known-FTZ modules, then top-level packages, then alphabetical."""
return (name not in KNOWN_EVER_CULPRITS, name.count("."), name)

Expand Down
6 changes: 3 additions & 3 deletions hypothesis-python/src/hypothesis/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@
import sys
import warnings
from pathlib import Path
from typing import Union

import _hypothesis_globals

from hypothesis.errors import HypothesisSideeffectWarning

__hypothesis_home_directory_default = Path.cwd() / ".hypothesis"

__hypothesis_home_directory = None


def set_hypothesis_home_dir(directory):
def set_hypothesis_home_dir(directory: Union[str, Path, None]) -> None:
global __hypothesis_home_directory
__hypothesis_home_directory = None if directory is None else Path(directory)


def storage_directory(*names, intent_to_write=True):
def storage_directory(*names: str, intent_to_write: bool = True) -> Path:
if intent_to_write:
check_sideeffect_during_initialization(
"accessing storage for {}", "/".join(names)
Expand Down
4 changes: 2 additions & 2 deletions hypothesis-python/src/hypothesis/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class StopTest(BaseException):
the Hypothesis engine, which should then continue normally.
"""

def __init__(self, testcounter):
def __init__(self, testcounter: int) -> None:
super().__init__(repr(testcounter))
self.testcounter = testcounter

Expand All @@ -230,7 +230,7 @@ class Found(HypothesisException):
class RewindRecursive(Exception):
"""Signal that the type inference should be rewound due to recursive types. Internal use only."""

def __init__(self, target):
def __init__(self, target: object) -> None:
self.target = target


Expand Down
2 changes: 1 addition & 1 deletion hypothesis-python/src/hypothesis/internal/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ class LRUCache:
# Anecdotally, OrderedDict seems quite competitive with lru_cache, but perhaps
# that is localized to our access patterns.

def __init__(self, max_size):
def __init__(self, max_size: int) -> None:
assert max_size > 0
self.max_size = max_size
self._threadlocal = threading.local()
Expand Down
2 changes: 1 addition & 1 deletion hypothesis-python/src/hypothesis/internal/cathetus.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sys import float_info


def cathetus(h, a):
def cathetus(h: float, a: float) -> float:
"""Given the lengths of the hypotenuse and a side of a right triangle,
return the length of the other side.
Expand Down
2 changes: 1 addition & 1 deletion hypothesis-python/src/hypothesis/internal/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def int_to_byte(i: int) -> bytes:
return bytes([i])


def is_typed_named_tuple(cls):
def is_typed_named_tuple(cls: type) -> bool:
"""Return True if cls is probably a subtype of `typing.NamedTuple`.
Unfortunately types created with `class T(NamedTuple):` actually
Expand Down
47 changes: 27 additions & 20 deletions hypothesis-python/src/hypothesis/internal/conjecture/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def label(self) -> int:
return self.owner.labels[self.owner.label_indices[self.index]]

@property
def parent(self):
def parent(self) -> Optional[int]:
"""The index of the example that this one is nested directly within."""
if self.index == 0:
return None
Expand All @@ -298,13 +298,13 @@ def ir_end(self) -> int:
return self.owner.ir_ends[self.index]

@property
def depth(self):
def depth(self) -> int:
"""Depth of this example in the example tree. The top-level example has a
depth of 0."""
return self.owner.depths[self.index]

@property
def trivial(self):
def trivial(self) -> bool:
"""An example is "trivial" if it only contains forced bytes and zero bytes.
All examples start out as trivial, and then get marked non-trivial when
we see a byte that is neither forced nor zero."""
Expand Down Expand Up @@ -352,6 +352,7 @@ def __init__(self, examples: "Examples"):
self.example_count = 0
self.block_count = 0
self.ir_node_count = 0
self.result: Any = None

def run(self) -> Any:
"""Rerun the test case with this visitor and return the
Expand Down Expand Up @@ -425,7 +426,7 @@ def calculated_example_property(cls: Type[ExampleProperty]) -> Any:
name = cls.__name__
cache_name = "__" + name

def lazy_calculate(self: "Examples") -> IntList:
def lazy_calculate(self: "Examples") -> Any:
result = getattr(self, cache_name, None)
if result is None:
result = cls(self).run()
Expand Down Expand Up @@ -465,7 +466,14 @@ def __init__(self) -> None:
def freeze(self) -> None:
self.__index_of_labels = None

def record_ir_draw(self, ir_type, value, *, kwargs, was_forced):
def record_ir_draw(
self,
ir_type: IRTypeName,
value: IRType,
*,
kwargs: IRKWargsType,
was_forced: bool,
) -> None:
self.trail.append(IR_NODE_RECORD)
node = IRNode(
ir_type=ir_type,
Expand Down Expand Up @@ -517,7 +525,7 @@ def __init__(self, record: ExampleRecord, blocks: "Blocks") -> None:
self.__children: "Optional[List[Sequence[int]]]" = None

class _starts_and_ends(ExampleProperty):
def begin(self):
def begin(self) -> None:
self.starts = IntList.of_length(len(self.examples))
self.ends = IntList.of_length(len(self.examples))

Expand All @@ -543,7 +551,7 @@ def ends(self) -> IntList:
return self.starts_and_ends[1]

class _ir_starts_and_ends(ExampleProperty):
def begin(self):
def begin(self) -> None:
self.starts = IntList.of_length(len(self.examples))
self.ends = IntList.of_length(len(self.examples))

Expand All @@ -570,7 +578,7 @@ def ir_ends(self) -> IntList:

class _discarded(ExampleProperty):
def begin(self) -> None:
self.result: "Set[int]" = set() # type: ignore # IntList in parent class
self.result: "Set[int]" = set()

def finish(self) -> FrozenSet[int]:
return frozenset(self.result)
Expand All @@ -584,7 +592,7 @@ def stop_example(self, i: int, *, discarded: bool) -> None:
class _trivial(ExampleProperty):
def begin(self) -> None:
self.nontrivial = IntList.of_length(len(self.examples))
self.result: "Set[int]" = set() # type: ignore # IntList in parent class
self.result: "Set[int]" = set()

def block(self, i: int) -> None:
if not self.examples.blocks.trivial(i):
Expand All @@ -610,7 +618,7 @@ def stop_example(self, i: int, *, discarded: bool) -> None:
parentage: IntList = calculated_example_property(_parentage)

class _depths(ExampleProperty):
def begin(self):
def begin(self) -> None:
self.result = IntList.of_length(len(self.examples))

def start_example(self, i: int, label_index: int) -> None:
Expand All @@ -619,10 +627,10 @@ def start_example(self, i: int, label_index: int) -> None:
depths: IntList = calculated_example_property(_depths)

class _ir_tree_nodes(ExampleProperty):
def begin(self):
def begin(self) -> None:
self.result = []

def ir_node(self, ir_node):
def ir_node(self, ir_node: "IRNode") -> None:
self.result.append(ir_node)

ir_tree_nodes: "List[IRNode]" = calculated_example_property(_ir_tree_nodes)
Expand Down Expand Up @@ -788,7 +796,7 @@ def all_bounds(self) -> Iterable[Tuple[int, int]]:
prev = e

@property
def last_block_length(self):
def last_block_length(self) -> int:
return self.end(-1) - self.start(-1)

def __len__(self) -> int:
Expand Down Expand Up @@ -869,7 +877,7 @@ def __getitem__(self, i: int) -> Block:

return result

def __check_completion(self):
def __check_completion(self) -> None:
"""The list of blocks is complete if we have created every ``Block``
object that we currently good and know that no more will be created.
Expand Down Expand Up @@ -899,7 +907,7 @@ def __repr__(self) -> str:
class _Overrun:
status = Status.OVERRUN

def __repr__(self):
def __repr__(self) -> str:
return "Overrun"


Expand Down Expand Up @@ -1052,7 +1060,7 @@ def __eq__(self, other):
and self.was_forced == other.was_forced
)

def __hash__(self):
def __hash__(self) -> int:
return hash(
(
self.ir_type,
Expand All @@ -1062,7 +1070,7 @@ def __hash__(self):
)
)

def __repr__(self):
def __repr__(self) -> str:
# repr to avoid "BytesWarning: str() on a bytes instance" for bytes nodes
forced_marker = " [forced]" if self.was_forced else ""
return f"{self.ir_type} {self.value!r}{forced_marker} {self.kwargs!r}"
Expand Down Expand Up @@ -1911,8 +1919,7 @@ def _compute_draw_float_init_logic(
"writeup - and good luck!"
)

def permitted(f):
assert isinstance(f, float)
def permitted(f: float) -> bool:
if math.isnan(f):
return allow_nan
if 0 < abs(f) < smallest_nonzero_magnitude:
Expand Down Expand Up @@ -2080,7 +2087,7 @@ def __init__(
self._node_index = 0
self.start_example(TOP_LABEL)

def __repr__(self):
def __repr__(self) -> str:
return "ConjectureData(%s, %d bytes%s)" % (
self.status.name,
len(self.buffer),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class CallStats(TypedDict):
"shrink-phase": NotRequired[PhaseStatistics],
"stopped-because": NotRequired[str],
"targets": NotRequired[Dict[str, float]],
"nodeid": NotRequired[str],
},
)

Expand Down
2 changes: 1 addition & 1 deletion hypothesis-python/src/hypothesis/internal/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from types import MethodType


def is_hypothesis_test(test):
def is_hypothesis_test(test: object) -> bool:
if isinstance(test, MethodType):
return is_hypothesis_test(test.__func__)
return getattr(test, "is_hypothesis_test", False)
4 changes: 2 additions & 2 deletions hypothesis-python/src/hypothesis/internal/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,9 @@ def get_float_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
return ConstructivePredicate(kwargs, predicate)


def max_len(size: int, element: Collection) -> bool:
def max_len(size: int, element: Collection[object]) -> bool:
return len(element) <= size


def min_len(size: int, element: Collection) -> bool:
def min_len(size: int, element: Collection[object]) -> bool:
return size <= len(element)
12 changes: 6 additions & 6 deletions hypothesis-python/src/hypothesis/internal/scrutineer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from functools import lru_cache, reduce
from os import sep
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple

from hypothesis._settings import Phase, Verbosity
from hypothesis.internal.compat import PYPY
Expand All @@ -35,7 +35,7 @@


@lru_cache(maxsize=None)
def should_trace_file(fname):
def should_trace_file(fname: str) -> bool:
# fname.startswith("<") indicates runtime code-generation via compile,
# e.g. compile("def ...", "<string>", "exec") in e.g. attrs methods.
return not (is_hypothesis_file(fname) or fname.startswith("<"))
Expand All @@ -55,12 +55,12 @@ class Tracer:

__slots__ = ("branches", "_previous_location")

def __init__(self):
def __init__(self) -> None:
self.branches: Trace = set()
self._previous_location = None
self._previous_location: Optional[Location] = None

@staticmethod
def can_trace():
def can_trace() -> bool:
return (
(sys.version_info[:2] < (3, 12) and sys.gettrace() is None)
or (
Expand Down Expand Up @@ -138,7 +138,7 @@ def __exit__(self, *args, **kwargs):
)


def _glob_to_re(locs):
def _glob_to_re(locs: Iterable[str]) -> str:
"""Translate a list of glob patterns to a combined regular expression.
Only the * wildcard is supported, and patterns including special
characters will only work by chance."""
Expand Down
Loading

0 comments on commit 09308df

Please sign in to comment.