Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update operators for compatible with brainpylib>=0.1.10 #468

Merged
merged 3 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.4.4.post2"
__version__ = "2.4.4.post3"

# fundamental supporting modules
from brainpy import errors, check, tools
Expand Down
104 changes: 54 additions & 50 deletions brainpy/_src/connect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
'SUPPORTED_SYN_STRUCTURE',

# the connection dtypes
'set_default_dtype', 'MAT_DTYPE', 'IDX_DTYPE',
'set_default_dtype', 'MAT_DTYPE', 'IDX_DTYPE', 'get_idx_type',

# brainpy_object class
'Connector', 'TwoEndConnector', 'OneEndConnector',
Expand Down Expand Up @@ -59,6 +59,10 @@
IDX_DTYPE = jnp.int32


def get_idx_type():
return IDX_DTYPE


def set_default_dtype(mat_dtype=None, idx_dtype=None):
"""Set the default dtype.

Expand Down Expand Up @@ -247,44 +251,44 @@ def _return_by_csr(self, structures, csr: tuple, all_data: dict):

if (PRE_IDS in structures) and (PRE_IDS not in all_data):
pre_ids = np.repeat(np.arange(self.pre_num), np.diff(indptr))
all_data[PRE_IDS] = bm.as_jax(pre_ids, dtype=IDX_DTYPE)
all_data[PRE_IDS] = bm.as_jax(pre_ids, dtype=get_idx_type())

if (POST_IDS in structures) and (POST_IDS not in all_data):
all_data[POST_IDS] = bm.as_jax(indices, dtype=IDX_DTYPE)
all_data[POST_IDS] = bm.as_jax(indices, dtype=get_idx_type())

if (COO in structures) and (COO not in all_data):
pre_ids = np.repeat(np.arange(self.pre_num), np.diff(indptr))
all_data[COO] = (bm.as_jax(pre_ids, dtype=IDX_DTYPE),
bm.as_jax(indices, dtype=IDX_DTYPE))
all_data[COO] = (bm.as_jax(pre_ids, dtype=get_idx_type()),
bm.as_jax(indices, dtype=get_idx_type()))

if (PRE2POST in structures) and (PRE2POST not in all_data):
all_data[PRE2POST] = (bm.as_jax(indices, dtype=IDX_DTYPE),
bm.as_jax(indptr, dtype=IDX_DTYPE))
all_data[PRE2POST] = (bm.as_jax(indices, dtype=get_idx_type()),
bm.as_jax(indptr, dtype=get_idx_type()))

if (CSR in structures) and (CSR not in all_data):
all_data[CSR] = (bm.as_jax(indices, dtype=IDX_DTYPE),
bm.as_jax(indptr, dtype=IDX_DTYPE))
all_data[CSR] = (bm.as_jax(indices, dtype=get_idx_type()),
bm.as_jax(indptr, dtype=get_idx_type()))

if (POST2PRE in structures) and (POST2PRE not in all_data):
indc, indptrc = csr2csc((indices, indptr), self.post_num)
all_data[POST2PRE] = (bm.as_jax(indc, dtype=IDX_DTYPE),
bm.as_jax(indptrc, dtype=IDX_DTYPE))
all_data[POST2PRE] = (bm.as_jax(indc, dtype=get_idx_type()),
bm.as_jax(indptrc, dtype=get_idx_type()))

if (CSC in structures) and (CSC not in all_data):
indc, indptrc = csr2csc((indices, indptr), self.post_num)
all_data[CSC] = (bm.as_jax(indc, dtype=IDX_DTYPE),
bm.as_jax(indptrc, dtype=IDX_DTYPE))
all_data[CSC] = (bm.as_jax(indc, dtype=get_idx_type()),
bm.as_jax(indptrc, dtype=get_idx_type()))

if (PRE2SYN in structures) and (PRE2SYN not in all_data):
syn_seq = np.arange(indices.size, dtype=IDX_DTYPE)
all_data[PRE2SYN] = (bm.as_jax(syn_seq, dtype=IDX_DTYPE),
bm.as_jax(indptr, dtype=IDX_DTYPE))
syn_seq = np.arange(indices.size, dtype=get_idx_type())
all_data[PRE2SYN] = (bm.as_jax(syn_seq, dtype=get_idx_type()),
bm.as_jax(indptr, dtype=get_idx_type()))

if (POST2SYN in structures) and (POST2SYN not in all_data):
syn_seq = np.arange(indices.size, dtype=IDX_DTYPE)
syn_seq = np.arange(indices.size, dtype=get_idx_type())
_, indptrc, syn_seqc = csr2csc((indices, indptr), self.post_num, syn_seq)
all_data[POST2SYN] = (bm.as_jax(syn_seqc, dtype=IDX_DTYPE),
bm.as_jax(indptrc, dtype=IDX_DTYPE))
all_data[POST2SYN] = (bm.as_jax(syn_seqc, dtype=get_idx_type()),
bm.as_jax(indptrc, dtype=get_idx_type()))

def _return_by_coo(self, structures, coo: tuple, all_data: dict):
pre_ids, post_ids = coo
Expand All @@ -293,24 +297,24 @@ def _return_by_coo(self, structures, coo: tuple, all_data: dict):
all_data[CONN_MAT] = bm.as_jax(coo2mat(coo, self.pre_num, self.post_num), dtype=MAT_DTYPE)

if (PRE_IDS in structures) and (PRE_IDS not in all_data):
all_data[PRE_IDS] = bm.as_jax(pre_ids, dtype=IDX_DTYPE)
all_data[PRE_IDS] = bm.as_jax(pre_ids, dtype=get_idx_type())

if (POST_IDS in structures) and (POST_IDS not in all_data):
all_data[POST_IDS] = bm.as_jax(post_ids, dtype=IDX_DTYPE)
all_data[POST_IDS] = bm.as_jax(post_ids, dtype=get_idx_type())

if (COO in structures) and (COO not in all_data):
all_data[COO] = (bm.as_jax(pre_ids, dtype=IDX_DTYPE),
bm.as_jax(post_ids, dtype=IDX_DTYPE))
all_data[COO] = (bm.as_jax(pre_ids, dtype=get_idx_type()),
bm.as_jax(post_ids, dtype=get_idx_type()))

if CSC in structures and CSC not in all_data:
csc = coo2csc(coo, self.post_num)
all_data[CSC] = (bm.as_jax(csc[0], dtype=IDX_DTYPE),
bm.as_jax(csc[1], dtype=IDX_DTYPE))
all_data[CSC] = (bm.as_jax(csc[0], dtype=get_idx_type()),
bm.as_jax(csc[1], dtype=get_idx_type()))

if POST2PRE in structures and POST2PRE not in all_data:
csc = coo2csc(coo, self.post_num)
all_data[POST2PRE] = (bm.as_jax(csc[0], dtype=IDX_DTYPE),
bm.as_jax(csc[1], dtype=IDX_DTYPE))
all_data[POST2PRE] = (bm.as_jax(csc[0], dtype=get_idx_type()),
bm.as_jax(csc[1], dtype=get_idx_type()))

if (len([s for s in structures
if s not in [CONN_MAT, PRE_IDS, POST_IDS,
Expand Down Expand Up @@ -350,8 +354,8 @@ def _make_returns(self, structures, conn_data):
# "csr" structure
if csr is not None:
if (PRE2POST in structures) and (PRE2POST not in all_data):
all_data[PRE2POST] = (bm.as_jax(csr[0], dtype=IDX_DTYPE),
bm.as_jax(csr[1], dtype=IDX_DTYPE))
all_data[PRE2POST] = (bm.as_jax(csr[0], dtype=get_idx_type()),
bm.as_jax(csr[1], dtype=get_idx_type()))
self._return_by_csr(structures, csr=csr, all_data=all_data)

# "mat" structure
Expand All @@ -364,9 +368,9 @@ def _make_returns(self, structures, conn_data):
# "coo" structure
if coo is not None:
if (PRE_IDS in structures) and (PRE_IDS not in structures):
all_data[PRE_IDS] = bm.as_jax(coo[0], dtype=IDX_DTYPE)
all_data[PRE_IDS] = bm.as_jax(coo[0], dtype=get_idx_type())
if (POST_IDS in structures) and (POST_IDS not in structures):
all_data[POST_IDS] = bm.as_jax(coo[1], dtype=IDX_DTYPE)
all_data[POST_IDS] = bm.as_jax(coo[1], dtype=get_idx_type())
self._return_by_coo(structures, coo=coo, all_data=all_data)

# return
Expand Down Expand Up @@ -416,34 +420,34 @@ def require(self, *structures):
if len(structures) == 1:
if PRE2POST in structures and _has_csr_imp:
r = self.build_csr()
return bm.as_jax(r[0], dtype=IDX_DTYPE), bm.as_jax(r[1], dtype=IDX_DTYPE)
return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type())
elif CSR in structures and _has_csr_imp:
r = self.build_csr()
return bm.as_jax(r[0], dtype=IDX_DTYPE), bm.as_jax(r[1], dtype=IDX_DTYPE)
return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type())
elif CONN_MAT in structures and _has_mat_imp:
return bm.as_jax(self.build_mat(), dtype=MAT_DTYPE)
elif PRE_IDS in structures and _has_coo_imp:
return bm.as_jax(self.build_coo()[0], dtype=IDX_DTYPE)
return bm.as_jax(self.build_coo()[0], dtype=get_idx_type())
elif POST_IDS in structures and _has_coo_imp:
return bm.as_jax(self.build_coo()[1], dtype=IDX_DTYPE)
return bm.as_jax(self.build_coo()[1], dtype=get_idx_type())
elif COO in structures and _has_coo_imp:
r = self.build_coo()
return bm.as_jax(r[0], dtype=IDX_DTYPE), bm.as_jax(r[1], dtype=IDX_DTYPE)
return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type())

elif len(structures) == 2:
if (PRE_IDS in structures and POST_IDS in structures and _has_coo_imp):
r = self.build_coo()
if structures[0] == PRE_IDS:
return bm.as_jax(r[0], dtype=IDX_DTYPE), bm.as_jax(r[1], dtype=IDX_DTYPE)
return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type())
else:
return bm.as_jax(r[1], dtype=IDX_DTYPE), bm.as_jax(r[0], dtype=IDX_DTYPE)
return bm.as_jax(r[1], dtype=get_idx_type()), bm.as_jax(r[0], dtype=get_idx_type())

if ((CSR in structures or PRE2POST in structures)
and _has_csr_imp and COO in structures and _has_coo_imp):
csr = self.build_csr()
csr = (bm.as_jax(csr[0], dtype=IDX_DTYPE), bm.as_jax(csr[1], dtype=IDX_DTYPE))
csr = (bm.as_jax(csr[0], dtype=get_idx_type()), bm.as_jax(csr[1], dtype=get_idx_type()))
coo = self.build_coo()
coo = (bm.as_jax(coo[0], dtype=IDX_DTYPE), bm.as_jax(coo[1], dtype=IDX_DTYPE))
coo = (bm.as_jax(coo[0], dtype=get_idx_type()), bm.as_jax(coo[1], dtype=get_idx_type()))
if structures[0] == COO:
return coo, csr
else:
Expand All @@ -452,7 +456,7 @@ def require(self, *structures):
if ((CSR in structures or PRE2POST in structures)
and _has_csr_imp and CONN_MAT in structures and _has_mat_imp):
csr = self.build_csr()
csr = (bm.as_jax(csr[0], dtype=IDX_DTYPE), bm.as_jax(csr[1], dtype=IDX_DTYPE))
csr = (bm.as_jax(csr[0], dtype=get_idx_type()), bm.as_jax(csr[1], dtype=get_idx_type()))
mat = bm.as_jax(self.build_mat(), dtype=MAT_DTYPE)
if structures[0] == CONN_MAT:
return mat, csr
Expand All @@ -461,7 +465,7 @@ def require(self, *structures):

if (COO in structures and _has_coo_imp and CONN_MAT in structures and _has_mat_imp):
coo = self.build_coo()
coo = (bm.as_jax(coo[0], dtype=IDX_DTYPE), bm.as_jax(coo[1], dtype=IDX_DTYPE))
coo = (bm.as_jax(coo[0], dtype=get_idx_type()), bm.as_jax(coo[1], dtype=get_idx_type()))
mat = bm.as_jax(self.build_mat(), dtype=MAT_DTYPE)
if structures[0] == COO:
return coo, mat
Expand Down Expand Up @@ -612,7 +616,7 @@ def mat2coo(dense):
pre_ids, post_ids = onp.where(dense > 0)
else:
pre_ids, post_ids = jnp.where(bm.as_jax(dense) > 0)
return pre_ids.astype(dtype=IDX_DTYPE), post_ids.astype(dtype=IDX_DTYPE)
return pre_ids.astype(dtype=get_idx_type()), post_ids.astype(dtype=get_idx_type())


def mat2csc(dense):
Expand Down Expand Up @@ -686,7 +690,7 @@ def coo2csr(coo, num_pre):
final_pre_count = bm.as_jax(final_pre_count)
indptr = final_pre_count.cumsum()
indptr = onp.insert(indptr, 0, 0)
return indices.astype(IDX_DTYPE), indptr.astype(IDX_DTYPE)
return indices.astype(get_idx_type()), indptr.astype(get_idx_type())


def coo2csc(coo, post_num, data=None):
Expand All @@ -695,31 +699,31 @@ def coo2csc(coo, post_num, data=None):
if isinstance(indices, onp.ndarray):
# to maintain the original order of the elements with the same value
sort_ids = onp.argsort(indices)
pre_ids_new = onp.asarray(pre_ids[sort_ids], dtype=IDX_DTYPE)
pre_ids_new = onp.asarray(pre_ids[sort_ids], dtype=get_idx_type())

unique_post_ids, count = onp.unique(indices, return_counts=True)
post_count = onp.zeros(post_num, dtype=IDX_DTYPE)
post_count = onp.zeros(post_num, dtype=get_idx_type())
post_count[unique_post_ids] = count

indptr_new = post_count.cumsum()
indptr_new = onp.insert(indptr_new, 0, 0)
indptr_new = onp.asarray(indptr_new, dtype=IDX_DTYPE)
indptr_new = onp.asarray(indptr_new, dtype=get_idx_type())

else:
pre_ids = bm.as_jax(pre_ids)
indices = bm.as_jax(indices)

# to maintain the original order of the elements with the same value
sort_ids = jnp.argsort(indices)
pre_ids_new = jnp.asarray(pre_ids[sort_ids], dtype=IDX_DTYPE)
pre_ids_new = jnp.asarray(pre_ids[sort_ids], dtype=get_idx_type())

unique_post_ids, count = jnp.unique(indices, return_counts=True)
post_count = bm.zeros(post_num, dtype=IDX_DTYPE)
post_count = bm.zeros(post_num, dtype=get_idx_type())
post_count[unique_post_ids] = count

indptr_new = post_count.value.cumsum()
indptr_new = jnp.insert(indptr_new, 0, 0)
indptr_new = jnp.asarray(indptr_new, dtype=IDX_DTYPE)
indptr_new = jnp.asarray(indptr_new, dtype=get_idx_type())

if data is None:
return pre_ids_new, indptr_new
Expand Down
Loading
Loading