Skip to content

Commit

Permalink
Merge pull request #729 from mit-ll-responsible-ai/pyright-up
Browse files Browse the repository at this point in the history
Update pyright and add support for jax 0.4.32+
  • Loading branch information
rsokl committed Sep 14, 2024
2 parents 9b6a037 + 4fdbfda commit cdd6dd3
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 17 deletions.
2 changes: 1 addition & 1 deletion deps/requirements-pyright.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pyright==1.1.375
pyright==1.1.380
4 changes: 3 additions & 1 deletion docs/source/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,10 @@ hydra-zen also provides auto-config support for some third-pary libraries:

- `pydantic.dataclasses.dataclass`
- `pydantic.Field`
- `pydantic.Field`
- `torch.optim.optimizer.required` (i.e. the default parameter for `lr` in `Optimizer`)
- `numpy.ufunc` and nunmpy array dispatchers (e.g. `np.sum`)
- `jax.numpy.ufunc` and jax compiled functions (e.g. `jax.vmap`)



*********************
Expand Down
27 changes: 23 additions & 4 deletions src/hydra_zen/structured_configs/_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,13 @@ def _is_ufunc(value: Any) -> bool:
return isinstance(value, numpy.ufunc)


def _is_jax_ufunc(value: Any) -> bool: # pragma: no cover
# checks without importing numpy
if (jnp := sys.modules.get("jax.numpy")) is None: # pragma: no cover
return False
return isinstance(value, jnp.ufunc)


def _is_numpy_array_func_dispatcher(value: Any) -> bool:
if (numpy := sys.modules.get("numpy")) is None: # pragma: no cover
return False
Expand Down Expand Up @@ -1047,7 +1054,7 @@ def _mutable_value(cls, x: _T, *, zen_convert: Optional[ZenConvert] = None) -> _

if cast in {list, tuple, dict}:
x = cls._sanitize_collection(x, convert_dataclass=settings["dataclass"])
return field(default_factory=lambda: cast(x)) # type: ignore
return field(default_factory=lambda: cast(x))
return field(default_factory=lambda: x)

