diff --git a/brainpy/_src/math/op_register/ad_support.py b/brainpy/_src/math/op_register/ad_support.py index 342093ea..57ce9064 100644 --- a/brainpy/_src/math/op_register/ad_support.py +++ b/brainpy/_src/math/op_register/ad_support.py @@ -1,51 +1,57 @@ import functools from functools import partial +import jax from jax import tree_util from jax.core import Primitive from jax.interpreters import ad __all__ = [ - 'defjvp', + 'defjvp', ] def defjvp(primitive, *jvp_rules): - """Define JVP rules for any JAX primitive. + """Define JVP rules for any JAX primitive. - This function is similar to ``jax.interpreters.ad.defjvp``. - However, the JAX one only supports primitive with ``multiple_results=False``. - ``brainpy.math.defjvp`` enables to define the independent JVP rule for - each input parameter no matter ``multiple_results=False/True``. + This function is similar to ``jax.interpreters.ad.defjvp``. + However, the JAX one only supports primitive with ``multiple_results=False``. + ``brainpy.math.defjvp`` enables to define the independent JVP rule for + each input parameter no matter ``multiple_results=False/True``. - For examples, please see ``test_ad_support.py``. + For examples, please see ``test_ad_support.py``. - Args: - primitive: Primitive, XLACustomOp. - *jvp_rules: The JVP translation rule for each primal. - """ - assert isinstance(primitive, Primitive) - if primitive.multiple_results: - ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive) - else: - ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive) + Args: + primitive: Primitive, XLACustomOp. + *jvp_rules: The JVP translation rule for each primal. + """ + assert isinstance(primitive, Primitive) + if primitive.multiple_results: + ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive) + else: + ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive) def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params): - assert primitive.multiple_results - val_out = tuple(primitive.bind(*primals, **params)) - tree = tree_util.tree_structure(val_out) - tangents_out = [] - for rule, t in zip(jvp_rules, tangents): - if rule is not None and type(t) is not ad.Zero: - r = tuple(rule(t, *primals, **params)) - tangents_out.append(r) - assert tree_util.tree_structure(r) == tree - return val_out, functools.reduce(_add_tangents, - tangents_out, - tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out)) + assert primitive.multiple_results + val_out = tuple(primitive.bind(*primals, **params)) + tree = tree_util.tree_structure(val_out) + tangents_out = [] + for rule, t in zip(jvp_rules, tangents): + if rule is not None and type(t) is not ad.Zero: + r = tuple(rule(t, *primals, **params)) + tangents_out.append(r) + assert tree_util.tree_structure(r) == tree + return val_out, functools.reduce( + _add_tangents, + tangents_out, + tree_util.tree_map( + # compatible with JAX 0.4.34 + lambda a: ad.Zero.from_primal_value(a) if jax.__version__ >= '0.4.34' else ad.Zero.from_value(a), + val_out + ) + ) def _add_tangents(xs, ys): - return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero)) - + return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))