diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index 97f26712c..5f06b4e67 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -915,9 +915,8 @@ def _valid_jaxtype(arg): def _check_output_dtype_revderiv(name, holomorphic, x): aval = core.get_aval(x) - if core.is_opaque_dtype(aval.dtype): - raise TypeError( - f"{name} with output element type {aval.dtype.name}") + # if jnp.issubdtype(aval.dtype, dtypes.extended): + # raise TypeError(f"{name} with output element type {aval.dtype.name}") if holomorphic: if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, " @@ -938,9 +937,8 @@ def _check_output_dtype_revderiv(name, holomorphic, x): def _check_input_dtype_revderiv(name, holomorphic, allow_int, x): _check_arg(x) aval = core.get_aval(x) - if core.is_opaque_dtype(aval.dtype): - raise TypeError( - f"{name} with input element type {aval.dtype.name}") + # if jnp.issubdtype(aval.dtype, dtypes.extended): + # raise TypeError(f"{name} with input element type {aval.dtype.name}") if holomorphic: if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, " @@ -972,8 +970,8 @@ def _check_output_dtype_jacfwd(holomorphic, x): def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None: _check_arg(x) aval = core.get_aval(x) - if core.is_opaque_dtype(aval.dtype): - raise TypeError(f"jacfwd with input element type {aval.dtype.name}") + # if jnp.issubdtype(aval.dtype, dtypes.extended): + # raise TypeError(f"jacfwd with input element type {aval.dtype.name}") if holomorphic: if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError("jacfwd with holomorphic=True requires inputs with complex " diff --git a/brainpy/_src/math/op_registers/numba_approach/__init__.py b/brainpy/_src/math/op_registers/numba_approach/__init__.py index 4740b98e2..cd05cab7b 100644 --- a/brainpy/_src/math/op_registers/numba_approach/__init__.py +++ b/brainpy/_src/math/op_registers/numba_approach/__init__.py @@ -45,7 +45,7 @@ class XLACustomOp(BrainPyObject): cpu_func: callable The function defines the computation on CPU backend. Same as ``con_compute``. gpu_func: callable - The function defines the computation on GPU backend. Currently, this function is not supportted. + The function defines the computation on GPU backend. Currently, this function is not supported. apply_cpu_func_to_gpu: bool Whether allows to apply CPU function on GPU backend. If True, the GPU data will move to CPU, and after calculation, the returned outputs on CPU backend will move to GPU. diff --git a/brainpy/_src/math/op_registers/tests/test_ei_net.py b/brainpy/_src/math/op_registers/tests/test_ei_net.py index 24a1a6a6c..28d106cb2 100644 --- a/brainpy/_src/math/op_registers/tests/test_ei_net.py +++ b/brainpy/_src/math/op_registers/tests/test_ei_net.py @@ -3,9 +3,6 @@ from jax.core import ShapedArray -bm.set_platform('cpu') - - def abs_eval(events, indices, indptr, *, weight, post_num): return [ShapedArray((post_num,), bm.float32), ] @@ -25,7 +22,7 @@ def con_compute(outs, ins): event_sum = bm.XLACustomOp(eval_shape=abs_eval, cpu_func=con_compute) -class ExponentialV2(bp.TwoEndConn): +class ExponentialV2(bp.synapses.TwoEndConn): """Exponential synapse model using customized operator written in C++.""" def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.): @@ -46,8 +43,8 @@ def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.): # function self.integral = bp.odeint(lambda g, t: -g / self.tau, method='exp_auto') - def update(self, tdi): - self.g.value = self.integral(self.g, tdi.t, tdi.dt) + def update(self): + self.g.value = self.integral(self.g, bp.share['t']) self.g += event_sum(self.pre.spike, self.pre2post[0], self.pre2post[1], @@ -56,31 +53,25 @@ def update(self, tdi): self.post.input += self.g * (self.E - self.post.V) -class EINet(bp.Network): +class EINet(bp.DynSysGroup): def __init__(self, scale): + super().__init__() # neurons - bm.random.seed() pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - E = bp.neurons.LIF(int(3200 * scale), **pars, method='exp_auto') - I = bp.neurons.LIF(int(800 * scale), **pars, method='exp_auto') + V_initializer=bp.init.Normal(-55., 2.), method='exp_auto') + self.E = bp.neurons.LIF(int(3200 * scale), **pars) + self.I = bp.neurons.LIF(int(800 * scale), **pars) # synapses - E2E = ExponentialV2(E, E, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.) - E2I = ExponentialV2(E, I, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.) - I2E = ExponentialV2(I, E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.) - I2I = ExponentialV2(I, I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.) - - super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I) + self.E2E = ExponentialV2(self.E, self.E, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.) + self.E2I = ExponentialV2(self.E, self.I, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.) + self.I2E = ExponentialV2(self.I, self.E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.) + self.I2I = ExponentialV2(self.I, self.I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.) def test1(): - bm.random.seed() + bm.set_platform('cpu') net2 = EINet(scale=0.1) - runner2 = bp.DSRunner(net2, inputs=[('E.input', 20.), ('I.input', 20.)]) - r = runner2.predict(100., eval_time=True) + runner = bp.DSRunner(net2, inputs=[('E.input', 20.), ('I.input', 20.)]) + r = runner.predict(100., eval_time=True) bm.clear_buffer_memory() - - - - diff --git a/docs/tutorial_advanced/3_dedicated_operators.rst b/docs/tutorial_advanced/3_dedicated_operators.rst index 746891cfa..36696e4aa 100644 --- a/docs/tutorial_advanced/3_dedicated_operators.rst +++ b/docs/tutorial_advanced/3_dedicated_operators.rst @@ -3,3 +3,5 @@ Brain Dynamics Dedicated Operators .. toctree:: :maxdepth: 1 + + low-level_operator_customization.ipynb \ No newline at end of file diff --git a/docs/tutorial_advanced/low-level_operator_customization.ipynb b/docs/tutorial_advanced/low-level_operator_customization.ipynb new file mode 100644 index 000000000..f914cb7aa --- /dev/null +++ b/docs/tutorial_advanced/low-level_operator_customization.ipynb @@ -0,0 +1,404 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": true + }, + "source": [ + "# Low-level Operator Customization" + ] + }, + { + "cell_type": "markdown", + "source": [ + "@[Tianqiu Zhang](https://github.com/ztqakita)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "BrainPy is built on Jax and can accelerate model running performance based on [Just-in-Time(JIT) compilation](./compilation.ipynb). In order to enhance performance on CPU and GPU, we publish another package ``BrainPyLib`` to provide several built-in low-level operators in synaptic computation. These operators are written in C++/CUDA and wrapped as Jax primitives by using ``XLA``. However, users cannot simply customize their own operators unless they have specific background. To solve this problem, we introduce `numba.cfunc` here and provide convenient interfaces for users to customize operators without touching the underlying logic. In this tutorial, we will introduce how to customize operators on CPU. Please notice that BrainPy currently only supports CPU operators customization, and GPU operators will be supported in the future." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "source": [ + "import brainpy as bp\n", + "import brainpy.math as bm\n", + "import jax\n", + "from jax import jit\n", + "import jax.numpy as jnp\n", + "from jax.core import ShapedArray\n", + "import numba\n", + "import time\n", + "\n", + "bm.set_platform('cpu')" + ], + "metadata": { + "collapsed": false + }, + "execution_count": 1, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ztqakita/opt/anaconda3/envs/bdp/lib/python3.9/site-packages/flax/struct.py:136: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.\n", + " jax.tree_util.register_keypaths(data_clz, keypaths)\n", + "/Users/ztqakita/opt/anaconda3/envs/bdp/lib/python3.9/site-packages/flax/struct.py:136: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.\n", + " jax.tree_util.register_keypaths(data_clz, keypaths)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "We have formally discussed the benefits of computation with our built-in operators. These operators are provided by `brainpylib` package and can be accessed through `brainpy.math` module. To be more specific, in order to speed up sparse synaptic computation, we customize several low-level operators for CPU and GPU, which are written in C++/CUDA and converted into Jax/XLA compatible primitive by using `Pybind11`." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "It is not easy to write a C++/CUDA operator and implement a series of conversion. Users have to learn how to write a C++/CUDA operator, how to write a customized Jax primitive, and how to convert your C++/CUDA operator into a Jax primitive. Here are some links for users who prefer to dive into the details: [Jax primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html), [XLA custom calls](https://www.tensorflow.org/xla/custom_call).\n", + "\n", + "However, we can only provide limit amounts of operators for users, and it would be great if users can customize their own operators in a relatively simple way. To achieve this goal, BrainPy provides a convenient interface `XLACustomOp` to register customized operators on CPU. Users no longer need to involve any C++ programming and XLA compilation. This is accomplished with the help of [`numba.cfunc`](https://numba.pydata.org/numba-doc/latest/user/cfunc.html), which will wrap python code as a compiled function callable from foreign C code. The C function object exposes the address of the compiled C callback so that it can be passed into XLA and registered as a jittable Jax primitives. Here is an example of using `XLACustomOp` on CPU." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## How to customize operators?" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "### CPU version\n", + "\n", + "First, users can customize a simple operator written in python. Notice that this python operator will be jitted in nopython mode, but some language features are not available inside Numba-compiled functions. Please look up [numba documentations](https://numba.pydata.org/numba-doc/latest/reference/pysupported.html#pysupported) for details." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [], + "source": [ + "def custom_op(outs, ins):\n", + " y, y1 = outs\n", + " x, x2 = ins\n", + " y[:] = x + 1\n", + " y1[:] = x2 + 2" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "There are some restrictions that users should know:\n", + "- Parameters of the operators are `outs` and `ins`, corresponding to output variable(s) and input variable(s). The order cannot be changed.\n", + "- The function cannot have any return value.\n", + "- When applying CPU function to GPU, users only need to implement CPU operators." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "Then users should describe the shapes and types of the outputs, because JAX/python can deduce the shapes and types of inputs when you call it, but it cannot infer the shapes and types of the outputs. The argument can be:\n", + "- a `ShapedArray`,\n", + "- a sequence of `ShapedArray`,\n", + "- a function, it should return correct output shapes of `ShapedArray`.\n", + "\n", + "Here we use function to describe the output shapes and types. The arguments include all the inputs of custom operators, but only shapes and types are accessible." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [], + "source": [ + "def abs_eval_1(*ins):\n", + " # ins: inputs arguments, only shapes and types are accessible.\n", + " # Because custom_op outputs shapes and types are exactly the\n", + " # same as inputs, so here we can only return ordinary inputs.\n", + " return ins" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "The function above is somewhat abstract for users, so here we give an alternative function below for passing shape information. We want you to know ``abs_eval_1`` and ``abs_eval_2`` are doing the same thing." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [], + "source": [ + "def abs_eval_2(*ins):\n", + " return ShapedArray(ins[0].shape, ins[0].dtype), ShapedArray(ins[1].shape, ins[1].dtype)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "Now we have prepared for registering a CPU operator. `XLACustomOp` will be called to wrap your operator and return a jittable Jax primitives. Here are some parameters users should define:\n", + "- `name`: Name of the operator.\n", + "- `eval_shape`: The function to evaluate the shape and dtype of the output according to the input. This function should receive the abstract information of inputs, and return the abstract information of the outputs.\n", + "- `con_compute`: The function to make the concrete computation. This function receives inputs and returns outputs.\n", + "- `cpu_func`: The function defines the computation on CPU backend. Same as ``con_compute``.\n", + "- `gpu_func`: The function defines the computation on GPU backend. Currently, this function is not supported.\n", + "- `apply_cpu_func_to_gpu`: Whether allows to apply CPU function on GPU backend. If True, the GPU data will be moved to CPU, and after calculation returned outputs on CPU backend will move to GPU.\n", + "- `batching_translation`: The batching translation for the primitive.\n", + "- `jvp_translation`: The forward autodiff translation rule.\n", + "- `transpose_translation`: The backward autodiff translation rule.\n", + "- `multiple_results`: Whether the primitive returns multiple results." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Array([[2., 2.]], dtype=float32), Array([[3., 3.]], dtype=float32)]\n" + ] + } + ], + "source": [ + "z = jnp.ones((1, 2), dtype=jnp.float32)\n", + "# Users could try out_shapes=abs_eval_2 and see if the result is different\n", + "op = bm.XLACustomOp(\n", + " name='add',\n", + " eval_shape=abs_eval_1,\n", + " cpu_func=custom_op,\n", + ")\n", + "jit_op = jit(op)\n", + "print(jit_op(z, z))" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "### GPU version\n", + "\n", + "We have discussed how to customize a CPU operator above, next we will talk about GPU operator, which is slightly different from CPU version. There are two additional parameters users need to provide:\n", + "- `gpu_func`: Customized operator of GPU version.\n", + "- `apply_cpu_func_to_gpu`: Whether to run kernel function on CPU for an alternative way for GPU version.\n", + "\n", + "```{warning}\n", + " GPU operators will be wrapped by `cuda.jit` in `numba`, but `numba` currently is not support to launch CUDA kernels from `cfuncs`. For this reason, `gpu_func` is none for default, and there will be an error if users pass a gpu operator to `gpu_func`.\n", + "```" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "Therefore, BrainPy enables users to set `apply_cpu_func_to_gpu` to true for a backup method. All the inputs will be initialized on GPU and transferred to CPU for computing. The operator users have defined will be implemented on CPU and the results will be transferred back to GPU for further tasks." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Performance" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "To illustrate the effectiveness of this approach, we will compare the customized operators with BrainPy built-in operators. Here we use `event_sum` as an example. The implementation of `event_sum` by using our customization is shown as below:" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [], + "source": [ + "def abs_eval(data, indices, indptr, vector, shape):\n", + " out_shape = shape[0]\n", + " return ShapedArray((out_shape,), data.dtype),\n", + "\n", + "@numba.njit(fastmath=True)\n", + "def sparse_op(outs, ins):\n", + " res_val = outs[0]\n", + " res_val.fill(0)\n", + " values, col_indices, row_ptr, vector, shape = ins\n", + "\n", + " for row_i in range(shape[0]):\n", + " v = vector[row_i]\n", + " for j in range(row_ptr[row_i], row_ptr[row_i + 1]):\n", + " res_val[col_indices[j]] += values * v\n", + "\n", + "sparse_cus_op = bm.XLACustomOp(name='sparse', eval_shape=abs_eval, con_compute=sparse_op)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "We will use sparse matrix vector multiplication to be our benchmark for testing the speed. We will use built-in operator `event` first." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "source": [ + "def sparse(size, prob):\n", + " bm.random.seed()\n", + " vector = bm.random.randn(size)\n", + " sparse_A = bp.conn.FixedProb(prob=prob, allow_multi_conn=True)(size, size).require('pre2post')\n", + " t0 = time.time()\n", + " for _ in range(100):\n", + " hidden = jax.block_until_ready(bm.sparse.csrmv(1., sparse_A[0], sparse_A[1], vector, shape=(size, size), transpose=True, method='vector'))\n", + " cost_t = time.time() - t0\n", + " print(f'Sparse: size {size}, prob {prob}, cost_t {cost_t} s.')\n", + " bm.clear_buffer_memory()\n", + "\n", + "sparse(50000, 0.01)" + ], + "metadata": { + "collapsed": false + }, + "execution_count": 7, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sparse: size 50000, prob 0.01, cost_t 2.222744941711426 s.\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "The total time is 2.22 seconds. Next we use our customized operator." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sparse: size 50000, prob 0.01, cost_t 2.364152193069458 s.\n" + ] + } + ], + "source": [ + "def sparse_customize(size, prob):\n", + " bm.random.seed()\n", + " vector = bm.random.randn(size)\n", + " sparse_A = bp.conn.FixedProb(prob=prob, allow_multi_conn=True)(size, size).require('pre2post')\n", + " t0 = time.time()\n", + " f = jit(lambda a: sparse_cus_op(a, sparse_A[0], sparse_A[1], vector, shape=(size, size)))\n", + " for _ in range(100):\n", + " hidden = jax.block_until_ready(f(1.))\n", + " cost_t = time.time() - t0\n", + " print(f'Sparse: size {size}, prob {prob}, cost_t {cost_t} s.')\n", + " bm.clear_buffer_memory()\n", + "\n", + "sparse_customize(50000, 0.01)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "After comparison, the customization method is almost as fast as the built-in method. Users can simply build their own operators without considering the computation speed loss." + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/requirements-dev.txt b/requirements-dev.txt index 01184540a..49fa49722 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,8 +1,8 @@ numpy numba brainpylib -jax>=0.4.1 -jaxlib>=0.4.1 +jax>=0.4.1, <0.4.16 +jaxlib>=0.4.1, <0.4.16 matplotlib>=3.4 msgpack tqdm diff --git a/requirements-doc.txt b/requirements-doc.txt index d88a0c02a..e6e498937 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -2,8 +2,8 @@ numpy tqdm msgpack numba -jax>=0.4.1 -jaxlib>=0.4.1 +jax>=0.4.1, <0.4.16 +jaxlib>=0.4.1, <0.4.16 matplotlib>=3.4 scipy>=1.1.0 numba diff --git a/requirements.txt b/requirements.txt index 74db0a68a..ebf85b86e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ numpy -jax>=0.4.1 +jax>=0.4.1, <0.4.16 tqdm msgpack numba \ No newline at end of file diff --git a/setup.py b/setup.py index 343ca3a89..68debcdee 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ author_email='chao.brain@qq.com', packages=packages, python_requires='>=3.8', - install_requires=['numpy>=1.15', 'jax>=0.4.1', 'tqdm', 'msgpack', 'numba'], + install_requires=['numpy>=1.15', 'jax>=0.4.1, <0.4.16', 'tqdm', 'msgpack', 'numba'], url='https://github.com/brainpy/BrainPy', project_urls={ "Bug Tracker": "https://github.com/brainpy/BrainPy/issues",