Skip to content

Commit

Permalink
[doc] update operator customization
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Oct 10, 2023
1 parent 460da63 commit 2dc40a9
Show file tree
Hide file tree
Showing 10 changed files with 705 additions and 424 deletions.
38 changes: 24 additions & 14 deletions brainpy/_src/math/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,7 @@ def addr_(
*,
beta: float = 1.0,
alpha: float = 1.0
) -> None:
):
vec1 = _as_jax_array_(vec1)
vec2 = _as_jax_array_(vec2)
r = alpha * jnp.outer(vec1, vec2) + beta * self.value
Expand Down Expand Up @@ -1217,7 +1217,7 @@ def absolute(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = Non
"""
return self.abs(out=out)

def absolute_(self) -> None:
def absolute_(self):
"""
alias of Array.abs_()
"""
Expand Down Expand Up @@ -1258,11 +1258,11 @@ def sin(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) ->
_check_out(out)
out.value = r

def sin_(self) -> None:
def sin_(self):
self.value = jnp.sin(self.value)
return self

def cos_(self) -> None:
def cos_(self):
self.value = jnp.cos(self.value)
return self

Expand All @@ -1274,7 +1274,7 @@ def cos(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) ->
_check_out(out)
out.value = r

def tan_(self) -> None:
def tan_(self):
self.value = jnp.tan(self.value)
return self

Expand All @@ -1286,7 +1286,7 @@ def tan(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) ->
_check_out(out)
out.value = r

def sinh_(self) -> None:
def sinh_(self):
self.value = jnp.tanh(self.value)
return self

Expand All @@ -1298,7 +1298,7 @@ def sinh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -
_check_out(out)
out.value = r

def cosh_(self) -> None:
def cosh_(self):
self.value = jnp.cosh(self.value)
return self

Expand All @@ -1310,7 +1310,7 @@ def cosh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -
_check_out(out)
out.value = r

def tanh_(self) -> None:
def tanh_(self):
self.value = jnp.tanh(self.value)
return self

Expand All @@ -1322,7 +1322,7 @@ def tanh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -
_check_out(out)
out.value = r

def arcsin_(self) -> None:
def arcsin_(self):
self.value = jnp.arcsin(self.value)
return self

Expand All @@ -1334,7 +1334,7 @@ def arcsin(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None)
_check_out(out)
out.value = r

def arccos_(self) -> None:
def arccos_(self):
self.value = jnp.arccos(self.value)
return self

Expand All @@ -1346,7 +1346,7 @@ def arccos(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None)
_check_out(out)
out.value = r

def arctan_(self) -> None:
def arctan_(self):
self.value = jnp.arctan(self.value)
return self

Expand Down Expand Up @@ -1381,7 +1381,7 @@ def clamp(

def clamp_(self,
min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None,
max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> None:
max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None):
"""
return the value between min_value and max_value,
if min_value is None, then no lower bound,
Expand All @@ -1392,7 +1392,7 @@ def clamp_(self,

def clip_(self,
min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None,
max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> None:
max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None):
"""
alias for clamp_
"""
Expand All @@ -1402,7 +1402,7 @@ def clip_(self,
def clone(self) -> 'Array':
return Array(self.value.copy())

def copy_(self, src: Union['Array', jax.Array, np.ndarray]) -> None:
def copy_(self, src: Union['Array', jax.Array, np.ndarray]) -> 'Array':
self.value = jnp.copy(_as_jax_array_(src))
return self

Expand Down Expand Up @@ -1507,6 +1507,16 @@ def cpu(self):
self.value = jax.device_put(self.value, jax.devices('cpu')[0])
return self

# dtype exchanging #
# ---------------- #

def bool(self): return jnp.asarray(self.value, dtypt=jnp.bool_)
def int(self): return jnp.asarray(self.value, dtypt=jnp.int32)
def long(self): return jnp.asarray(self.value, dtypt=jnp.int64)
def half(self): return jnp.asarray(self.value, dtypt=jnp.float16)
def float(self): return jnp.asarray(self.value, dtypt=jnp.float32)
def double(self): return jnp.asarray(self.value, dtype=jnp.float64)


JaxArray = Array
ndarray = Array
Expand Down
1 change: 1 addition & 0 deletions brainpy/_src/math/op_registers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

from .numba_approach import (XLACustomOp,
CustomOpByNumba,
register_op_with_numba,
compile_cpu_signature_with_numba)
from .utils import register_general_batching
65 changes: 64 additions & 1 deletion brainpy/_src/math/op_registers/numba_approach/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,74 @@
from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba

__all__ = [
'CustomOpByNumba',
'XLACustomOp',
'register_op_with_numba',
'compile_cpu_signature_with_numba',
]


class CustomOpByNumba(BrainPyObject):
"""Creating a XLA custom call operator with Numba JIT on CPU backend.
Parameters
----------
name: str
The name of operator.
eval_shape: callable
The function to evaluate the shape and dtype of the output according to the input.
This function should receive the abstract information of inputs, and return the
abstract information of the outputs. For example:
>>> def eval_shape(inp1_info, inp2_info, inp3_info, ...):
>>> return out1_info, out2_info
con_compute: callable
The function to make the concrete computation. This function receives inputs,
and returns outputs. For example:
>>> def con_compute(inp1, inp2, inp3, ...):
>>> return out1, out2
"""

def __init__(
self,
eval_shape: Callable = None,
con_compute: Callable = None,
name: str = None,
batching_translation: Callable = None,
jvp_translation: Callable = None,
transpose_translation: Callable = None,
multiple_results: bool = True,
):
super().__init__(name=name)

# abstract evaluation function
if eval_shape is None:
raise ValueError('Must provide "eval_shape" for abstract evaluation.')

# cpu function
cpu_func = con_compute

# register OP
self.op = register_op_with_numba(
self.name,
cpu_func=cpu_func,
out_shapes=eval_shape,
batching_translation=batching_translation,
jvp_translation=jvp_translation,
transpose_translation=transpose_translation,
multiple_results=multiple_results,
)

def __call__(self, *args, **kwargs):
args = tree_map(lambda a: a.value if isinstance(a, Array) else a,
args, is_leaf=lambda a: isinstance(a, Array))
kwargs = tree_map(lambda a: a.value if isinstance(a, Array) else a,
kwargs, is_leaf=lambda a: isinstance(a, Array))
res = self.op.bind(*args, **kwargs)
return res


class XLACustomOp(BrainPyObject):
"""Creating a XLA custom call operator.
Expand Down Expand Up @@ -175,8 +237,9 @@ def abs_eval_rule(*input_shapes, **info):
shapes = out_shapes

if isinstance(shapes, core.ShapedArray):
pass
assert not multiple_results, "multiple_results is True, while the abstract evaluation returns only one data."
elif isinstance(shapes, (tuple, list)):
assert multiple_results, "multiple_results is False, while the abstract evaluation returns multiple data."
for elem in shapes:
if not isinstance(elem, core.ShapedArray):
raise ValueError(f'Elements in "out_shapes" must be instances of '
Expand Down
8 changes: 4 additions & 4 deletions brainpy/_src/math/surrogate/_one_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, forward_use_surrogate=False):
self.forward_use_surrogate = forward_use_surrogate
self._true_call_ = jax.custom_gradient(self.call)

def __call__(self, x: Union[jax.Array, Array]):
def __call__(self, x: jax.Array):
return self._true_call_(as_jax(x))

def call(self, x):
Expand Down Expand Up @@ -70,7 +70,7 @@ class Sigmoid(_OneInpSurrogate):
"""

def __init__(self, alpha=4., forward_use_surrogate=False):
def __init__(self, alpha: float = 4., forward_use_surrogate=False):
super().__init__(forward_use_surrogate)
self.alpha = alpha

Expand Down Expand Up @@ -154,7 +154,7 @@ class PiecewiseQuadratic(_OneInpSurrogate):
"""

def __init__(self, alpha=1., forward_use_surrogate=False):
def __init__(self, alpha: float = 1., forward_use_surrogate=False):
super().__init__(forward_use_surrogate)
self.alpha = alpha

Expand Down Expand Up @@ -258,7 +258,7 @@ class PiecewiseExp(_OneInpSurrogate):
piecewise_exp
"""

def __init__(self, alpha=1., forward_use_surrogate=False):
def __init__(self, alpha: float = 1., forward_use_surrogate=False):
super().__init__(forward_use_surrogate)
self.alpha = alpha

Expand Down
1 change: 1 addition & 0 deletions brainpy/math/op_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


from brainpy._src.math.op_registers import (
CustomOpByNumba,
XLACustomOp,
compile_cpu_signature_with_numba,
)
Expand Down
29 changes: 29 additions & 0 deletions docs/apis/brainpy.math.op_register.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
Operator Registration
=====================

.. contents::
:local:
:depth: 1


CPU Operator Customization with Numba
-------------------------------------

.. currentmodule:: brainpy.math
.. automodule:: brainpy.math

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

CustomOpByNumba
XLACustomOp


.. autosummary::
:toctree: generated/

register_op_with_numba
compile_cpu_signature_with_numba

1 change: 1 addition & 0 deletions docs/apis/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ dynamics programming. For more information and usage examples, please refer to t
brainpy.math.sharding.rst
brainpy.math.environment.rst
brainpy.math.modes.rst
brainpy.math.op_register.rst

2 changes: 1 addition & 1 deletion docs/tutorial_advanced/3_dedicated_operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ Brain Dynamics Dedicated Operators
.. toctree::
:maxdepth: 1

low-level_operator_customization.ipynb
operator_custom_with_numba.ipynb
Loading

0 comments on commit 2dc40a9

Please sign in to comment.