From 2dc40a9ce44301f5ccb8de63a97712fccae8c6a7 Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 11 Oct 2023 07:00:10 +0800 Subject: [PATCH] [doc] update operator customization --- brainpy/_src/math/ndarray.py | 38 +- brainpy/_src/math/op_registers/__init__.py | 1 + .../op_registers/numba_approach/__init__.py | 65 +- brainpy/_src/math/surrogate/_one_input.py | 8 +- brainpy/math/op_register.py | 1 + docs/apis/brainpy.math.op_register.rst | 29 + docs/apis/math.rst | 1 + .../3_dedicated_operators.rst | 2 +- .../low-level_operator_customization.ipynb | 404 ------------ .../operator_custom_with_numba.ipynb | 580 ++++++++++++++++++ 10 files changed, 705 insertions(+), 424 deletions(-) create mode 100644 docs/apis/brainpy.math.op_register.rst delete mode 100644 docs/tutorial_advanced/low-level_operator_customization.ipynb create mode 100644 docs/tutorial_advanced/operator_custom_with_numba.ipynb diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index c83c43eea..cb5e739e4 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -1181,7 +1181,7 @@ def addr_( *, beta: float = 1.0, alpha: float = 1.0 - ) -> None: + ): vec1 = _as_jax_array_(vec1) vec2 = _as_jax_array_(vec2) r = alpha * jnp.outer(vec1, vec2) + beta * self.value @@ -1217,7 +1217,7 @@ def absolute(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = Non """ return self.abs(out=out) - def absolute_(self) -> None: + def absolute_(self): """ alias of Array.abs_() """ @@ -1258,11 +1258,11 @@ def sin(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> _check_out(out) out.value = r - def sin_(self) -> None: + def sin_(self): self.value = jnp.sin(self.value) return self - def cos_(self) -> None: + def cos_(self): self.value = jnp.cos(self.value) return self @@ -1274,7 +1274,7 @@ def cos(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> _check_out(out) out.value = r - def tan_(self) -> None: + def tan_(self): self.value = jnp.tan(self.value) return self @@ -1286,7 +1286,7 @@ def tan(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> _check_out(out) out.value = r - def sinh_(self) -> None: + def sinh_(self): self.value = jnp.tanh(self.value) return self @@ -1298,7 +1298,7 @@ def sinh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) - _check_out(out) out.value = r - def cosh_(self) -> None: + def cosh_(self): self.value = jnp.cosh(self.value) return self @@ -1310,7 +1310,7 @@ def cosh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) - _check_out(out) out.value = r - def tanh_(self) -> None: + def tanh_(self): self.value = jnp.tanh(self.value) return self @@ -1322,7 +1322,7 @@ def tanh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) - _check_out(out) out.value = r - def arcsin_(self) -> None: + def arcsin_(self): self.value = jnp.arcsin(self.value) return self @@ -1334,7 +1334,7 @@ def arcsin(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) _check_out(out) out.value = r - def arccos_(self) -> None: + def arccos_(self): self.value = jnp.arccos(self.value) return self @@ -1346,7 +1346,7 @@ def arccos(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) _check_out(out) out.value = r - def arctan_(self) -> None: + def arctan_(self): self.value = jnp.arctan(self.value) return self @@ -1381,7 +1381,7 @@ def clamp( def clamp_(self, min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, - max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> None: + max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None): """ return the value between min_value and max_value, if min_value is None, then no lower bound, @@ -1392,7 +1392,7 @@ def clamp_(self, def clip_(self, min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, - max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> None: + max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None): """ alias for clamp_ """ @@ -1402,7 +1402,7 @@ def clip_(self, def clone(self) -> 'Array': return Array(self.value.copy()) - def copy_(self, src: Union['Array', jax.Array, np.ndarray]) -> None: + def copy_(self, src: Union['Array', jax.Array, np.ndarray]) -> 'Array': self.value = jnp.copy(_as_jax_array_(src)) return self @@ -1507,6 +1507,16 @@ def cpu(self): self.value = jax.device_put(self.value, jax.devices('cpu')[0]) return self + # dtype exchanging # + # ---------------- # + + def bool(self): return jnp.asarray(self.value, dtypt=jnp.bool_) + def int(self): return jnp.asarray(self.value, dtypt=jnp.int32) + def long(self): return jnp.asarray(self.value, dtypt=jnp.int64) + def half(self): return jnp.asarray(self.value, dtypt=jnp.float16) + def float(self): return jnp.asarray(self.value, dtypt=jnp.float32) + def double(self): return jnp.asarray(self.value, dtype=jnp.float64) + JaxArray = Array ndarray = Array diff --git a/brainpy/_src/math/op_registers/__init__.py b/brainpy/_src/math/op_registers/__init__.py index 685f3c37b..3628c3279 100644 --- a/brainpy/_src/math/op_registers/__init__.py +++ b/brainpy/_src/math/op_registers/__init__.py @@ -1,5 +1,6 @@ from .numba_approach import (XLACustomOp, + CustomOpByNumba, register_op_with_numba, compile_cpu_signature_with_numba) from .utils import register_general_batching diff --git a/brainpy/_src/math/op_registers/numba_approach/__init__.py b/brainpy/_src/math/op_registers/numba_approach/__init__.py index cd05cab7b..ed960a738 100644 --- a/brainpy/_src/math/op_registers/numba_approach/__init__.py +++ b/brainpy/_src/math/op_registers/numba_approach/__init__.py @@ -16,12 +16,74 @@ from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba __all__ = [ + 'CustomOpByNumba', 'XLACustomOp', 'register_op_with_numba', 'compile_cpu_signature_with_numba', ] +class CustomOpByNumba(BrainPyObject): + """Creating a XLA custom call operator with Numba JIT on CPU backend. + + Parameters + ---------- + name: str + The name of operator. + eval_shape: callable + The function to evaluate the shape and dtype of the output according to the input. + This function should receive the abstract information of inputs, and return the + abstract information of the outputs. For example: + + >>> def eval_shape(inp1_info, inp2_info, inp3_info, ...): + >>> return out1_info, out2_info + con_compute: callable + The function to make the concrete computation. This function receives inputs, + and returns outputs. For example: + + >>> def con_compute(inp1, inp2, inp3, ...): + >>> return out1, out2 + """ + + def __init__( + self, + eval_shape: Callable = None, + con_compute: Callable = None, + name: str = None, + batching_translation: Callable = None, + jvp_translation: Callable = None, + transpose_translation: Callable = None, + multiple_results: bool = True, + ): + super().__init__(name=name) + + # abstract evaluation function + if eval_shape is None: + raise ValueError('Must provide "eval_shape" for abstract evaluation.') + + # cpu function + cpu_func = con_compute + + # register OP + self.op = register_op_with_numba( + self.name, + cpu_func=cpu_func, + out_shapes=eval_shape, + batching_translation=batching_translation, + jvp_translation=jvp_translation, + transpose_translation=transpose_translation, + multiple_results=multiple_results, + ) + + def __call__(self, *args, **kwargs): + args = tree_map(lambda a: a.value if isinstance(a, Array) else a, + args, is_leaf=lambda a: isinstance(a, Array)) + kwargs = tree_map(lambda a: a.value if isinstance(a, Array) else a, + kwargs, is_leaf=lambda a: isinstance(a, Array)) + res = self.op.bind(*args, **kwargs) + return res + + class XLACustomOp(BrainPyObject): """Creating a XLA custom call operator. @@ -175,8 +237,9 @@ def abs_eval_rule(*input_shapes, **info): shapes = out_shapes if isinstance(shapes, core.ShapedArray): - pass + assert not multiple_results, "multiple_results is True, while the abstract evaluation returns only one data." elif isinstance(shapes, (tuple, list)): + assert multiple_results, "multiple_results is False, while the abstract evaluation returns multiple data." for elem in shapes: if not isinstance(elem, core.ShapedArray): raise ValueError(f'Elements in "out_shapes" must be instances of ' diff --git a/brainpy/_src/math/surrogate/_one_input.py b/brainpy/_src/math/surrogate/_one_input.py index c967622ee..382bfdda3 100644 --- a/brainpy/_src/math/surrogate/_one_input.py +++ b/brainpy/_src/math/surrogate/_one_input.py @@ -37,7 +37,7 @@ def __init__(self, forward_use_surrogate=False): self.forward_use_surrogate = forward_use_surrogate self._true_call_ = jax.custom_gradient(self.call) - def __call__(self, x: Union[jax.Array, Array]): + def __call__(self, x: jax.Array): return self._true_call_(as_jax(x)) def call(self, x): @@ -70,7 +70,7 @@ class Sigmoid(_OneInpSurrogate): """ - def __init__(self, alpha=4., forward_use_surrogate=False): + def __init__(self, alpha: float = 4., forward_use_surrogate=False): super().__init__(forward_use_surrogate) self.alpha = alpha @@ -154,7 +154,7 @@ class PiecewiseQuadratic(_OneInpSurrogate): """ - def __init__(self, alpha=1., forward_use_surrogate=False): + def __init__(self, alpha: float = 1., forward_use_surrogate=False): super().__init__(forward_use_surrogate) self.alpha = alpha @@ -258,7 +258,7 @@ class PiecewiseExp(_OneInpSurrogate): piecewise_exp """ - def __init__(self, alpha=1., forward_use_surrogate=False): + def __init__(self, alpha: float = 1., forward_use_surrogate=False): super().__init__(forward_use_surrogate) self.alpha = alpha diff --git a/brainpy/math/op_register.py b/brainpy/math/op_register.py index b15a0d6de..7fb7df73f 100644 --- a/brainpy/math/op_register.py +++ b/brainpy/math/op_register.py @@ -2,6 +2,7 @@ from brainpy._src.math.op_registers import ( + CustomOpByNumba, XLACustomOp, compile_cpu_signature_with_numba, ) diff --git a/docs/apis/brainpy.math.op_register.rst b/docs/apis/brainpy.math.op_register.rst new file mode 100644 index 000000000..7010b64eb --- /dev/null +++ b/docs/apis/brainpy.math.op_register.rst @@ -0,0 +1,29 @@ +Operator Registration +===================== + +.. contents:: + :local: + :depth: 1 + + +CPU Operator Customization with Numba +------------------------------------- + +.. currentmodule:: brainpy.math +.. automodule:: brainpy.math + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + CustomOpByNumba + XLACustomOp + + +.. autosummary:: + :toctree: generated/ + + register_op_with_numba + compile_cpu_signature_with_numba + diff --git a/docs/apis/math.rst b/docs/apis/math.rst index 97d7749be..e3f0b765a 100644 --- a/docs/apis/math.rst +++ b/docs/apis/math.rst @@ -35,4 +35,5 @@ dynamics programming. For more information and usage examples, please refer to t brainpy.math.sharding.rst brainpy.math.environment.rst brainpy.math.modes.rst + brainpy.math.op_register.rst diff --git a/docs/tutorial_advanced/3_dedicated_operators.rst b/docs/tutorial_advanced/3_dedicated_operators.rst index 36696e4aa..7885d7c7f 100644 --- a/docs/tutorial_advanced/3_dedicated_operators.rst +++ b/docs/tutorial_advanced/3_dedicated_operators.rst @@ -4,4 +4,4 @@ Brain Dynamics Dedicated Operators .. toctree:: :maxdepth: 1 - low-level_operator_customization.ipynb \ No newline at end of file + operator_custom_with_numba.ipynb \ No newline at end of file diff --git a/docs/tutorial_advanced/low-level_operator_customization.ipynb b/docs/tutorial_advanced/low-level_operator_customization.ipynb deleted file mode 100644 index f914cb7aa..000000000 --- a/docs/tutorial_advanced/low-level_operator_customization.ipynb +++ /dev/null @@ -1,404 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "collapsed": true - }, - "source": [ - "# Low-level Operator Customization" - ] - }, - { - "cell_type": "markdown", - "source": [ - "@[Tianqiu Zhang](https://github.com/ztqakita)" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "markdown", - "source": [ - "BrainPy is built on Jax and can accelerate model running performance based on [Just-in-Time(JIT) compilation](./compilation.ipynb). In order to enhance performance on CPU and GPU, we publish another package ``BrainPyLib`` to provide several built-in low-level operators in synaptic computation. These operators are written in C++/CUDA and wrapped as Jax primitives by using ``XLA``. However, users cannot simply customize their own operators unless they have specific background. To solve this problem, we introduce `numba.cfunc` here and provide convenient interfaces for users to customize operators without touching the underlying logic. In this tutorial, we will introduce how to customize operators on CPU. Please notice that BrainPy currently only supports CPU operators customization, and GPU operators will be supported in the future." - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "source": [ - "import brainpy as bp\n", - "import brainpy.math as bm\n", - "import jax\n", - "from jax import jit\n", - "import jax.numpy as jnp\n", - "from jax.core import ShapedArray\n", - "import numba\n", - "import time\n", - "\n", - "bm.set_platform('cpu')" - ], - "metadata": { - "collapsed": false - }, - "execution_count": 1, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/ztqakita/opt/anaconda3/envs/bdp/lib/python3.9/site-packages/flax/struct.py:136: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.\n", - " jax.tree_util.register_keypaths(data_clz, keypaths)\n", - "/Users/ztqakita/opt/anaconda3/envs/bdp/lib/python3.9/site-packages/flax/struct.py:136: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.\n", - " jax.tree_util.register_keypaths(data_clz, keypaths)\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "We have formally discussed the benefits of computation with our built-in operators. These operators are provided by `brainpylib` package and can be accessed through `brainpy.math` module. To be more specific, in order to speed up sparse synaptic computation, we customize several low-level operators for CPU and GPU, which are written in C++/CUDA and converted into Jax/XLA compatible primitive by using `Pybind11`." - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "markdown", - "source": [ - "It is not easy to write a C++/CUDA operator and implement a series of conversion. Users have to learn how to write a C++/CUDA operator, how to write a customized Jax primitive, and how to convert your C++/CUDA operator into a Jax primitive. Here are some links for users who prefer to dive into the details: [Jax primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html), [XLA custom calls](https://www.tensorflow.org/xla/custom_call).\n", - "\n", - "However, we can only provide limit amounts of operators for users, and it would be great if users can customize their own operators in a relatively simple way. To achieve this goal, BrainPy provides a convenient interface `XLACustomOp` to register customized operators on CPU. Users no longer need to involve any C++ programming and XLA compilation. This is accomplished with the help of [`numba.cfunc`](https://numba.pydata.org/numba-doc/latest/user/cfunc.html), which will wrap python code as a compiled function callable from foreign C code. The C function object exposes the address of the compiled C callback so that it can be passed into XLA and registered as a jittable Jax primitives. Here is an example of using `XLACustomOp` on CPU." - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "markdown", - "source": [ - "## How to customize operators?" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "markdown", - "source": [ - "### CPU version\n", - "\n", - "First, users can customize a simple operator written in python. Notice that this python operator will be jitted in nopython mode, but some language features are not available inside Numba-compiled functions. Please look up [numba documentations](https://numba.pydata.org/numba-doc/latest/reference/pysupported.html#pysupported) for details." - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "execution_count": 2, - "outputs": [], - "source": [ - "def custom_op(outs, ins):\n", - " y, y1 = outs\n", - " x, x2 = ins\n", - " y[:] = x + 1\n", - " y1[:] = x2 + 2" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "markdown", - "source": [ - "There are some restrictions that users should know:\n", - "- Parameters of the operators are `outs` and `ins`, corresponding to output variable(s) and input variable(s). The order cannot be changed.\n", - "- The function cannot have any return value.\n", - "- When applying CPU function to GPU, users only need to implement CPU operators." - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "markdown", - "source": [ - "Then users should describe the shapes and types of the outputs, because JAX/python can deduce the shapes and types of inputs when you call it, but it cannot infer the shapes and types of the outputs. The argument can be:\n", - "- a `ShapedArray`,\n", - "- a sequence of `ShapedArray`,\n", - "- a function, it should return correct output shapes of `ShapedArray`.\n", - "\n", - "Here we use function to describe the output shapes and types. The arguments include all the inputs of custom operators, but only shapes and types are accessible." - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "execution_count": 3, - "outputs": [], - "source": [ - "def abs_eval_1(*ins):\n", - " # ins: inputs arguments, only shapes and types are accessible.\n", - " # Because custom_op outputs shapes and types are exactly the\n", - " # same as inputs, so here we can only return ordinary inputs.\n", - " return ins" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "markdown", - "source": [ - "The function above is somewhat abstract for users, so here we give an alternative function below for passing shape information. We want you to know ``abs_eval_1`` and ``abs_eval_2`` are doing the same thing." - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "execution_count": 4, - "outputs": [], - "source": [ - "def abs_eval_2(*ins):\n", - " return ShapedArray(ins[0].shape, ins[0].dtype), ShapedArray(ins[1].shape, ins[1].dtype)" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "markdown", - "source": [ - "Now we have prepared for registering a CPU operator. `XLACustomOp` will be called to wrap your operator and return a jittable Jax primitives. Here are some parameters users should define:\n", - "- `name`: Name of the operator.\n", - "- `eval_shape`: The function to evaluate the shape and dtype of the output according to the input. This function should receive the abstract information of inputs, and return the abstract information of the outputs.\n", - "- `con_compute`: The function to make the concrete computation. This function receives inputs and returns outputs.\n", - "- `cpu_func`: The function defines the computation on CPU backend. Same as ``con_compute``.\n", - "- `gpu_func`: The function defines the computation on GPU backend. Currently, this function is not supported.\n", - "- `apply_cpu_func_to_gpu`: Whether allows to apply CPU function on GPU backend. If True, the GPU data will be moved to CPU, and after calculation returned outputs on CPU backend will move to GPU.\n", - "- `batching_translation`: The batching translation for the primitive.\n", - "- `jvp_translation`: The forward autodiff translation rule.\n", - "- `transpose_translation`: The backward autodiff translation rule.\n", - "- `multiple_results`: Whether the primitive returns multiple results." - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "execution_count": 5, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[Array([[2., 2.]], dtype=float32), Array([[3., 3.]], dtype=float32)]\n" - ] - } - ], - "source": [ - "z = jnp.ones((1, 2), dtype=jnp.float32)\n", - "# Users could try out_shapes=abs_eval_2 and see if the result is different\n", - "op = bm.XLACustomOp(\n", - " name='add',\n", - " eval_shape=abs_eval_1,\n", - " cpu_func=custom_op,\n", - ")\n", - "jit_op = jit(op)\n", - "print(jit_op(z, z))" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "markdown", - "source": [ - "### GPU version\n", - "\n", - "We have discussed how to customize a CPU operator above, next we will talk about GPU operator, which is slightly different from CPU version. There are two additional parameters users need to provide:\n", - "- `gpu_func`: Customized operator of GPU version.\n", - "- `apply_cpu_func_to_gpu`: Whether to run kernel function on CPU for an alternative way for GPU version.\n", - "\n", - "```{warning}\n", - " GPU operators will be wrapped by `cuda.jit` in `numba`, but `numba` currently is not support to launch CUDA kernels from `cfuncs`. For this reason, `gpu_func` is none for default, and there will be an error if users pass a gpu operator to `gpu_func`.\n", - "```" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "markdown", - "source": [ - "Therefore, BrainPy enables users to set `apply_cpu_func_to_gpu` to true for a backup method. All the inputs will be initialized on GPU and transferred to CPU for computing. The operator users have defined will be implemented on CPU and the results will be transferred back to GPU for further tasks." - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "markdown", - "source": [ - "## Performance" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "markdown", - "source": [ - "To illustrate the effectiveness of this approach, we will compare the customized operators with BrainPy built-in operators. Here we use `event_sum` as an example. The implementation of `event_sum` by using our customization is shown as below:" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "execution_count": 6, - "outputs": [], - "source": [ - "def abs_eval(data, indices, indptr, vector, shape):\n", - " out_shape = shape[0]\n", - " return ShapedArray((out_shape,), data.dtype),\n", - "\n", - "@numba.njit(fastmath=True)\n", - "def sparse_op(outs, ins):\n", - " res_val = outs[0]\n", - " res_val.fill(0)\n", - " values, col_indices, row_ptr, vector, shape = ins\n", - "\n", - " for row_i in range(shape[0]):\n", - " v = vector[row_i]\n", - " for j in range(row_ptr[row_i], row_ptr[row_i + 1]):\n", - " res_val[col_indices[j]] += values * v\n", - "\n", - "sparse_cus_op = bm.XLACustomOp(name='sparse', eval_shape=abs_eval, con_compute=sparse_op)" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "markdown", - "source": [ - "We will use sparse matrix vector multiplication to be our benchmark for testing the speed. We will use built-in operator `event` first." - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "source": [ - "def sparse(size, prob):\n", - " bm.random.seed()\n", - " vector = bm.random.randn(size)\n", - " sparse_A = bp.conn.FixedProb(prob=prob, allow_multi_conn=True)(size, size).require('pre2post')\n", - " t0 = time.time()\n", - " for _ in range(100):\n", - " hidden = jax.block_until_ready(bm.sparse.csrmv(1., sparse_A[0], sparse_A[1], vector, shape=(size, size), transpose=True, method='vector'))\n", - " cost_t = time.time() - t0\n", - " print(f'Sparse: size {size}, prob {prob}, cost_t {cost_t} s.')\n", - " bm.clear_buffer_memory()\n", - "\n", - "sparse(50000, 0.01)" - ], - "metadata": { - "collapsed": false - }, - "execution_count": 7, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Sparse: size 50000, prob 0.01, cost_t 2.222744941711426 s.\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "The total time is 2.22 seconds. Next we use our customized operator." - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "execution_count": 9, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Sparse: size 50000, prob 0.01, cost_t 2.364152193069458 s.\n" - ] - } - ], - "source": [ - "def sparse_customize(size, prob):\n", - " bm.random.seed()\n", - " vector = bm.random.randn(size)\n", - " sparse_A = bp.conn.FixedProb(prob=prob, allow_multi_conn=True)(size, size).require('pre2post')\n", - " t0 = time.time()\n", - " f = jit(lambda a: sparse_cus_op(a, sparse_A[0], sparse_A[1], vector, shape=(size, size)))\n", - " for _ in range(100):\n", - " hidden = jax.block_until_ready(f(1.))\n", - " cost_t = time.time() - t0\n", - " print(f'Sparse: size {size}, prob {prob}, cost_t {cost_t} s.')\n", - " bm.clear_buffer_memory()\n", - "\n", - "sparse_customize(50000, 0.01)" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "markdown", - "source": [ - "After comparison, the customization method is almost as fast as the built-in method. Users can simply build their own operators without considering the computation speed loss." - ], - "metadata": { - "collapsed": false - } - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/tutorial_advanced/operator_custom_with_numba.ipynb b/docs/tutorial_advanced/operator_custom_with_numba.ipynb new file mode 100644 index 000000000..84d4deb79 --- /dev/null +++ b/docs/tutorial_advanced/operator_custom_with_numba.ipynb @@ -0,0 +1,580 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": true + }, + "source": [ + "# Operator Customization with Numba" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Brain dynamics is sparse and event-driven, however, proprietary operators for brain dynamics are not well abstracted and summarized. As a result, we are often faced with the need to customize operators. In this tutorial, we will explore how to customize brain dynamics operators using Numba.\n", + "\n", + "Start by importing the relevant Python package." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [], + "source": [ + "import brainpy as bp\n", + "import brainpy.math as bm\n", + "\n", + "import jax\n", + "from jax import jit\n", + "import jax.numpy as jnp\n", + "from jax.core import ShapedArray\n", + "\n", + "import numba\n", + "\n", + "bm.set_platform('cpu')" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-10T22:58:55.444792400Z", + "start_time": "2023-10-10T22:58:55.368614800Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## ``brainpy.math.CustomOpByNumba``\n", + "\n", + "``brainpy.math.CustomOpByNumba`` is also called ``brainpy.math.XLACustomOp``.\n", + "\n", + "BrainPy provides ``brainpy.math.CustomOpByNumba`` for customizing the operator on the CPU device. Two parameters are required to provide in ``CustomOpByNumba``:\n", + "\n", + "- ``eval_shape``: evaluates the *shape* and *datatype* of the output argument based on the *shape* and *datatype* of the input argument.\n", + "- `con_compute`: receives the input parameters and performs a specific computation based on them.\n", + "\n", + "Suppose here we want to customize an operator that does the ``b = a+1`` operation. First, define an ``eval_shape`` function. The arguments to this function are information about all the input parameters, and the return value is information about the output parameters.\n", + "\n", + "```python\n", + "from jax.core import ShapedArray\n", + "\n", + "def eval_shape(a):\n", + " b = ShapedArray(a.shape, dtype=a.dtype)\n", + " return b\n", + "```\n", + "\n", + "Since ``b`` in ``b = a + 1`` has the same type and shape as ``a``, the ``eval_shape`` function returns the same shape and type. Next, we need to define ``con_compute``. ``con_compute`` takes only ``(outs, ins)`` arguments, where all return values are inside ``outs`` and all input arguments are inside ``ins``.\n", + "\n", + "\n", + "```python\n", + "def con_compute(outs, ins):\n", + " b = outs\n", + " a = ins\n", + " b[:] = a + 1\n", + "```\n", + "\n", + "Unlike the ``eval_shape`` function, the ``con_compute`` function does not support any return values. Instead, all output must just be updated in-place. Also, the ``con_compute`` function must follow the specification of Numba's just-in-time compilation, see:\n", + "\n", + "- https://numba.pydata.org/numba-doc/latest/reference/pysupported.html\n", + "- https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html\n", + "\n", + "Also, ``con_compute`` can be customized according to Numba's just-in-time compilation policy. For example, if JIT is just turned on, then you can use:\n", + "\n", + "```python\n", + "@numba.njit\n", + "def con_compute(outs, ins):\n", + " b = outs\n", + " a = ins\n", + " b[:] = a + 1\n", + "```\n", + "\n", + "If the parallel computation with multiple cores is turned on, you can use:\n", + "\n", + "\n", + "```python\n", + "@numba.njit(parallel=True)\n", + "def con_compute(outs, ins):\n", + " b = outs\n", + " a = ins\n", + " b[:] = a + 1\n", + "```\n", + "\n", + "\n", + "For more advanced usage, we encourage readers to read the [Numba online manual](https://numba.pydata.org/numba-doc/latest/index.html).\n", + "\n", + "Finally, this customized operator can be registered and used as:\n", + "\n", + "```bash\n", + "\n", + ">>> op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False)\n", + ">>> op(bm.zeros(10))\n", + "[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", + "```" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Return multiple values ``multiple_returns=True``\n", + "\n", + "If the result of our computation needs to return multiple arrays, then we need to use ``multiple_returns=True`` in our use of registering the operator. In this case, ``outs`` will be a list containing multiple arrays, not an array.\n", + "\n", + "\n", + "```python\n", + "def eval_shape2(a, b):\n", + " c = ShapedArray(a.shape, dtype=a.dtype)\n", + " d = ShapedArray(b.shape, dtype=b.dtype)\n", + " return c, d\n", + "\n", + "def con_compute2(outs, ins):\n", + " c, d = outs # 取出所有的输出\n", + " a, b = ins # 取出所有的输入\n", + " c[:] = a + 1\n", + " d[:] = a * 2\n", + "\n", + "op2 = bm.CustomOpByNumba(eval_shape2, con_compute2, multiple_results=True)\n", + "```\n", + "\n", + "```bash\n", + ">>> op2(bm.zeros(10), bm.ones(10))\n", + "([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.],\n", + " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.])\n", + "```" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Non-Tracer parameters\n", + "\n", + "In the ``eval_shape`` function, all arguments are abstract information (containing only the shape and type) if they are arguments that can be traced by ``jax.jit``. However, if we infer the output data type requires additional information beyond the input parameter information, then we need to define non-Tracer parameters.\n", + "\n", + "For an operator defined by ``brainpy.math.CustomOpByNumba``, non-Tracer parameters are often then parameters passed in via key-value pairs such as ``key=value``. For example:\n", + "\n", + "```python\n", + "op2(a, b, c, d=d, e=e)\n", + "```\n", + "\n", + "``a, b, c`` are all ``jax.jit`` traceable parameters, and ``d`` and ``e`` are deterministic, non-tracer parameters. Therefore, in the ``eval_shape(a, b, c, d, e)`` function, ``a, b, c`` will be ``SharedArray``, and ``d`` and ``e`` will be concrete values.\n", + "\n", + "For another example, \n", + "\n", + "```python\n", + "\n", + "def eval_shape3(a, *, b):\n", + " return SharedArray(b, a.dtype) # The shape of the return value is determined by the input b\n", + "\n", + "def con_compute3(outs, ins):\n", + " c = outs # Take out all the outputs\n", + " a, b = ins # Take out all inputs\n", + " c[:] = 2.\n", + "\n", + "op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n", + "```\n", + "\n", + "```bash\n", + ">>> op3(bm.zeros(4), 5)\n", + "[2. 2. 2. 2. 2.]\n", + "```" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "... note:\n", + "\n", + " It is worth noting that all arguments will be converted to arrays. Both Tracer and non-Tracer parameters are arrays in ``con_compute``. For example, ``1`` is passed in, but in ``con_compute`` it's a 0-dimensional array ``1``; ``(1, 2)`` is passed in, and in ``con_compute`` it will be the 1-dimensional array ``array([1, 2])``.\n", + " " + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Example: A sparse operator\n", + "\n", + "To illustrate the effectiveness of this approach, we define in this an event-driven sparse computation operator." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [], + "source": [ + "def abs_eval(data, indices, indptr, vector, shape):\n", + " out_shape = shape[0]\n", + " return ShapedArray((out_shape,), data.dtype),\n", + "\n", + "@numba.njit(fastmath=True)\n", + "def sparse_op(outs, ins):\n", + " res_val = outs[0]\n", + " res_val.fill(0)\n", + " values, col_indices, row_ptr, vector, shape = ins\n", + "\n", + " for row_i in range(shape[0]):\n", + " v = vector[row_i]\n", + " for j in range(row_ptr[row_i], row_ptr[row_i + 1]):\n", + " res_val[col_indices[j]] += values * v\n", + "\n", + "sparse_cus_op = bm.CustomOpByNumba(eval_shape=abs_eval, con_compute=sparse_op)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-10T22:58:55.539425400Z", + "start_time": "2023-10-10T22:58:55.398947400Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "Let's try to use sparse matrix vector multiplication operator." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [ + { + "data": { + "text/plain": "[Array([ -2.2834747, -52.950108 , -5.0921535, ..., -40.264236 ,\n -27.219269 , 33.138054 ], dtype=float32)]" + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "size = 5000\n", + "\n", + "vector = bm.random.randn(size)\n", + "sparse_A = bp.conn.FixedProb(prob=0.1, allow_multi_conn=True)(size, size).require('pre2post')\n", + "f = jit(lambda a: sparse_cus_op(a, sparse_A[0], sparse_A[1], vector, shape=(size, size)))\n", + "f(1.)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-10T22:58:57.856525300Z", + "start_time": "2023-10-10T22:58:55.414106700Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "大脑动力学具有稀疏和事件驱动的特性,然而,大脑动力学的专有算子并没有很好的抽象和总结。因此,我们往往面临着自定义算子的需求。在这个教程中,我们将探索如何使用Numba来自定义脑动力学算子。\n", + "\n", + "首先引入相关的Python包。" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [], + "source": [ + "import brainpy as bp\n", + "import brainpy.math as bm\n", + "\n", + "import jax\n", + "from jax import jit\n", + "import jax.numpy as jnp\n", + "from jax.core import ShapedArray\n", + "\n", + "import numba\n", + "\n", + "bm.set_platform('cpu')" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-10T22:58:57.858443100Z", + "start_time": "2023-10-10T22:58:57.842107200Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## ``brainpy.math.CustomOpByNumba``接口\n", + "\n", + "``brainpy.math.CustomOpByNumba`` 也叫做``brainpy.math.XLACustomOp``。\n", + "\n", + "BrainPy提供了``brainpy.math.CustomOpByNumba``用于自定义CPU上的算子。使用``CustomOpByNumba``需要提供两个接口:\n", + "\n", + "- `eval_shape`: 根据输入参数的形状(shape)和数据类型(dtype)来评估输出参数的形状和数据类型。\n", + "- `con_compute`: 接收真正的参数,并根据参数进行具体计算。\n", + "\n", + "假如在这里我们要自定义一个做``b = a+1``操作的算子。首先,定义一个``eval_shape``函数。该函数的参数是所有输入变量的信息,返回值是输出参数的信息。\n", + "\n", + "```python\n", + "from jax.core import ShapedArray\n", + "\n", + "def eval_shape(a):\n", + " b = ShapedArray(a.shape, dtype=a.dtype)\n", + " return b\n", + "```\n", + "\n", + "由于``b = a + 1``中``b``与``a``具有同样的类型和形状,因此``eval_shape``函数返回一样的形状和类型。接下来,我们就需要定义``con_compute``。``con_compute``只接收``(outs, ins)``参数,其中,所有的返回值都在``outs``内,所有的输入参数都在``ins``内。\n", + "\n", + "\n", + "```python\n", + "\n", + "def con_compute(outs, ins):\n", + " b = outs\n", + " a = ins\n", + " b[:] = a + 1\n", + "```\n", + "\n", + "与``eval_shape``函数不同,``con_compute``函数不接收任何返回值。相反,所有的输出都必须通过in-place update的形式就行。另外,``con_compute``函数必须遵循Numba即时编译的规范,见:\n", + "\n", + "- https://numba.pydata.org/numba-doc/latest/reference/pysupported.html\n", + "- https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html\n", + "\n", + "同时,``con_compute``也可以自定义Numba的即时编译策略。比如,如果只是开启JIT,那么可以用:\n", + "\n", + "```python\n", + "@numba.njit\n", + "def con_compute(outs, ins):\n", + " b = outs\n", + " a = ins\n", + " b[:] = a + 1\n", + "```\n", + "\n", + "如果是开始并行计算利用多核,可以使用:\n", + "\n", + "\n", + "```python\n", + "@numba.njit(parallel=True)\n", + "def con_compute(outs, ins):\n", + " b = outs\n", + " a = ins\n", + " b[:] = a + 1\n", + "```\n", + "\n", + "\n", + "更多高级用法,建议读者们阅读[Numba在线手册](https://numba.pydata.org/numba-doc/latest/index.html)。\n", + "\n", + "最后,我们自定义这个算子可以使用:\n", + "\n", + "```bash\n", + "\n", + ">>> op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False)\n", + ">>> op(bm.zeros(10))\n", + "[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", + "```" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## 返回多个值 ``multiple_returns=True``\n", + "\n", + "如果我们的计算结果需要返回多个数组,那么,我们在注册算子的使用需要使用``multiple_returns=True``。此时,``outs``将会是一个包含多个数组的列表,而不是一个数组。\n", + "\n", + "```python\n", + "def eval_shape2(a, b):\n", + " c = ShapedArray(a.shape, dtype=a.dtype)\n", + " d = ShapedArray(b.shape, dtype=b.dtype)\n", + " return c, d # 返回多个抽象数组信息\n", + "\n", + "def con_compute2(outs, ins):\n", + " c, d = outs # 取出所有的输出\n", + " a, b = ins # 取出所有的输入\n", + " c[:] = a + 1\n", + " d[:] = a * 2\n", + "\n", + "op2 = bm.CustomOpByNumba(eval_shape2, con_compute2, multiple_results=True)\n", + "```\n", + "\n", + "```bash\n", + ">>> op2(bm.zeros(10), bm.ones(10))\n", + "([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.],\n", + " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.])\n", + "```" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## 非Tracer参数\n", + "\n", + "在``eval_shape``函数中推断数据类型时,如果所有参数都是可以被``jax.jit``追踪的参数,那么所有参数都是抽象信息(只包含形状和类型)。如果有时推断输出数据类型时还需要除输入参数信息以外的额外信息,此时我们需要定义非Tracer参数。\n", + "\n", + "对于一个由``brainpy.math.CustomOpByNumba``定义的算子,非Tracer参数往往那么通过``key=value``等键值对传入的参数。比如,\n", + "\n", + "```python\n", + "op2(a, b, c, d=d, e=e)\n", + "```\n", + "\n", + "``a, b, c``都是可被`jax.jit`追踪的参数,`d`和`e`是确定性的、非Tracer参数。此时,``eval_shape(a, b, c, d, e)``函数中,a,b,c都是``SharedArray``,而d和e都是具体的数值,\n", + "\n", + "举个例子,\n", + "\n", + "```python\n", + "\n", + "def eval_shape3(a, *, b):\n", + " return SharedArray(b, a.dtype) # 返回值的形状由输入b决定\n", + "\n", + "def con_compute3(outs, ins):\n", + " c = outs # 取出所有的输出\n", + " a, b = ins # 取出所有的输入\n", + " c[:] = 2.\n", + "\n", + "op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n", + "```\n", + "\n", + "```bash\n", + ">>> op3(bm.zeros(4), 5)\n", + "[2. 2. 2. 2. 2.]\n", + "```" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "\n", + "... note::\n", + "\n", + " 值得注意的是,所有的输入值都将被转化成数组。无论是Tracer还是非Tracer参数,在``con_compute``中都是数组。比如传入的是``1``,但在``con_compute``中是0维数组``1``;传入的是``(1, 2)``,在``con_compute``中将是1维数组``array([1, 2])``。\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## 示例:一个稀疏算子\n", + "\n", + "为了说明这种方法的有效性,我们在这个定义一个事件驱动的稀疏计算算子。" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [], + "source": [ + "def abs_eval(data, indices, indptr, vector, shape):\n", + " out_shape = shape[0]\n", + " return [ShapedArray((out_shape,), data.dtype)]\n", + "\n", + "@numba.njit(fastmath=True)\n", + "def sparse_op(outs, ins):\n", + " res_val = outs[0]\n", + " res_val.fill(0)\n", + " values, col_indices, row_ptr, vector, shape = ins\n", + "\n", + " for row_i in range(shape[0]):\n", + " v = vector[row_i]\n", + " for j in range(row_ptr[row_i], row_ptr[row_i + 1]):\n", + " res_val[col_indices[j]] += values * v\n", + "\n", + "sparse_cus_op = bm.CustomOpByNumba(eval_shape=abs_eval, con_compute=sparse_op)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-10T22:58:57.858443100Z", + "start_time": "2023-10-10T22:58:57.849184700Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "使用该算子我们可以用:" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 7, + "outputs": [ + { + "data": { + "text/plain": "[Array([ 17.464092, -9.924386, -33.09052 , ..., -37.2057 , -12.551924,\n -9.046049], dtype=float32)]" + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "size = 5000\n", + "\n", + "vector = bm.random.randn(size)\n", + "sparse_A = bp.conn.FixedProb(prob=0.1, allow_multi_conn=True)(size, size).require('pre2post')\n", + "f = jit(lambda a: sparse_cus_op(a, sparse_A[0], sparse_A[1], vector, shape=(size, size)))\n", + "f(1.)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-10T22:58:58.245683200Z", + "start_time": "2023-10-10T22:58:57.853019500Z" + } + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}