Skip to content

Commit

Permalink
[math] add brainpy.math.defjvp, support to define jvp rules for Pri…
Browse files Browse the repository at this point in the history
…mitive with multiple results. See examples in `test_ad_support.py`
  • Loading branch information
chaoming0625 committed Dec 2, 2023
1 parent 6c2c9bb commit 670937e
Showing 4 changed files with 192 additions and 2 deletions.
50 changes: 50 additions & 0 deletions brainpy/_src/math/ad_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import functools
from functools import partial

from jax import tree_util
from jax.core import Primitive
from jax.interpreters import ad
from brainpy._src.math.op_register.base import XLACustomOp

__all__ = [
'defjvp',
]


def defjvp(primitive, *jvp_rules):
"""Define JVP rule when the primitive
Args:
primitive: Primitive, XLACustomOp.
*jvp_rules: The JVP translation rule for each primal.
Returns:
The JVP gradients.
"""
if isinstance(primitive, XLACustomOp):
primitive = primitive.primitive
assert isinstance(primitive, Primitive)
if primitive.multiple_results:
ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
else:
ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)


def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params):
assert primitive.multiple_results
val_out = tuple(primitive.bind(*primals, **params))
tree = tree_util.tree_structure(val_out)
tangents_out = []
for rule, t in zip(jvp_rules, tangents):
if rule is not None and type(t) is not ad.Zero:
r = tuple(rule(t, *primals, **params))
tangents_out.append(r)
assert tree_util.tree_structure(r) == tree
return val_out, functools.reduce(_add_tangents,
tangents_out,
tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out))


def _add_tangents(xs, ys):
return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))

4 changes: 2 additions & 2 deletions brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
@@ -139,13 +139,13 @@ def __init__(
if transpose_translation is not None:
ad.primitive_transposes[self.primitive] = transpose_translation

def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None):
def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None, **kwargs):
if outs is None:
outs = self.outs
assert outs is not None
outs = tuple([_transform_to_shapedarray(o) for o in outs])
ins = jax.tree_util.tree_map(_transform_to_array, ins, is_leaf=_is_bp_array)
return self.primitive.bind(*ins, outs=outs)
return self.primitive.bind(*ins, outs=outs, **kwargs)

def def_abstract_eval(self, fun):
"""Define the abstract evaluation function.
136 changes: 136 additions & 0 deletions brainpy/_src/math/tests/test_ad_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import Tuple

import jax
import numba
from jax import core
from jax import numpy as jnp
from jax.interpreters import ad

import brainpy as bp
import brainpy.math as bm


def csrmv(data, indices, indptr, vector, *, shape: Tuple[int, int], transpose: bool = False, ):
data = jnp.atleast_1d(bm.as_jax(data))
indices = bm.as_jax(indices)
indptr = bm.as_jax(indptr)
vector = bm.as_jax(vector)
if vector.dtype == jnp.bool_:
vector = bm.as_jax(vector, dtype=data.dtype)
outs = [core.ShapedArray([shape[1] if transpose else shape[0]], data.dtype)]
if transpose:
return prim_trans(data, indices, indptr, vector, outs=outs, shape=shape, transpose=transpose)
else:
return prim(data, indices, indptr, vector, outs=outs, shape=shape, transpose=transpose)


@numba.njit(fastmath=True)
def _csr_matvec_transpose_numba_imp(values, col_indices, row_ptr, vector, res_val):
res_val.fill(0)
if values.shape[0] == 1:
values = values[0]
for row_i in range(vector.shape[0]):
v = vector[row_i]
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
res_val[col_indices[j]] += values * v
else:
for row_i in range(vector.shape[0]):
v = vector[row_i]
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
res_val[col_indices[j]] += v * values[j]


@numba.njit(fastmath=True, parallel=True, nogil=True)
def _csr_matvec_numba_imp(values, col_indices, row_ptr, vector, res_val):
res_val.fill(0)
# csr mat @ vec
if values.shape[0] == 1:
values = values[0]
for row_i in numba.prange(res_val.shape[0]):
r = 0.
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
r += values * vector[col_indices[j]]
res_val[row_i] = r
else:
for row_i in numba.prange(res_val.shape[0]):
r = 0.
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
r += values[j] * vector[col_indices[j]]
res_val[row_i] = r


def _csrmv_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose, **kwargs):
return csrmv(data_dot, indices, indptr, v, shape=shape, transpose=transpose)


def _csrmv_jvp_vec(v_dot, data, indices, indptr, v, *, shape, transpose, **kwargs):
return csrmv(data, indices, indptr, v_dot, shape=shape, transpose=transpose)


def _csrmv_cusparse_transpose(ct, data, indices, indptr, vector, *, shape, transpose, **kwargs):
if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
raise ValueError("Cannot transpose with respect to sparse indices.")

ct = ct[0]
if ad.is_undefined_primal(vector):
ct_vector = csrmv(data, indices, indptr, ct, shape=shape, transpose=not transpose)
return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector)

else:
if type(ct) is ad.Zero:
ct_data = ad.Zero(data)
else:
if data.aval.shape[0] == 1: # scalar
ct_data = csrmv(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)
ct_data = jnp.inner(ct, ct_data)
else: # heterogeneous values
row, col = bm.sparse.csr_to_coo(indices, indptr)
ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row]
return ct_data, indices, indptr, vector


prim_trans = bm.XLACustomOp(_csr_matvec_transpose_numba_imp)
bm.defjvp(prim_trans, _csrmv_jvp_mat, None, None, _csrmv_jvp_vec)
prim_trans.def_transpose_rule(_csrmv_cusparse_transpose)

prim = bm.XLACustomOp(_csr_matvec_numba_imp)
bm.defjvp(prim, _csrmv_jvp_mat, None, None, _csrmv_jvp_vec)
prim.def_transpose_rule(_csrmv_cusparse_transpose)


def sum_op(op):
def func(*args, **kwargs):
r = op(*args, **kwargs)
return r.sum()

return func


def try_a_trial(transpose, shape):
rng = bm.random.RandomState()
conn = bp.conn.FixedProb(0.1)
indices, indptr = conn(*shape).require('pre2post')
indices = bm.as_jax(indices)
indptr = bm.as_jax(indptr)
heter_data = rng.random(indices.shape)
heter_data = bm.as_jax(heter_data)
vector = rng.random(shape[0] if transpose else shape[1])
vector = bm.as_jax(vector)

r5 = jax.grad(sum_op(lambda *args, **kwargs: bm.sparse.csrmv(*args, **kwargs, method='vector')), argnums=(0, 3))(
heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
r6 = jax.grad(sum_op(lambda *args, **kwargs: csrmv(*args, **kwargs)[0]), argnums=(0, 3))(
heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
print(r5)
print(r6)
assert bm.allclose(r5[0], r6[0])
assert bm.allclose(r5[1], r6[1][0])


def test():
transposes = [True, False]
shapes = [(100, 200), (10, 1000), (2, 2000)]

for transpose in transposes:
for shape in shapes:
try_a_trial(transpose, shape)
4 changes: 4 additions & 0 deletions brainpy/math/others.py
Original file line number Diff line number Diff line change
@@ -9,3 +9,7 @@
from brainpy._src.math.object_transform.naming import (
clear_name_cache,
)

from brainpy._src.math.ad_support import (
defjvp as defjvp,
)

0 comments on commit 670937e

Please sign in to comment.