diff --git a/paddle/fluid/operators/npu_op_runner.h b/paddle/fluid/operators/npu_op_runner.h index a7ce3f27f3325..97468364eec79 100644 --- a/paddle/fluid/operators/npu_op_runner.h +++ b/paddle/fluid/operators/npu_op_runner.h @@ -82,6 +82,8 @@ class NpuOpRunner { aclopAttr *attr_{nullptr}; }; +aclDataType ConvertToNpuDtype(framework::proto::VarType::Type dtype); + } // namespace operators } // namespace paddle #endif diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc new file mode 100644 index 0000000000000..c777a02f96bd9 --- /dev/null +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc @@ -0,0 +1,203 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/softmax.h" +#include +#include +#include "paddle/fluid/operators/math/cross_entropy.h" +#include "paddle/fluid/operators/npu_op_runner.h" +#include "paddle/fluid/operators/softmax_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class SoftmaxWithCrossEntropyNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* logits = ctx.Input("Logits"); + auto* labels = ctx.Input("Label"); + auto* softmax = ctx.Output("Softmax"); + auto* loss = ctx.Output("Loss"); + + int cls_num = logits->dims()[1]; + const int rank = logits->dims().size(); + const int axis = CanonicalAxis(ctx.Attr("axis"), rank); + std::vector axes; + for (auto i = axis; i < logits->dims().size(); ++i) { + axes.push_back(i); + } + + auto stream = + ctx.template device_context() + .stream(); + + // softmax + softmax->mutable_data(ctx.GetPlace()); + auto runner_softmax = + NpuOpRunner("SoftmaxV2", {*logits}, {*softmax}, {{"axes", axes}}); + runner_softmax.Run(stream); + + // cast label from int64/int32 to int32 + Tensor tmp_labels(framework::proto::VarType::INT32); + if (labels->type() != framework::proto::VarType::INT32) { + tmp_labels.Resize(labels->dims()); + tmp_labels.mutable_data(ctx.GetPlace(), framework::proto::VarType::INT32); + auto dst_dtype = ConvertToNpuDtype(framework::proto::VarType::INT32); + auto runner_cast_label = + NpuOpRunner("Cast", {*labels}, {tmp_labels}, + {{"dst_type", static_cast(dst_dtype)}}); + runner_cast_label.Run(stream); + labels = &tmp_labels; + } + + // on and off + Tensor on_tensor(framework::proto::VarType::INT32); + on_tensor.mutable_data({1}, ctx.GetPlace()); + TensorFromVector(std::vector{static_cast(1)}, + ctx.device_context(), &on_tensor); + Tensor off_tensor(framework::proto::VarType::INT32); + off_tensor.mutable_data({1}, ctx.GetPlace()); + TensorFromVector(std::vector{static_cast(0)}, + ctx.device_context(), &off_tensor); + + // one_hot + Tensor tmp_onehot(on_tensor.type()); + tmp_onehot.Resize(logits->dims()); + tmp_onehot.mutable_data(ctx.GetPlace()); + + auto runner_onehot = + NpuOpRunner("OneHotD", {*labels, on_tensor, off_tensor}, {tmp_onehot}, + {{"axis", -1}, {"depth", cls_num}}); + runner_onehot.Run(stream); + + // cast one_hot from int32 to T + Tensor cast_onehot(logits->type()); + cast_onehot.Resize(tmp_onehot.dims()); + cast_onehot.mutable_data(ctx.GetPlace()); + auto dst_dtype = ConvertToNpuDtype(logits->type()); + auto runner_cast_onehot = + NpuOpRunner("Cast", {tmp_onehot}, {cast_onehot}, + {{"dst_type", static_cast(dst_dtype)}}); + runner_cast_onehot.Run(stream); + + // SoftmaxCrossEntropyWithLogits + Tensor backprop(logits->type()); + backprop.Resize(logits->dims()); + backprop.mutable_data(ctx.GetPlace()); + + loss->mutable_data(ctx.GetPlace()); + + // SoftmaxCrossEntropyWithLogits requires loss to be of shape [batch_size] + auto loss_dims = loss->dims(); + loss->Resize({loss_dims[0]}); + auto runner_s = NpuOpRunner("SoftmaxCrossEntropyWithLogits", + {*logits, cast_onehot}, {*loss, backprop}, {}); + runner_s.Run(stream); + loss->Resize(loss_dims); + } +}; + +template +class SoftmaxWithCrossEntropyGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* labels = ctx.Input("Label"); + auto* softmax = ctx.Input("Softmax"); + auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); + auto* logits_grad = ctx.Output(framework::GradVarName("Logits")); + + int cls_num = softmax->dims()[1]; + + auto stream = + ctx.template device_context() + .stream(); + + // cast label from int64/int32 to int32 + Tensor tmp_labels(framework::proto::VarType::INT32); + if (labels->type() != framework::proto::VarType::INT32) { + tmp_labels.Resize(labels->dims()); + tmp_labels.mutable_data(ctx.GetPlace(), framework::proto::VarType::INT32); + auto dst_dtype = ConvertToNpuDtype(framework::proto::VarType::INT32); + auto runner_cast_label = + NpuOpRunner("Cast", {*labels}, {tmp_labels}, + {{"dst_type", static_cast(dst_dtype)}}); + runner_cast_label.Run(stream); + labels = &tmp_labels; + } + + // on and off + Tensor on_tensor(framework::proto::VarType::INT32); + on_tensor.mutable_data({1}, ctx.GetPlace()); + TensorFromVector(std::vector{static_cast(1)}, + ctx.device_context(), &on_tensor); + Tensor off_tensor(framework::proto::VarType::INT32); + off_tensor.mutable_data({1}, ctx.GetPlace()); + TensorFromVector(std::vector{static_cast(0)}, + ctx.device_context(), &off_tensor); + + // one_hot + Tensor tmp_onehot(on_tensor.type()); + tmp_onehot.Resize(softmax->dims()); + tmp_onehot.mutable_data(ctx.GetPlace()); + + auto runner_onehot = + NpuOpRunner("OneHotD", {*labels, on_tensor, off_tensor}, {tmp_onehot}, + {{"axis", -1}, {"depth", cls_num}}); + runner_onehot.Run(stream); + + // cast one_hot from int32 to T + Tensor cast_onehot(softmax->type()); + cast_onehot.Resize(tmp_onehot.dims()); + cast_onehot.mutable_data(ctx.GetPlace()); + auto dst_dtype = ConvertToNpuDtype(softmax->type()); + auto runner_cast_onehot = + NpuOpRunner("Cast", {tmp_onehot}, {cast_onehot}, + {{"dst_type", static_cast(dst_dtype)}}); + runner_cast_onehot.Run(stream); + + // sub + Tensor tmp_sub(softmax->type()); + tmp_sub.Resize(softmax->dims()); + tmp_sub.mutable_data(ctx.GetPlace()); + auto runner_sub = + NpuOpRunner("Sub", {*softmax, cast_onehot}, {tmp_sub}, {}); + + runner_sub.Run(stream); + // mul + logits_grad->mutable_data(ctx.GetPlace()); + auto runner_mul = + NpuOpRunner("Mul", {*loss_grad, tmp_sub}, {*logits_grad}, {}); + runner_mul.Run(stream); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + softmax_with_cross_entropy, + ops::SoftmaxWithCrossEntropyNPUKernel, + ops::SoftmaxWithCrossEntropyNPUKernel); +REGISTER_OP_NPU_KERNEL( + softmax_with_cross_entropy_grad, + ops::SoftmaxWithCrossEntropyGradNPUKernel< + paddle::platform::NPUDeviceContext, float>, + ops::SoftmaxWithCrossEntropyGradNPUKernel< + paddle::platform::NPUDeviceContext, paddle::platform::float16>); diff --git a/python/paddle/fluid/tests/unittests/npu/test_softmax_with_cross_entropy_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_softmax_with_cross_entropy_op_npu.py new file mode 100644 index 0000000000000..1b48268b0e77e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_softmax_with_cross_entropy_op_npu.py @@ -0,0 +1,159 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from test_softmax_op import stable_softmax +from test_softmax_with_cross_entropy_op import cross_entropy + +paddle.enable_static() +SEED = 2021 + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestSoftmaxWithCrossEntropyOp(OpTest): + def set_npu(self): + self.__class__.use_npu = True + + def init_dtype(self): + self.dtype = np.float32 + + def initParams(self): + self.set_npu() + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = False + self.place = paddle.NPUPlace(0) + self.soft_label = False + self.init_dtype() + self.axis = -1 + self.ignore_index = -1 + self.shape = [41, 37] + np.random.seed(SEED) + + def setUp(self): + self.initParams() + + logits = getattr( + self, "logits", + np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)) + softmax = np.apply_along_axis(stable_softmax, self.axis, logits) + + if self.soft_label: + labels = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype) + labels /= np.sum(labels, axis=self.axis, keepdims=True) + else: + axis_dim = self.shape[self.axis] + self.shape[self.axis] = 1 + labels = np.random.randint(0, axis_dim, self.shape, dtype="int64") + + loss = cross_entropy(softmax, labels, self.soft_label, self.axis, + self.ignore_index) + + self.inputs = {"Logits": logits, "Label": labels} + self.outputs = { + "Softmax": softmax.astype(self.dtype), + "Loss": loss.astype(self.dtype) + } + self.attrs = { + "numeric_stable_mode": self.numeric_stable_mode, + "soft_label": self.soft_label, + "ignore_index": self.ignore_index, + } + + if self.axis != -1: + self.attrs['axis'] = self.axis + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + + # TODO(ascendrc): Add grad test + # def test_check_grad(self): + # if self.dtype == np.float16: + # return + # self.check_grad(['X'], 'Out') + # + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestPowNet(unittest.TestCase): + def _test(self, run_npu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(32, 32)).astype('float32') + b_np = np.random.random(size=(32, 32)).astype('float32') + label_np = np.random.randint(2, size=(32, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[32, 1], dtype='int64') + + sum = paddle.add(a, b) + z = paddle.pow(sum, 2.0) + + fc_1 = fluid.layers.fc(input=z, size=128) + prediction = fluid.layers.fc(input=fc_1, size=2) + + cost = fluid.layers.softmax_with_cross_entropy(prediction, label) + loss = fluid.layers.reduce_mean(cost) + sgd = fluid.optimizer.SGD(learning_rate=0.01) + sgd.minimize(loss) + + if run_npu: + place = paddle.NPUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_npu(self): + cpu_pred, cpu_loss = self._test(False) + npu_pred, npu_loss = self._test(True) + + self.assertTrue(np.allclose(npu_pred, cpu_pred)) + self.assertTrue(np.allclose(npu_loss, cpu_loss)) + + +if __name__ == '__main__': + unittest.main()