Skip to content

Commit

Permalink
Moved print_bindings into storage.py
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Feb 25, 2024
1 parent d7fd59a commit 28ad527
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 57 deletions.
3 changes: 2 additions & 1 deletion jaxtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@
set_array_name_format as set_array_name_format,
)
from ._config import config as config
from ._decorator import jaxtyped as jaxtyped, print_bindings as print_bindings
from ._decorator import jaxtyped as jaxtyped
from ._errors import (
AnnotationError as AnnotationError,
TypeCheckError as TypeCheckError,
)
from ._import_hook import install_import_hook as install_import_hook
from ._ipython_extension import load_ipython_extension as load_ipython_extension
from ._storage import print_bindings as print_bindings


# Now import Array and ArrayLike
Expand Down
60 changes: 4 additions & 56 deletions jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from ._config import config
from ._errors import AnnotationError, TypeCheckError
from ._storage import get_shape_memo, pop_shape_memo, push_shape_memo
from ._storage import pop_shape_memo, push_shape_memo, shape_str


class _Sentinel:
Expand Down Expand Up @@ -319,7 +319,7 @@ def wrapped_fn(*args, **kwargs): # pyright: ignore
return fn(*args, **kwargs)
except Exception as e:
if sys.version_info >= (3, 11) and _no_jaxtyping_note(e):
shape_info = _exc_shape_info(memos)
shape_info = shape_str(memos)
if shape_info != "":
msg = (
"The preceding error occurred within the scope of a "
Expand Down Expand Up @@ -411,7 +411,7 @@ def wrapped_fn(*args, **kwargs):
"----------------------\n"
f"Called with parameters: {param_values}\n"
f"Parameter annotations: {param_hints}.\n"
+ _exc_shape_info(memos)
+ shape_str(memos)
)
if config.jaxtyping_remove_typechecker_stack:
raise TypeCheckError(msg) from None
Expand Down Expand Up @@ -464,7 +464,7 @@ def wrapped_fn(*args, **kwargs):
"----------------------\n"
f"Called with parameters: {param_values}\n"
f"Parameter annotations: {param_hints}.\n"
+ _exc_shape_info(memos)
+ shape_str(memos)
)
if config.jaxtyping_remove_typechecker_stack:
raise TypeCheckError(msg) from None
Expand Down Expand Up @@ -756,40 +756,6 @@ def _pformat(x, short_self: bool):
return pformat(x)


def _exc_shape_info(memos) -> str:
"""Gives debug information on the current state of jaxtyping's internal memos.
Used in type-checking error messages.
"""
single_memo, variadic_memo, pytree_memo, _ = memos
single_memo = {
name: size
for name, size in single_memo.items()
if not name.startswith("~~delete~~")
}
variadic_memo = {
name: shape
for name, (_, shape) in variadic_memo.items()
if not name.startswith("~~delete~~")
}
pieces = []
if len(single_memo) > 0 or len(variadic_memo) > 0:
pieces.append(
"The current values for each jaxtyping axis annotation are as follows."
)
for name, size in single_memo.items():
pieces.append(f"{name}={size}")
for name, shape in variadic_memo.items():
pieces.append(f"{name}={shape}")
if len(pytree_memo) > 0:
pieces.append(
"The current values for each jaxtyping PyTree structure annotation are as "
"follows."
)
for name, structure in pytree_memo.items():
pieces.append(f"{name}={structure}")
return "\n".join(pieces)


class _jaxtyping_note_str(str):
"""Used with `_no_jaxtyping_note` to flag that a note came from jaxtyping."""

Expand All @@ -808,21 +774,3 @@ def _no_jaxtyping_note(e: Exception) -> bool:


_spacer = "--------------------\n"


def print_bindings():
"""Prints the values of the current jaxtyping axis bindings. Intended for debugging.
That is, whilst doing runtime type checking, so that e.g. the `foo` and `bar` of
`Float[Array, "foo bar"]` are assigned values -- this function will print out those
values.
**Arguments:**
Nothing.
**Returns:**
Nothing.
"""
print(_exc_shape_info(get_shape_memo()))
56 changes: 56 additions & 0 deletions jaxtyping/_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,62 @@ def pop_shape_memo() -> None:
_shape_storage.memo_stack.pop()


def shape_str(memos) -> str:
"""Gives debug information on the current state of jaxtyping's internal memos.
Used in type-checking error messages.
**Arguments:**
- `memos`: as returned by `get_shape_memo` or `push_shape_memo`.
"""
single_memo, variadic_memo, pytree_memo, _ = memos
single_memo = {
name: size
for name, size in single_memo.items()
if not name.startswith("~~delete~~")
}
variadic_memo = {
name: shape
for name, (_, shape) in variadic_memo.items()
if not name.startswith("~~delete~~")
}
pieces = []
if len(single_memo) > 0 or len(variadic_memo) > 0:
pieces.append(
"The current values for each jaxtyping axis annotation are as follows."
)
for name, size in single_memo.items():
pieces.append(f"{name}={size}")
for name, shape in variadic_memo.items():
pieces.append(f"{name}={shape}")
if len(pytree_memo) > 0:
pieces.append(
"The current values for each jaxtyping PyTree structure annotation are as "
"follows."
)
for name, structure in pytree_memo.items():
pieces.append(f"{name}={structure}")
return "\n".join(pieces)


def print_bindings():
"""Prints the values of the current jaxtyping axis bindings. Intended for debugging.
That is, whilst doing runtime type checking, so that e.g. the `foo` and `bar` of
`Float[Array, "foo bar"]` are assigned values -- this function will print out those
values.
**Arguments:**
Nothing.
**Returns:**
Nothing.
"""
print(shape_str(get_shape_memo()))


_treepath_storage = threading.local()


Expand Down

0 comments on commit 28ad527

Please sign in to comment.