diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 1bfce85d592c..eb5e7e8bf8de 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -283,15 +283,15 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...], return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args @lu.transformation2 -def _argnums_partial(f, dyn_argnums, fixed_args, *dyn_args, **kwargs): +def _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs): sentinel = object() - args = [sentinel] * (len(fixed_args) + len(dyn_args)) - for i, arg in zip(dyn_argnums, dyn_args): + args = [sentinel] * (len(_fixed_args) + len(dyn_args)) + for i, arg in zip(_dyn_argnums, dyn_args): args[i] = arg - fixed_args_ = iter(fixed_args) + fixed_args_ = iter(_fixed_args) args = [next(fixed_args_).val if x is sentinel else x for x in args] assert next(fixed_args_, sentinel) is sentinel - return f(*args, **kwargs) + return _fun(*args, **kwargs) def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...], kwargs: dict[str, Any]): @@ -315,9 +315,9 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...], return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs @lu.transformation2 -def _argnames_partial(f, fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs): - kwargs = dict({k: v.val for k, v in fixed_kwargs.val.items()}, **dyn_kwargs) - return f(*args, **kwargs) +def _argnames_partial(_fun, _fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs): + kwargs = dict({k: v.val for k, v in _fixed_kwargs.val.items()}, **dyn_kwargs) + return _fun(*args, **kwargs) @lru_cache(maxsize=4096) @@ -438,9 +438,9 @@ def flat_out_axes( return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef)) @lu.transformation_with_aux2 -def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs): - ans = f(*args, **kwargs) - spec = tree_unflatten(treedef, leaves) +def _flat_out_axes(_fun, _store, _leaves, _treedef, *args, **kwargs): + ans = _fun(*args, **kwargs) + spec = tree_unflatten(_treedef, _leaves) try: spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None)) except ValueError: @@ -451,7 +451,7 @@ def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs): "that the `out_axes` argument to `pmap` is a pytree prefix of the " "pmapped function's output.") raise ValueError(msg) from None - store.store(spec_flat) + _store.store(spec_flat) return ans def check_callable(fun): @@ -687,10 +687,10 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames, for path, l in generate_key_paths(x) if l is not static) @lu.transformation_with_aux2 -def result_paths(f, store, *args, **kwargs): +def result_paths(_fun, _store, *args, **kwargs): "linear_util transform to get output pytree paths of pre-flattened function." - ans = f(*args, **kwargs) - store.store([keystr(path) for path, _ in generate_key_paths(ans)]) + ans = _fun(*args, **kwargs) + _store.store([keystr(path) for path, _ in generate_key_paths(ans)]) return ans def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,