@classmethod
Expand Down Expand Up @@ -1157,6 +1164,7 @@ def _make_hydra_compatible(
or _is_ufunc(value)
or _is_numpy_array_func_dispatcher(value=value)
or _is_jax_compiled_func(value=value)
or _is_jax_ufunc(value=value)
)
):
# `value` is importable callable -- create config that will import
Expand Down Expand Up @@ -1576,7 +1584,13 @@ def builds(
@classmethod
def builds(
cls: Type[Self],
*pos_args: Union[Importable, Callable[P, R], Type[AnyBuilds[Importable]], Any],
*pos_args: Union[
Importable,
Callable[P, R],
Type[AnyBuilds[Importable]],
Type[BuildsWithSig[Type[R], P]],
Any,
],
zen_partial: Optional[bool] = None,
zen_wrappers: ZenWrappers[Callable[..., Any]] = tuple(),
zen_meta: Optional[Mapping[str, SupportedPrimitive]] = None,
Expand Down Expand Up @@ -2584,7 +2598,7 @@ def builds(self,target, populate_full_signature=False, **kw):
if is_dataclass(target):
_fields = {f.name: f for f in fields(target)}
else:
_fields = target.__fields__ # type: ignore
_fields = target.__fields__
_update = {}
for name, param in signature_params.items():
if name not in _fields:
Expand Down Expand Up @@ -3379,7 +3393,12 @@ def kwargs_of(
@classmethod
def kwargs_of(
cls: Type[Self],
__hydra_target: Callable[P, Any],
__hydra_target: Union[
Callable[P, Any],
Callable[Concatenate[Any, P], Any],
Callable[Concatenate[Any, Any, P], Any],
Callable[Concatenate[Any, Any, Any, P], Any],
],
*,
zen_dataclass: Optional[DataclassOptions] = None,
zen_exclude: Union[
Expand Down
15 changes: 13 additions & 2 deletions src/hydra_zen/typing/_builds_overloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,12 @@ def __call__(

def __call__(
self,
__hydra_target: Union[Callable[P, R], Type[Builds[Importable]], Importable],
__hydra_target: Union[
Callable[P, R],
Type[Builds[Importable]],
Importable,
Type[BuildsWithSig[Type[R], P]],
],
*pos_args: T,
zen_partial: Optional[bool] = None,
populate_full_signature: bool = False,
Expand Down Expand Up @@ -476,6 +481,7 @@ def __call__(
Importable,
Type[Builds[Importable]],
Type[PartialBuilds[Importable]],
Type[BuildsWithSig[Type[R], P]],
],
*pos_args: T,
zen_partial: Optional[bool] = None,
Expand Down Expand Up @@ -656,7 +662,12 @@ def __call__(

def __call__(
self,
__hydra_target: Union[Callable[P, R], Type[AnyBuilds[Importable]], Importable],
__hydra_target: Union[
Callable[P, R],
Type[AnyBuilds[Importable]],
Importable,
Type[BuildsWithSig[Type[R], P]],
],
*pos_args: T,
zen_partial: Optional[bool] = True,
populate_full_signature: bool = False,
Expand Down
6 changes: 3 additions & 3 deletions src/hydra_zen/wrapper/_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def validate(self, __cfg: Union[ConfigLike, str]) -> None:
p.kind is p.POSITIONAL_ONLY for p in self.parameters.values()
)

_args_ = getattr(cfg, "_args_", [])
_args_: List[Any] = getattr(cfg, "_args_", [])

if not isinstance(_args_, Sequence):
raise HydraZenValidationError(
Expand Down Expand Up @@ -1603,12 +1603,12 @@ def __call__(self: Self, __target: Optional[F] = None, **kw: Any) -> Union[F, Se
if "provider" not in kw: # pragma: no branch
provider = "hydra_zen"

_name: NodeName = name(__target) if callable(name) else name
_name: NodeName = name(__target) if callable(name) else name # type: ignore
if not isinstance(_name, str):
raise TypeError(f"`name` must be a string, got {_name}")
del name

_group: GroupName = group(__target) if callable(group) else group
_group: GroupName = group(__target) if callable(group) else group # type: ignore
if _group is not None and not isinstance(_group, str):
raise TypeError(f"`group` must be a string or None, got {_group}")
del group
Expand Down
10 changes: 5 additions & 5 deletions tests/annotations/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,19 +564,19 @@ def f(x: int, y: str, z: bool = False):
# The following should be ok
reveal_type(
Conf_f(1, "hi"),
expected_text="BuildsWithSig[type[C], (x: int, y: str, z: bool = ...)]",
expected_text="BuildsWithSig[type[C], (x: int, y: str, z: bool = False)]",
)
reveal_type(
Conf_f(1, "hi", True),
expected_text="BuildsWithSig[type[C], (x: int, y: str, z: bool = ...)]",
expected_text="BuildsWithSig[type[C], (x: int, y: str, z: bool = False)]",
)
reveal_type(
Conf_f(1, y="hi"),
expected_text="BuildsWithSig[type[C], (x: int, y: str, z: bool = ...)]",
expected_text="BuildsWithSig[type[C], (x: int, y: str, z: bool = False)]",
)
reveal_type(
Conf_f(x=1, y="hi", z=False),
expected_text="BuildsWithSig[type[C], (x: int, y: str, z: bool = ...)]",
expected_text="BuildsWithSig[type[C], (x: int, y: str, z: bool = False)]",
)

# check instantiation
Expand Down Expand Up @@ -727,7 +727,7 @@ def f(x: int, y: str, z: bool = False):

reveal_type(
Conf,
expected_text="type[BuildsWithSig[type[int], (x: int, y: str, z: bool = ...)]]",
expected_text="type[BuildsWithSig[type[int], (x: int, y: str, z: bool = False)]]",
)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_docs_typecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_docstrings_scan_clean_via_pyright(func):
raw_files,
pyright_analyze(
*raw_files,
report_unnecessary_type_ignore_comment=True,
report_unnecessary_type_ignore_comment=False,
preamble=preamble,
pyright_config=pyright_config,
),
Expand Down

0 comments on commit cdd6dd3

Please sign in to comment.