Skip to content

Commit

Permalink
shrink operator (#5022)
Browse files Browse the repository at this point in the history
  • Loading branch information
brightening-eyes authored Sep 20, 2023
1 parent 0ddc34f commit 7e5fa3a
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 0 deletions.
16 changes: 16 additions & 0 deletions docs/developer-guide/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
* [RNN](#rnn)
* [Scale](#scale)
* [SELU](#selu)
* [Shrink](#shrink)
* [ShuffleChannel](#shufflechannel)
* [Sigmoid](#sigmoid)
* [Slice](#slice)
Expand Down Expand Up @@ -1675,6 +1676,21 @@ else y = x * lambda
| 0 | alpha | float | 1.67326324f| |
| 1 | lambda | float | 1.050700987f| |

# Shrink
```
if x < -lambd y = x + bias
if x > lambd y = x - bias
else y = x
```

* one_blob_only
* support_inplace

| param id | name | type | default | description |
| --------- | ------------- | ----- | --------- | ----------------- |
| 0 | bias | float | 0.0f | |
| 1 | lambd | float | 0.5f | |

# ShuffleChannel
```
if reverse == 0 y = shufflechannel(x) by group
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ ncnn_add_layer(CopyTo)
ncnn_add_layer(Erf)
ncnn_add_layer(Diag)
ncnn_add_layer(CELU)
ncnn_add_layer(Shrink)

if(NCNN_VULKAN)
ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp)
Expand Down
54 changes: 54 additions & 0 deletions src/layer/shrink.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "shrink.h"

namespace ncnn {

Shrink::Shrink()
{
one_blob_only = true;
support_inplace = true;
}

int Shrink::load_param(const ParamDict& pd)
{
bias = pd.get(0, 0.0f);
lambd = pd.get(1, 0.5f);

return 0;
}

int Shrink::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
int w = bottom_top_blob.w;
int h = bottom_top_blob.h;
int channels = bottom_top_blob.c;
int size = w * h;

#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
{
float* ptr = bottom_top_blob.channel(q);

for (int i = 0; i < size; i++)
{
ptr[i] = ptr[i] < -lambd ? ptr[i] + bias : ptr[i] > lambd ? ptr[i] - bias : ptr[i];
}
}

return 0;
}

} // namespace ncnn
37 changes: 37 additions & 0 deletions src/layer/shrink.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef LAYER_SHRINK_H
#define LAYER_SHRINK_H

#include "layer.h"

namespace ncnn {

class Shrink : public Layer
{
public:
Shrink();

virtual int load_param(const ParamDict& pd);
virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const;

public:
float bias;
float lambd;
};

} // namespace ncnn

#endif // LAYER_SHRINK_H
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ ncnn_add_layer_test(ROIPooling)
ncnn_add_layer_test(ROIAlign)
ncnn_add_layer_test(Scale)
ncnn_add_layer_test(SELU)
ncnn_add_layer_test(Shrink)
ncnn_add_layer_test(ShuffleChannel)
ncnn_add_layer_test(Sigmoid)
ncnn_add_layer_test(Slice)
Expand Down
76 changes: 76 additions & 0 deletions tests/test_shrink.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "layer/shrink.h"
#include "testutil.h"

static int test_shrink(const ncnn::Mat& a, float lambd, float bias)
{
ncnn::ParamDict pd;
pd.set(0, bias);
pd.set(1, lambd);

std::vector<ncnn::Mat> weights(0);

int ret = test_layer<ncnn::Shrink>("Shrink", pd, weights, a);
if (ret != 0)
{
fprintf(stderr, "test_shrink failed a.dims=%d a=(%d %d %d %d)\n", a.dims, a.w, a.h, a.d, a.c);
}

return ret;
}

static int test_shrink_0()
{
return 0
|| test_shrink(RandomMat(12, 24, 8, 3), 0.5f, 5.0f)
|| test_shrink(RandomMat(9, 7, 16, 4), 1.0f, 0.3f)
|| test_shrink(RandomMat(6, 9, 4, 3), 4.5, 6.1);
}

static int test_shrink_1()
{
return 0
|| test_shrink(RandomMat(12, 6, 24), 0.5f, 5.0f)
|| test_shrink(RandomMat(7, 8, 24), 1.0f, 0.3f)
|| test_shrink(RandomMat(3, 4, 5), 4.5, 6.1);
}

static int test_shrink_2()
{
return 0
|| test_shrink(RandomMat(5, 7), 3.4f, 0.3f)
|| test_shrink(RandomMat(7, 9), 3.1f, 4.0f)
|| test_shrink(RandomMat(3, 5), 2.0f, 4.0f);
}

static int test_shrink_3()
{
return 0
|| test_shrink(RandomMat(25), 3.4f, 0.3f)
|| test_shrink(RandomMat(63), 3.1f, 4.0f)
|| test_shrink(RandomMat(1024), 2.0f, 4.0f);
}

int main()
{
SRAND(7767517);

return 0
|| test_shrink_0()
|| test_shrink_1()
|| test_shrink_2()
|| test_shrink_3();
}
11 changes: 11 additions & 0 deletions tools/onnx/onnx2ncnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3894,6 +3894,10 @@ int main(int argc, char** argv)
{
fprintf(pp, "%-16s", "BinaryOp");
}
else if (op == "Shrink")
{
fprintf(pp, "%-16s", "Shrink");
}
else if (op == "ShuffleChannel")
{
fprintf(pp, "%-16s", "ShuffleChannel");
Expand Down Expand Up @@ -5746,6 +5750,13 @@ int main(int argc, char** argv)
fprintf(pp, " 2=%e", b);
}
}
else if (op == "Shrink")
{
float bias = get_node_attr_f(node, "bias", 0.0f);
float lambd = get_node_attr_f(node, "lambd", 0.5f);
fprintf(pp, " 0=%e", bias);
fprintf(pp, " 1=%e", lambd);
}
else if (op == "ShuffleChannel")
{
int group = get_node_attr_i(node, "group", 1);
Expand Down

0 comments on commit 7e5fa3a

Please sign in to comment.