Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix ad compatibility #694

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 36 additions & 30 deletions brainpy/_src/math/op_register/ad_support.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,57 @@
import functools
from functools import partial

import jax
from jax import tree_util
from jax.core import Primitive
from jax.interpreters import ad

__all__ = [
'defjvp',
'defjvp',
]


def defjvp(primitive, *jvp_rules):
"""Define JVP rules for any JAX primitive.
"""Define JVP rules for any JAX primitive.

This function is similar to ``jax.interpreters.ad.defjvp``.
However, the JAX one only supports primitive with ``multiple_results=False``.
``brainpy.math.defjvp`` enables to define the independent JVP rule for
each input parameter no matter ``multiple_results=False/True``.
This function is similar to ``jax.interpreters.ad.defjvp``.
However, the JAX one only supports primitive with ``multiple_results=False``.
``brainpy.math.defjvp`` enables to define the independent JVP rule for
each input parameter no matter ``multiple_results=False/True``.

For examples, please see ``test_ad_support.py``.
For examples, please see ``test_ad_support.py``.

Args:
primitive: Primitive, XLACustomOp.
*jvp_rules: The JVP translation rule for each primal.
"""
assert isinstance(primitive, Primitive)
if primitive.multiple_results:
ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
else:
ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)
Args:
primitive: Primitive, XLACustomOp.
*jvp_rules: The JVP translation rule for each primal.
"""
assert isinstance(primitive, Primitive)
if primitive.multiple_results:
ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
else:
ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)


def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params):
assert primitive.multiple_results
val_out = tuple(primitive.bind(*primals, **params))
tree = tree_util.tree_structure(val_out)
tangents_out = []
for rule, t in zip(jvp_rules, tangents):
if rule is not None and type(t) is not ad.Zero:
r = tuple(rule(t, *primals, **params))
tangents_out.append(r)
assert tree_util.tree_structure(r) == tree
return val_out, functools.reduce(_add_tangents,
tangents_out,
tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out))
assert primitive.multiple_results
val_out = tuple(primitive.bind(*primals, **params))
tree = tree_util.tree_structure(val_out)
tangents_out = []
for rule, t in zip(jvp_rules, tangents):
if rule is not None and type(t) is not ad.Zero:
r = tuple(rule(t, *primals, **params))
tangents_out.append(r)
assert tree_util.tree_structure(r) == tree
return val_out, functools.reduce(
_add_tangents,
tangents_out,
tree_util.tree_map(
# compatible with JAX 0.4.34
lambda a: ad.Zero.from_primal_value(a) if jax.__version__ >= '0.4.34' else ad.Zero.from_value(a),
val_out
)
)


def _add_tangents(xs, ys):
return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))

return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))
Loading