diff --git a/brainpy/_src/dyn/synapses/delay_couplings.py b/brainpy/_src/dyn/synapses/delay_couplings.py index ef43139da..8a848e646 100644 --- a/brainpy/_src/dyn/synapses/delay_couplings.py +++ b/brainpy/_src/dyn/synapses/delay_couplings.py @@ -64,7 +64,7 @@ def __init__( self.output_var = var_to_output # Connection matrix - self.conn_mat = bm.asarray(conn_mat) + self.conn_mat = conn_mat if self.conn_mat.shape != required_shape: raise ValueError(f'we expect the structural connection matrix has the shape of ' f'(pre.num, post.num), i.e., {required_shape}, ' diff --git a/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py b/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py index bacf4076a..b759a4789 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py @@ -1,24 +1,14 @@ # -*- coding: utf-8 -*- -import sys from functools import partial import jax -import pytest from absl.testing import parameterized import brainpy as bp import brainpy.math as bm -# pytestmark = pytest.mark.skip(reason="Skipped due to pytest limitations, manual execution required for testing.") - -is_manual_test = False -if sys.platform.startswith('darwin') and not is_manual_test: - pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) - -# bm.set_platform('cpu') - seed = 1234 @@ -38,206 +28,6 @@ def func(*args, **kwargs): return func -# ### MANUAL TESTS ### - -# transposes = [True, False] -# shapes = [(100, 200), -# (200, 200), -# (200, 100), -# (10, 1000), -# # (2, 10000), -# # (1000, 10), -# # (10000, 2) -# ] -# homo_datas = [-1., 0., 1.] - -# def test_homo(shape, transpose, homo_data): -# print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') -# rng = bm.random.RandomState() -# indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') -# events = rng.random(shape[0] if transpose else shape[1]) < 0.1 -# heter_data = bm.ones(indices.shape) * homo_data - -# r1 = bm.event.csrmv(homo_data, indices, indptr, events, shape=shape, transpose=transpose) -# r2 = bm.event.csrmv_taichi(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - -# assert (bm.allclose(r1, r2[0])) - -# bm.clear_buffer_memory() - - -# def test_homo_vmap(shape, transpose, homo_data): -# print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - -# rng = bm.random.RandomState() -# indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - -# # vmap 'data' -# events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 -# f1 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, -# shape=shape, transpose=transpose)) -# f2 = jax.vmap(partial(bm.event.csrmv_taichi, indices=indices, indptr=indptr, events=events, -# shape=shape, transpose=transpose)) -# vmap_data = bm.as_jax([homo_data] * 10) -# assert(bm.allclose(f1(vmap_data), f2(vmap_data)[0])) - -# # vmap 'events' -# f3 = jax.vmap(partial(bm.event.csrmv, homo_data, indices, indptr, -# shape=shape, transpose=transpose)) -# f4 = jax.vmap(partial(bm.event.csrmv_taichi, homo_data, indices, indptr, -# shape=shape, transpose=transpose)) -# vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 -# assert(bm.allclose(f3(vmap_data), f4(vmap_data)[0])) - -# # vmap 'data' and 'events' -# f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) -# f6 = jax.vmap(lambda dd, ee: bm.event.csrmv_taichi(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - -# vmap_data1 = bm.as_jax([homo_data] * 10) -# vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 -# assert(bm.allclose(f5(vmap_data1, vmap_data2), -# f6(vmap_data1, vmap_data2)[0])) - -# bm.clear_buffer_memory() - - -# def test_homo_grad(shape, transpose, homo_data): -# print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - -# rng = bm.random.RandomState() -# indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 -# dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) - -# # grad 'data' -# r1 = jax.grad(sum_op(bm.event.csrmv))( -# homo_data, indices, indptr, events, shape=shape, transpose=transpose) -# r2 = jax.grad(sum_op2(bm.event.csrmv_taichi))( -# homo_data, indices, indptr, events, shape=shape, transpose=transpose) -# assert(bm.allclose(r1, r2)) - -# # grad 'events' -# r3 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( -# homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) -# r4 = jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)( -# homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) -# assert(bm.allclose(r3, r4)) - -# bm.clear_buffer_memory() - - -# def test_heter(shape, transpose): -# print(f'test_heter: shape = {shape}, transpose = {transpose}') -# rng = bm.random.RandomState() -# indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 -# heter_data = bm.as_jax(rng.random(indices.shape)) - -# r1 = bm.event.csrmv(heter_data, indices, indptr, events, -# shape=shape, transpose=transpose) -# r2 = bm.event.csrmv_taichi(heter_data, indices, indptr, events, -# shape=shape, transpose=transpose) - -# assert(bm.allclose(r1, r2[0])) - -# bm.clear_buffer_memory() - - -# def test_heter_vmap(shape, transpose): -# print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}') - -# rng = bm.random.RandomState() -# indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) - -# # vmap 'data' -# events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 -# f1 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, -# shape=shape, transpose=transpose)) -# f2 = jax.vmap(partial(bm.event.csrmv_taichi, indices=indices, indptr=indptr, events=events, -# shape=shape, transpose=transpose)) -# vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) -# assert(bm.allclose(f1(vmap_data), f2(vmap_data)[0])) - -# # vmap 'events' -# data = bm.as_jax(rng.random(indices.shape)) -# f3 = jax.vmap(partial(bm.event.csrmv, data, indices, indptr, -# shape=shape, transpose=transpose)) -# f4 = jax.vmap(partial(bm.event.csrmv_taichi, data, indices, indptr, -# shape=shape, transpose=transpose)) -# vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 -# assert(bm.allclose(f3(vmap_data), f4(vmap_data)[0])) - -# # vmap 'data' and 'events' -# f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, -# shape=shape, transpose=transpose)) -# f6 = jax.vmap(lambda dd, ee: bm.event.csrmv_taichi(dd, indices, indptr, ee, -# shape=shape, transpose=transpose)) -# vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0]))) -# vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 -# assert(bm.allclose(f5(vmap_data1, vmap_data2), -# f6(vmap_data1, vmap_data2)[0])) - -# bm.clear_buffer_memory() - - -# def test_heter_grad(shape, transpose): -# print(f'test_heter_grad: shape = {shape}, transpose = {transpose}') - -# rng = bm.random.RandomState() -# indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# events = rng.random(shape[0] if transpose else shape[1]) < 0.1 -# events = bm.as_jax(events) -# dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) - -# # grad 'data' -# data = bm.as_jax(rng.random(indices.shape)) -# r1 = jax.grad(sum_op(bm.event.csrmv))( -# data, indices, indptr, events, shape=shape, transpose=transpose) -# r2 = jax.grad(sum_op2(bm.event.csrmv_taichi))( -# data, indices, indptr, events, shape=shape, transpose=transpose) -# assert(bm.allclose(r1, r2)) - -# # grad 'events' -# r3 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( -# data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) -# r4 = jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)( -# data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) -# assert(bm.allclose(r3, r4)) - -# r5 = jax.grad(sum_op(bm.event.csrmv), argnums=(0, 3))( -# data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) -# r6 = jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=(0, 3))( -# data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) -# assert(bm.allclose(r5[0], r6[0])) -# assert(bm.allclose(r5[1], r6[1])) - -# bm.clear_buffer_memory() - -# def test_all(): -# for transpose in transposes: -# for shape in shapes: -# for homo_data in homo_datas: -# test_homo(shape, transpose, homo_data) -# test_homo_vmap(shape, transpose, homo_data) -# test_homo_grad(shape, transpose, homo_data) - -# for transpose in transposes: -# for shape in shapes: -# test_heter(shape, transpose) -# test_heter_vmap(shape, transpose) -# test_heter_grad(shape, transpose) -# test_all() - - -### PYTEST class Test_event_csr_matvec_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): super(Test_event_csr_matvec_taichi, self).__init__(*args, **kwargs) diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec_taichi.py b/brainpy/_src/math/jitconn/tests/test_event_matvec_taichi.py index 8d03fe1e6..e42434e95 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec_taichi.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec_taichi.py @@ -1,18 +1,12 @@ # -*- coding: utf-8 -*- -import sys import jax import jax.numpy as jnp -import pytest from absl.testing import parameterized import brainpy.math as bm -is_manual_test = False -if sys.platform.startswith('darwin') and not is_manual_test: - pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) - shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] shapes = [(100, 200), (2, 1000), (1000, 2)] diff --git a/brainpy/_src/math/jitconn/tests/test_matvec_taichi.py b/brainpy/_src/math/jitconn/tests/test_matvec_taichi.py index eb56b0bee..380db3cf5 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec_taichi.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec_taichi.py @@ -1,251 +1,16 @@ # -*- coding: utf-8 -*- -import sys import jax import jax.numpy as jnp -import pytest from absl.testing import parameterized import brainpy.math as bm -is_manual_test = False -if sys.platform.startswith('darwin') and not is_manual_test: - pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) - shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] shapes = [(100, 200), (2, 1000), (1000, 2)] -# def sum_op(op): -# def func(*args, **kwargs): -# r = op(*args, **kwargs) -# return r.sum() - -# return func - - -# def sum_op2(op): -# def func(*args, **kwargs): -# r = op(*args, **kwargs)[0] -# return r.sum() - -# return func - -# def test_homo(shape, transpose, outdim_parallel, prob, homo_data, seed=None): -# print(f'test_homo: ' -# f'shape = {shape}, ' -# f'transpose = {transpose}, ' -# f'outdim_parallel = {outdim_parallel}, ' -# f'prob={prob}, ' -# f'homo_data = {homo_data}') - -# rng = bm.random.RandomState() -# vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - -# r1 = bm.jitconn.mv_prob_homo_taichi(vector, -# homo_data, -# conn_prob=prob, -# shape=shape, -# seed=seed, -# outdim_parallel=outdim_parallel, -# transpose=transpose) - -# r2 = bm.jitconn.mv_prob_homo_taichi(vector, -# homo_data, -# conn_prob=prob, -# shape=shape, -# seed=seed, -# outdim_parallel=outdim_parallel, -# transpose=transpose) -# assert (jnp.allclose(r1, r2)) - -# r2 = bm.jitconn.mv_prob_homo_taichi(vector, -# homo_data, -# conn_prob=prob, -# shape=(shape[1], shape[0]), -# seed=seed, -# outdim_parallel=outdim_parallel, -# transpose=not transpose) -# assert (jnp.allclose(r1, r2)) - -# bm.clear_buffer_memory() - -# def test_homo_vmap(shape, transpose, outdim_parallel, prob, seed=None): -# print(f'test_homo_vmap: ' -# f'shape = {shape}, ' -# f'transpose = {transpose}, ' -# f'outdim_parallel = {outdim_parallel}, ' -# f'prob={prob}') - -# rng = bm.random.RandomState() -# events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) -# weights = bm.as_jax(rng.random(10)) - -# f1 = jax.vmap( -# lambda event, data: bm.jitconn.mv_prob_homo_taichi( -# event, data, -# conn_prob=prob, shape=shape, seed=seed, -# outdim_parallel=outdim_parallel, transpose=transpose -# )[0] -# ) -# r1 = f1(events, weights) -# r2 = f1(events, weights) -# assert (jnp.allclose(r1, r2)) - -# bm.clear_buffer_memory() - -# def test_uniform(shape, transpose, outdim_parallel, prob, w_low, w_high, seed=None): -# print(f'test_uniform: ' -# f'shape = {shape}, ' -# f'transpose = {transpose}, ' -# f'outdim_parallel = {outdim_parallel}, ' -# f'prob={prob}, ' -# f'w_low = {w_low}, ' -# f'w_high = {w_high}, ') - -# rng = bm.random.RandomState() -# events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - -# r1 = bm.jitconn.mv_prob_uniform_taichi(events, -# w_low=w_low, -# w_high=w_high, -# conn_prob=prob, -# shape=shape, -# seed=seed, -# outdim_parallel=outdim_parallel, -# transpose=transpose) - -# r2 = bm.jitconn.mv_prob_uniform_taichi(events, -# w_low=w_low, -# w_high=w_high, -# conn_prob=prob, -# shape=shape, -# seed=seed, -# outdim_parallel=outdim_parallel, -# transpose=transpose) -# c = jnp.allclose(r1, r2) -# if not c: -# print(r1, r2) -# assert (c) - -# r2 = bm.jitconn.mv_prob_uniform_taichi(events, -# w_low=w_low, -# w_high=w_high, -# conn_prob=prob, -# shape=(shape[1], shape[0]), -# seed=seed, -# outdim_parallel=outdim_parallel, -# transpose=not transpose) -# c = jnp.allclose(r1, r2) -# if not c: -# print(r1, r2) -# assert (c) - -# bm.clear_buffer_memory() - -# test_homo(shape=(100, 200), transpose=True, outdim_parallel=True, prob=0.1, homo_data=1., seed=1234) -# test_homo_vmap(shape=(100, 200), transpose=True, outdim_parallel=True, prob=0.1, seed=1234) - -# test_uniform(shape=(100, 200), transpose=True, outdim_parallel=False, prob=0.1, w_low=-1., w_high=0., seed=1234) - -# def test_homo_grad(shape, transpose, outdim_parallel, prob, seed=None): -# print(f'_test_homo_grad: ' -# f'shape = {shape}, ' -# f'transpose = {transpose}, ' -# f'outdim_parallel = {outdim_parallel}, ' -# f'prob={prob}') - -# rng = bm.random.RandomState() -# events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5 -# events = events.astype(float) - -# f1 = jax.grad( -# lambda event, data: bm.jitconn.mv_prob_homo_taichi( -# event, data, -# conn_prob=prob, -# shape=shape, -# seed=seed, -# outdim_parallel=outdim_parallel, -# transpose=transpose -# )[0].sum(), -# argnums=0 -# ) -# r1 = f1(events, 1.) -# r2 = f1(events, 2.) - -# print(r1 *2 - r2) -# assert (jnp.allclose(r1 * 2., r2)) - -# bm.clear_buffer_memory() - - -# def test_normal_grad(shape, transpose, outdim_parallel, prob, seed=None): -# print(f'_test_normal_grad: ' -# f'shape = {shape}, ' -# f'transpose = {transpose}, ' -# f'outdim_parallel = {outdim_parallel}, ' -# f'prob={prob}') - -# rng = bm.random.RandomState() -# events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 -# events = events.astype(float) - -# f1 = jax.grad( -# lambda e, w_sigma: bm.jitconn.mv_prob_normal_taichi( -# e, -# w_mu=0., -# w_sigma=w_sigma, -# conn_prob=prob, -# shape=shape, -# seed=seed, -# outdim_parallel=outdim_parallel, -# transpose=transpose -# )[0].sum() -# ) -# r1 = f1(events, 1.) -# r2 = f1(events, 2.) -# print(r1 *2 - r2) -# assert (bm.allclose(r1 * 2., r2)) - -# bm.clear_buffer_memory() - -# def test_uniform_grad(shape, transpose, outdim_parallel, prob, seed=None): -# print(f'_test_uniform_grad: ' -# f'shape = {shape}, ' -# f'transpose = {transpose}, ' -# f'outdim_parallel = {outdim_parallel}, ' -# f'prob={prob}') - - -# rng = bm.random.RandomState() -# events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - -# f1 = jax.grad( -# lambda e, w_low, w_high: bm.jitconn.mv_prob_uniform_taichi( -# e, -# w_low=w_low, -# w_high=w_high, -# conn_prob=prob, -# shape=shape, -# seed=seed, -# outdim_parallel=outdim_parallel, -# transpose=transpose -# )[0].sum() -# ) - -# r1 = f1(events, 0., 1.) -# r2 = f1(events, 0., 2.) -# print(r1 *2 - r2) -# assert (bm.allclose(r1 * 2., r2)) - -# bm.clear_buffer_memory() - -# test_homo_grad(shape=(100, 200), transpose=True, outdim_parallel=True, prob=0.1, seed=1234) -# test_normal_grad(shape=(100, 200), transpose=True, outdim_parallel=True, prob=0.1, seed=1234) -# test_uniform_grad(shape=(100, 200), transpose=True, outdim_parallel=False, prob=0.1, seed=1234) - - class Test_matvec_prob_conn(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): super(Test_matvec_prob_conn, self).__init__(*args, **kwargs) diff --git a/brainpy/_src/math/op_register/tests/test_taichi_based.py b/brainpy/_src/math/op_register/tests/test_taichi_based.py index 14ee77a81..7f405ec12 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_based.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_based.py @@ -1,48 +1,30 @@ import jax import jax.numpy as jnp -import taichi as taichi -import pytest -import platform +import taichi as ti import brainpy.math as bm bm.set_platform('cpu') -if not platform.platform().startswith('Windows'): - pytest.skip(allow_module_level=True) - - -# @ti.kernel -# def event_ell_cpu(indices: ti.types.ndarray(ndim=2), -# vector: ti.types.ndarray(ndim=1), -# weight: ti.types.ndarray(ndim=1), -# out: ti.types.ndarray(ndim=1)): -# weight_0 = weight[0] -# num_rows, num_cols = indices.shape -# ti.loop_config(serialize=True) -# for i in range(num_rows): -# if vector[i]: -# for j in range(num_cols): -# out[indices[i, j]] += weight_0 - -@taichi.func -def get_weight(weight: taichi.types.ndarray(ndim=1)) -> taichi.f32: + +@ti.func +def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32: return weight[0] -@taichi.func -def update_output(out: taichi.types.ndarray(ndim=1), index: taichi.i32, weight_val: taichi.f32): +@ti.func +def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32): out[index] += weight_val -@taichi.kernel -def event_ell_cpu(indices: taichi.types.ndarray(ndim=2), - vector: taichi.types.ndarray(ndim=1), - weight: taichi.types.ndarray(ndim=1), - out: taichi.types.ndarray(ndim=1)): +@ti.kernel +def event_ell_cpu(indices: ti.types.ndarray(ndim=2), + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): weight_val = get_weight(weight) num_rows, num_cols = indices.shape - taichi.loop_config(serialize=True) + ti.loop_config(serialize=True) for i in range(num_rows): if vector[i]: for j in range(num_cols): diff --git a/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py b/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py index 1c603da01..2b3d7b5b0 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py @@ -1,22 +1,13 @@ # -*- coding: utf-8 -*- -import sys from functools import partial import jax -import pytest from absl.testing import parameterized import brainpy as bp import brainpy.math as bm -# pytestmark = pytest.mark.skip(reason="Skipped due to pytest limitations, manual execution required for testing.") - - -is_manual_test = False -if sys.platform.startswith('darwin') and not is_manual_test: - pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) - # bm.set_platform('gpu') seed = 1234