Skip to content

Commit

Permalink
Add activation ops (PaddlePaddle#61)
Browse files Browse the repository at this point in the history
* add base activation ops

* add tanh test

* add gelu, tanh
  • Loading branch information
gglin001 authored Aug 15, 2021
1 parent ff4ba95 commit 1684973
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 11 deletions.
10 changes: 10 additions & 0 deletions paddle/fluid/framework/ipu/ipu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,16 @@ void IpuBackend::LowerBody(const ir::Graph* graph) {
auto outputs = op->Output("__outputs__");
auto result = builder_->aiOnnxOpset11().relu(inputs);
tensors_.emplace(outputs[0], result);
} else if (op_type == "Tanh") {
auto inputs = GetOpInputs(op);
auto outputs = op->Output("__outputs__");
auto result = builder_->aiOnnxOpset11().tanh(inputs);
tensors_.emplace(outputs[0], result);
} else if (op_type == "Gelu") {
auto inputs = GetOpInputs(op);
auto outputs = op->Output("__outputs__");
auto result = builder_->aiGraphcoreOpset1().gelu(inputs);
tensors_.emplace(outputs[0], result);
} else if (op_type == "BatchNormalization") {
auto inputs = GetOpInputs(op);
auto outputs = op->Output("__outputs__");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,53 @@
// limitations under the License.

#include "paddle/fluid/framework/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/framework/ipu/popart_canonicalization/op_builder.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {
namespace ipu {
namespace {

ir::Node *activation_op_handler(ir::Graph *graph, ir::Node *node,
const std::string &type) {
auto new_node =
CreateBaseOp(graph, type, {GetInputNode("X", node)}, node->outputs);
ReplaceNodeInputs(node, new_node);
ReplaceNodeOutputs(node, new_node);
return new_node;
}

ir::Node *relu_handler(ir::Graph *graph, ir::Node *node) {
auto *op = node->Op();
auto op_desc = std::make_unique<framework::OpDesc>();
op_desc->SetType("Relu");
std::vector<std::string> inputs;
inputs.push_back(op->Input("X").front());
op_desc->SetInput("__inputs__", inputs);
std::vector<std::string> outputs;
outputs.push_back(op->Output("Out").front());
op_desc->SetOutput("__outputs__", outputs);
op_desc->Flush();
return graph->CreateOpNode(op_desc.get());
return activation_op_handler(graph, node, "Relu");
}

ir::Node *tanh_handler(ir::Graph *graph, ir::Node *node) {
return activation_op_handler(graph, node, "Tanh");
}

ir::Node *log_handler(ir::Graph *graph, ir::Node *node) {
return activation_op_handler(graph, node, "Log");
}

ir::Node *sigmoid_handler(ir::Graph *graph, ir::Node *node) {
return activation_op_handler(graph, node, "Sigmoid");
}

ir::Node *sqrt_handler(ir::Graph *graph, ir::Node *node) {
return activation_op_handler(graph, node, "Sqrt");
}

ir::Node *gelu_handler(ir::Graph *graph, ir::Node *node) {
return activation_op_handler(graph, node, "Gelu");
}

REGISTER_HANDLER(relu, relu_handler);
REGISTER_HANDLER(tanh, tanh_handler);
REGISTER_HANDLER(log, log_handler);
REGISTER_HANDLER(sigmoid, sigmoid_handler);
REGISTER_HANDLER(sqrt, sqrt_handler);
REGISTER_HANDLER(gelu, gelu_handler);

} // namespace
} // namespace ipu
Expand Down
76 changes: 76 additions & 0 deletions python/paddle/fluid/tests/unittests/ipu/ipu_gelu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import unittest
import sys
import paddle
import paddle.fluid as fluid
import paddle.nn
import paddle.fluid.compiler as compiler

paddle.enable_static()
SEED = 2021


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestRelu(unittest.TestCase):
def _test(self, run_ipu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)

np_image = np.random.rand(1, 3, 2, 2).astype(np.float32)

with paddle.static.program_guard(main_prog, startup_prog):
image = paddle.static.data(
name='image', shape=[1, 3, 2, 2], dtype='float32')
out = paddle.nn.functional.gelu(image)

if run_ipu:
place = paddle.IPUPlace()
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)

if run_ipu:
feed_list = [image.name]
fetch_list = [out.name]
ipu_strategy = compiler.get_ipu_strategy()
ipu_strategy.is_training = False
print(main_prog)
program = compiler.IpuCompiler(
main_prog, ipu_strategy=ipu_strategy).compile(feed_list,
fetch_list)
print(program)
else:
program = main_prog
result = exe.run(program, feed={"image": np_image}, fetch_list=[out])
return np.array(result)

def test_relu(self):
# cpu and ipu dimenstion mismatch, cpu:(100, 1, 1), ipu:(100, 1)
cpu_loss = self._test(False).flatten()
ipu_loss = self._test(True).flatten()
self.assertTrue(np.allclose(ipu_loss, cpu_loss, atol=1e-3))


if __name__ == "__main__":
unittest.main()
75 changes: 75 additions & 0 deletions python/paddle/fluid/tests/unittests/ipu/ipu_tanh_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import unittest
import sys
import paddle
import paddle.fluid as fluid
import paddle.fluid.compiler as compiler

paddle.enable_static()
SEED = 2021


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestRelu(unittest.TestCase):
def _test(self, run_ipu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)

np_image = np.random.rand(1, 3, 2, 2).astype(np.float32)

with paddle.static.program_guard(main_prog, startup_prog):
image = paddle.static.data(
name='image', shape=[1, 3, 2, 2], dtype='float32')
out = paddle.nn.functional.tanh(image)

if run_ipu:
place = paddle.IPUPlace()
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)

if run_ipu:
feed_list = [image.name]
fetch_list = [out.name]
ipu_strategy = compiler.get_ipu_strategy()
ipu_strategy.is_training = False
print(main_prog)
program = compiler.IpuCompiler(
main_prog, ipu_strategy=ipu_strategy).compile(feed_list,
fetch_list)
print(program)
else:
program = main_prog
result = exe.run(program, feed={"image": np_image}, fetch_list=[out])
return np.array(result)

def test_relu(self):
# cpu and ipu dimenstion mismatch, cpu:(100, 1, 1), ipu:(100, 1)
cpu_loss = self._test(False).flatten()
ipu_loss = self._test(True).flatten()
self.assertTrue(np.allclose(ipu_loss, cpu_loss, atol=1e-4))


if __name__ == "__main__":
unittest.main()

0 comments on commit 1684973

Please sign in to comment.