Skip to content

Commit

Permalink
Standardizing and generalizing object-oriented transformations (#628)
Browse files Browse the repository at this point in the history
* test improvement

* remove pytorch add

* variable evaluation using `brainpy.math.eval_shape`

* fix bugs

* update transformations

* remove `new_transform` API

* update

* update

* fix

* fix

* fix bugs

* fix bugs

* updates

* updates

* upgrade

* add `brainpy.math.VariableStack`
  • Loading branch information
chaoming0625 authored Feb 22, 2024
1 parent 48455e5 commit 4d74816
Show file tree
Hide file tree
Showing 20 changed files with 208 additions and 686 deletions.
11 changes: 10 additions & 1 deletion brainpy/_src/dnn/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def update(self, x):
nonbatching = False
if x.ndim == self.num_spatial_dims + 1:
nonbatching = True
x = x.unsqueeze(0)
x = bm.unsqueeze(x, 0)
w = self.w.value
if self.mask is not None:
try:
Expand Down Expand Up @@ -190,6 +190,9 @@ def __repr__(self):
class Conv1d(_GeneralConv):
"""One-dimensional convolution.
The input should a 2d array with the shape of ``[H, C]``, or
a 3d array with the shape of ``[B, H, C]``, where ``H`` is the feature size.
Parameters
----------
in_channels: int
Expand Down Expand Up @@ -282,6 +285,9 @@ def _check_input_dim(self, x):
class Conv2d(_GeneralConv):
"""Two-dimensional convolution.
The input should a 3d array with the shape of ``[H, W, C]``, or
a 4d array with the shape of ``[B, H, W, C]``.
Parameters
----------
in_channels: int
Expand Down Expand Up @@ -375,6 +381,9 @@ def _check_input_dim(self, x):
class Conv3d(_GeneralConv):
"""Three-dimensional convolution.
The input should a 3d array with the shape of ``[H, W, D, C]``, or
a 4d array with the shape of ``[B, H, W, D, C]``.
Parameters
----------
in_channels: int
Expand Down
3 changes: 2 additions & 1 deletion brainpy/_src/dnn/tests/test_activation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from absl.testing import parameterized
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm

Expand Down
11 changes: 6 additions & 5 deletions brainpy/_src/dnn/tests/test_conv_layers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# -*- coding: utf-8 -*-

from unittest import TestCase
from absl.testing import absltest
import jax.numpy as jnp
import brainpy.math as bm
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm


class TestConv(parameterized.TestCase):
def test_Conv2D_img(self):
bm.random.seed()
img = jnp.zeros((2, 200, 198, 4))
for k in range(4):
x = 30 + 60 * k
Expand All @@ -24,21 +22,22 @@ def test_Conv2D_img(self):
strides=(2, 1), padding='VALID', groups=4)
out = net(img)
print("out shape: ", out.shape)
self.assertEqual(out.shape, (2, 99, 196, 32))
# print("First output channel:")
# plt.figure(figsize=(10, 10))
# plt.imshow(np.array(img)[0, :, :, 0])
# plt.show()
bm.clear_buffer_memory()

def test_conv1D(self):
bm.random.seed()
with bp.math.training_environment():
model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,))

input = bp.math.ones((2, 5, 3))

out = model(input)
print("out shape: ", out.shape)
self.assertEqual(out.shape, (2, 5, 32))
# print("First output channel:")
# plt.figure(figsize=(10, 10))
# plt.imshow(np.array(out)[0, :, :])
Expand All @@ -54,6 +53,7 @@ def test_conv2D(self):

out = model(input)
print("out shape: ", out.shape)
self.assertEqual(out.shape, (2, 5, 5, 32))
# print("First output channel:")
# plt.figure(figsize=(10, 10))
# plt.imshow(np.array(out)[0, :, :, 31])
Expand All @@ -67,6 +67,7 @@ def test_conv3D(self):
input = bp.math.ones((2, 5, 5, 5, 3))
out = model(input)
print("out shape: ", out.shape)
self.assertEqual(out.shape, (2, 5, 5, 5, 32))
bm.clear_buffer_memory()


Expand Down
6 changes: 2 additions & 4 deletions brainpy/_src/dnn/tests/test_function.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# -*- coding: utf-8 -*-

from unittest import TestCase

import jax.numpy as jnp
import brainpy.math as bm
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm


class TestFunction(parameterized.TestCase):
Expand Down
5 changes: 3 additions & 2 deletions brainpy/_src/dnn/tests/test_linear.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import brainpy as bp
from absl.testing import parameterized
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm


Expand Down
5 changes: 3 additions & 2 deletions brainpy/_src/dnn/tests/test_mode.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import brainpy.math as bm
from absl.testing import parameterized
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm


class Test_Conv(parameterized.TestCase):
Expand Down
5 changes: 3 additions & 2 deletions brainpy/_src/dnn/tests/test_normalization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import brainpy.math as bm
from absl.testing import parameterized
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm


class Test_Normalization(parameterized.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dnn/tests/test_pooling_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import jax
import jax.numpy as jnp
import numpy as np
from absl.testing import parameterized
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm
Expand Down
45 changes: 14 additions & 31 deletions brainpy/_src/math/object_transform/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@
get_stack_cache,
cache_stack)
from .base import (BrainPyObject, ObjectTransform)
from .variables import (Variable,
VariableStack,
current_transform_number,
new_transform)
from .variables import (Variable, VariableStack)
from .tools import eval_shape

__all__ = [
'grad', # gradient of scalar function
Expand Down Expand Up @@ -203,36 +201,21 @@ def __call__(self, *args, **kwargs):
elif not self._eval_dyn_vars: # evaluate dynamical variables
stack = get_stack_cache(self.target)
if stack is None:
with new_transform(self):
with VariableStack() as stack:
if current_transform_number() > 1:
rets = self._transform(
[v.value for v in self._grad_vars], # variables for gradients
{}, # dynamical variables
*args,
**kwargs
)
else:
rets = jax.eval_shape(
self._transform,
[v.value for v in self._grad_vars], # variables for gradients
{}, # dynamical variables
*args,
**kwargs
)
with VariableStack() as stack:
rets = eval_shape(self._transform,
[v.value for v in self._grad_vars], # variables for gradients
{}, # dynamical variables
*args,
**kwargs)
cache_stack(self.target, stack)

self._dyn_vars = stack
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
self._eval_dyn_vars = True
self._dyn_vars = stack
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
self._eval_dyn_vars = True

# if not the outermost transformation
if current_transform_number():
return self._return(rets)
else:
self._dyn_vars = stack
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
self._eval_dyn_vars = True
# if not the outermost transformation
if not stack.is_first_stack():
return self._return(rets)

rets = self._transform(
[v.value for v in self._grad_vars], # variables for gradients
Expand Down
4 changes: 1 addition & 3 deletions brainpy/_src/math/object_transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,20 @@
"""

import numbers
import os
import warnings
from collections import namedtuple
from typing import Any, Tuple, Callable, Sequence, Dict, Union, Optional

import jax
import numpy as np

from brainpy import errors
from brainpy._src.math.modes import Mode
from brainpy._src.math.ndarray import (Array, )
from brainpy._src.math.object_transform.collectors import (ArrayCollector, Collector)
from brainpy._src.math.object_transform.naming import (get_unique_name,
check_name_uniqueness)
from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar,
VarList, VarDict)
from brainpy._src.math.modes import Mode
from brainpy._src.math.sharding import BATCH_AXIS

variable_ = None
Expand Down
Loading

0 comments on commit 4d74816

Please sign in to comment.