Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-1431] Multiple channel support in Gluon PReLU (#16262)
Browse files Browse the repository at this point in the history
* Multiple channel support in Gluon PReLU

* Update activations.py
  • Loading branch information
jonatanmil authored and szha committed Dec 12, 2019
1 parent 9092f17 commit f701f3f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
11 changes: 8 additions & 3 deletions python/mxnet/gluon/nn/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,23 @@ class PReLU(HybridBlock):
----------
alpha_initializer : Initializer
Initializer for the `embeddings` matrix.
in_channels : int, default 1
Number of channels (alpha parameters) to learn. Can either be 1
or `n` where `n` is the size of the second dimension of the input
tensor.
Inputs:
- **data**: input tensor with arbitrary shape.
Outputs:
- **out**: output tensor with the same shape as `data`.
"""
def __init__(self, alpha_initializer=initializer.Constant(0.25), **kwargs):
def __init__(self, alpha_initializer=initializer.Constant(0.25),
in_channels=1, **kwargs):
super(PReLU, self).__init__(**kwargs)
with self.name_scope():
self.alpha = self.params.get('alpha', shape=(1,), init=alpha_initializer)
self.alpha = self.params.get('alpha', shape=(in_channels,),
init=alpha_initializer)

def hybrid_forward(self, F, x, alpha):
return F.LeakyReLU(x, gamma=alpha, act_type='prelu', name='fwd')
Expand Down
5 changes: 5 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,11 @@ def selu(x):
x = point_to_validate.reshape((1, 3, 2))
assert_almost_equal(prelu(x).asnumpy(), mx.nd.where(x >= 0, x, 0.25 * x).asnumpy())

multichannel_init = mx.initializer.Constant(mx.nd.array([0.1, 0.25, 0.5]))
prelu_multichannel = mx.gluon.nn.PReLU(alpha_initializer=multichannel_init, in_channels=3)
prelu_multichannel.initialize()
assert_almost_equal(prelu_multichannel(x).asnumpy(), np.array([[-0.01, 0.1], [-0.025, 0.1], [-0.05, 0.1]]))

gelu = mx.gluon.nn.GELU()
def gelu_test(x):
CUBE_CONSTANT = 0.044715
Expand Down

0 comments on commit f701f3f

Please sign in to comment.