Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardizing and generalizing object-oriented transformations #628

Merged
merged 16 commits into from
Feb 22, 2024
11 changes: 10 additions & 1 deletion brainpy/_src/dnn/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def update(self, x):
nonbatching = False
if x.ndim == self.num_spatial_dims + 1:
nonbatching = True
x = x.unsqueeze(0)
x = bm.unsqueeze(x, 0)
w = self.w.value
if self.mask is not None:
try:
Expand Down Expand Up @@ -190,6 +190,9 @@ def __repr__(self):
class Conv1d(_GeneralConv):
"""One-dimensional convolution.

The input should a 2d array with the shape of ``[H, C]``, or
a 3d array with the shape of ``[B, H, C]``, where ``H`` is the feature size.

Parameters
----------
in_channels: int
Expand Down Expand Up @@ -282,6 +285,9 @@ def _check_input_dim(self, x):
class Conv2d(_GeneralConv):
"""Two-dimensional convolution.

The input should a 3d array with the shape of ``[H, W, C]``, or
a 4d array with the shape of ``[B, H, W, C]``.

Parameters
----------
in_channels: int
Expand Down Expand Up @@ -375,6 +381,9 @@ def _check_input_dim(self, x):
class Conv3d(_GeneralConv):
"""Three-dimensional convolution.

The input should a 3d array with the shape of ``[H, W, D, C]``, or
a 4d array with the shape of ``[B, H, W, D, C]``.

Parameters
----------
in_channels: int
Expand Down
3 changes: 2 additions & 1 deletion brainpy/_src/dnn/tests/test_activation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from absl.testing import parameterized
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm

Expand Down
11 changes: 6 additions & 5 deletions brainpy/_src/dnn/tests/test_conv_layers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# -*- coding: utf-8 -*-

from unittest import TestCase
from absl.testing import absltest
import jax.numpy as jnp
import brainpy.math as bm
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm


class TestConv(parameterized.TestCase):
def test_Conv2D_img(self):
bm.random.seed()
img = jnp.zeros((2, 200, 198, 4))
for k in range(4):
x = 30 + 60 * k
Expand All @@ -24,21 +22,22 @@ def test_Conv2D_img(self):
strides=(2, 1), padding='VALID', groups=4)
out = net(img)
print("out shape: ", out.shape)
self.assertEqual(out.shape, (2, 99, 196, 32))
# print("First output channel:")
# plt.figure(figsize=(10, 10))
# plt.imshow(np.array(img)[0, :, :, 0])
# plt.show()
bm.clear_buffer_memory()

def test_conv1D(self):
bm.random.seed()
with bp.math.training_environment():
model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,))

input = bp.math.ones((2, 5, 3))

out = model(input)
print("out shape: ", out.shape)
self.assertEqual(out.shape, (2, 5, 32))
# print("First output channel:")
# plt.figure(figsize=(10, 10))
# plt.imshow(np.array(out)[0, :, :])
Expand All @@ -54,6 +53,7 @@ def test_conv2D(self):

out = model(input)
print("out shape: ", out.shape)
self.assertEqual(out.shape, (2, 5, 5, 32))
# print("First output channel:")
# plt.figure(figsize=(10, 10))
# plt.imshow(np.array(out)[0, :, :, 31])
Expand All @@ -67,6 +67,7 @@ def test_conv3D(self):
input = bp.math.ones((2, 5, 5, 5, 3))
out = model(input)
print("out shape: ", out.shape)
self.assertEqual(out.shape, (2, 5, 5, 5, 32))
bm.clear_buffer_memory()


Expand Down
6 changes: 2 additions & 4 deletions brainpy/_src/dnn/tests/test_function.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# -*- coding: utf-8 -*-

from unittest import TestCase

import jax.numpy as jnp
import brainpy.math as bm
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm


class TestFunction(parameterized.TestCase):
Expand Down
5 changes: 3 additions & 2 deletions brainpy/_src/dnn/tests/test_linear.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import brainpy as bp
from absl.testing import parameterized
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm


Expand Down
5 changes: 3 additions & 2 deletions brainpy/_src/dnn/tests/test_mode.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import brainpy.math as bm
from absl.testing import parameterized
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm


class Test_Conv(parameterized.TestCase):
Expand Down
5 changes: 3 additions & 2 deletions brainpy/_src/dnn/tests/test_normalization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import brainpy.math as bm
from absl.testing import parameterized
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm


class Test_Normalization(parameterized.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dnn/tests/test_pooling_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import jax
import jax.numpy as jnp
import numpy as np
from absl.testing import parameterized
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm
Expand Down
45 changes: 14 additions & 31 deletions brainpy/_src/math/object_transform/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@
get_stack_cache,
cache_stack)
from .base import (BrainPyObject, ObjectTransform)
from .variables import (Variable,
VariableStack,
current_transform_number,
new_transform)
from .variables import (Variable, VariableStack)
from .tools import eval_shape

__all__ = [
'grad', # gradient of scalar function
Expand Down Expand Up @@ -203,36 +201,21 @@ def __call__(self, *args, **kwargs):
elif not self._eval_dyn_vars: # evaluate dynamical variables
stack = get_stack_cache(self.target)
if stack is None:
with new_transform(self):
with VariableStack() as stack:
if current_transform_number() > 1:
rets = self._transform(
[v.value for v in self._grad_vars], # variables for gradients
{}, # dynamical variables
*args,
**kwargs
)
else:
rets = jax.eval_shape(
self._transform,
[v.value for v in self._grad_vars], # variables for gradients
{}, # dynamical variables
*args,
**kwargs
)
with VariableStack() as stack:
rets = eval_shape(self._transform,
[v.value for v in self._grad_vars], # variables for gradients
{}, # dynamical variables
*args,
**kwargs)
cache_stack(self.target, stack)

self._dyn_vars = stack
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
self._eval_dyn_vars = True
self._dyn_vars = stack
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
self._eval_dyn_vars = True

# if not the outermost transformation
if current_transform_number():
return self._return(rets)
else:
self._dyn_vars = stack
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
self._eval_dyn_vars = True
# if not the outermost transformation
if not stack.is_first_stack():
return self._return(rets)

rets = self._transform(
[v.value for v in self._grad_vars], # variables for gradients
Expand Down
4 changes: 1 addition & 3 deletions brainpy/_src/math/object_transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,20 @@
"""

import numbers
import os
import warnings
from collections import namedtuple
from typing import Any, Tuple, Callable, Sequence, Dict, Union, Optional

import jax
import numpy as np

from brainpy import errors
from brainpy._src.math.modes import Mode
from brainpy._src.math.ndarray import (Array, )
from brainpy._src.math.object_transform.collectors import (ArrayCollector, Collector)
from brainpy._src.math.object_transform.naming import (get_unique_name,
check_name_uniqueness)
from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar,
VarList, VarDict)
from brainpy._src.math.modes import Mode
from brainpy._src.math.sharding import BATCH_AXIS

variable_ = None
Expand Down
Loading