Skip to content

Commit

Permalink
elu4d
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Sep 21, 2023
1 parent fdf2c48 commit 546bfd0
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 18 deletions.
3 changes: 2 additions & 1 deletion src/layer/elu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ int ELU::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
int w = bottom_top_blob.w;
int h = bottom_top_blob.h;
int d = bottom_top_blob.d;
int channels = bottom_top_blob.c;
int size = w * h;
int size = w * h * d;

#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
Expand Down
3 changes: 2 additions & 1 deletion src/layer/x86/elu_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ int ELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
int w = bottom_top_blob.w;
int h = bottom_top_blob.h;
int d = bottom_top_blob.d;
int channels = bottom_top_blob.c;
int elempack = bottom_top_blob.elempack;
int size = w * h * elempack;
int size = w * h * d * elempack;

#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
Expand Down
22 changes: 14 additions & 8 deletions tools/pnnx/tests/ncnn/test_F_elu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,36 @@ class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
def forward(self, x, y, z, w):
x = x * 2 - 1
y = y * 2 - 1
z = z * 2 - 1
w = w * 2 - 1
x = F.elu(x)
y = F.elu(y, 1.2)
z = F.elu(z, -0.6)
return x, y, z
w = F.elu(w, 0.1)
return x, y, z, w

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 16)
y = torch.rand(1, 2, 16)
z = torch.rand(1, 3, 12, 16)
x = torch.rand(16)
y = torch.rand(2, 16)
z = torch.rand(3, 12, 16)
w = torch.rand(5, 7, 9, 11)

a = net(x, y, z)
a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_F_elu.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_F_elu.pt inputshape=[1,16],[1,2,16],[1,3,12,16]")
os.system("../../src/pnnx test_F_elu.pt inputshape=[16],[2,16],[3,12,16],[5,7,9,11]")

# ncnn inference
import test_F_elu_ncnn
Expand Down
22 changes: 14 additions & 8 deletions tools/pnnx/tests/ncnn/test_nn_ELU.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,36 @@ def __init__(self):
self.act_0 = nn.ELU()
self.act_1 = nn.ELU(alpha=1.3)

def forward(self, x, y, z):
def forward(self, x, y, z, w):
x = x * 2 - 1
y = y * 2 - 1
z = z * 2 - 1
w = w * 2 - 1
x = self.act_0(x)
y = self.act_0(y)
z = self.act_1(z)
return x, y, z
w = self.act_1(w)
return x, y, z, w

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 12)
y = torch.rand(1, 12, 64)
z = torch.rand(1, 12, 24, 64)
x = torch.rand(12)
y = torch.rand(12, 64)
z = torch.rand(12, 24, 64)
w = torch.rand(12, 24, 32, 64)

a = net(x, y, z)
a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_nn_ELU.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_nn_ELU.pt inputshape=[1,12],[1,12,64],[1,12,24,64]")
os.system("../../src/pnnx test_nn_ELU.pt inputshape=[12],[12,64],[12,24,64],[12,24,32,64]")

# ncnn inference
import test_nn_ELU_ncnn
Expand Down
4 changes: 4 additions & 0 deletions tools/pnnx/tests/test_F_elu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z, w):
x = x * 2 - 1
y = y * 2 - 1
z = z * 2 - 1
w = w * 2 - 1
x = F.elu(x)
y = F.elu(y, 1.2)
z = F.elu(z, -0.6)
Expand Down
4 changes: 4 additions & 0 deletions tools/pnnx/tests/test_nn_ELU.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def __init__(self):
self.act_1 = nn.ELU(alpha=1.3)

def forward(self, x, y, z, w):
x = x * 2 - 1
y = y * 2 - 1
z = z * 2 - 1
w = w * 2 - 1
x = self.act_0(x)
y = self.act_0(y)
z = self.act_1(z)
Expand Down

0 comments on commit 546bfd0

Please sign in to comment.