Skip to content

Commit

Permalink
[NPU] add npu kernel for softmax_with_cross_entropy (PaddlePaddle#31656)
Browse files Browse the repository at this point in the history
* init

* fix bugs
  • Loading branch information
zhiqiu committed Mar 16, 2021
1 parent 925432d commit 5118968
Show file tree
Hide file tree
Showing 3 changed files with 364 additions and 0 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/operators/npu_op_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class NpuOpRunner {
aclopAttr *attr_{nullptr};
};

aclDataType ConvertToNpuDtype(framework::proto::VarType::Type dtype);

} // namespace operators
} // namespace paddle
#endif
203 changes: 203 additions & 0 deletions paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/math/softmax.h"
#include <memory>
#include <string>
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/operators/softmax_op.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename DeviceContext, typename T>
class SoftmaxWithCrossEntropyNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* logits = ctx.Input<Tensor>("Logits");
auto* labels = ctx.Input<Tensor>("Label");
auto* softmax = ctx.Output<Tensor>("Softmax");
auto* loss = ctx.Output<Tensor>("Loss");

int cls_num = logits->dims()[1];
const int rank = logits->dims().size();
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
std::vector<int> axes;
for (auto i = axis; i < logits->dims().size(); ++i) {
axes.push_back(i);
}

auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();

// softmax
softmax->mutable_data<T>(ctx.GetPlace());
auto runner_softmax =
NpuOpRunner("SoftmaxV2", {*logits}, {*softmax}, {{"axes", axes}});
runner_softmax.Run(stream);

// cast label from int64/int32 to int32
Tensor tmp_labels(framework::proto::VarType::INT32);
if (labels->type() != framework::proto::VarType::INT32) {
tmp_labels.Resize(labels->dims());
tmp_labels.mutable_data(ctx.GetPlace(), framework::proto::VarType::INT32);
auto dst_dtype = ConvertToNpuDtype(framework::proto::VarType::INT32);
auto runner_cast_label =
NpuOpRunner("Cast", {*labels}, {tmp_labels},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_label.Run(stream);
labels = &tmp_labels;
}

// on and off
Tensor on_tensor(framework::proto::VarType::INT32);
on_tensor.mutable_data<int>({1}, ctx.GetPlace());
TensorFromVector(std::vector<int>{static_cast<int>(1)},
ctx.device_context(), &on_tensor);
Tensor off_tensor(framework::proto::VarType::INT32);
off_tensor.mutable_data<int>({1}, ctx.GetPlace());
TensorFromVector(std::vector<int>{static_cast<int>(0)},
ctx.device_context(), &off_tensor);

// one_hot
Tensor tmp_onehot(on_tensor.type());
tmp_onehot.Resize(logits->dims());
tmp_onehot.mutable_data<int>(ctx.GetPlace());

auto runner_onehot =
NpuOpRunner("OneHotD", {*labels, on_tensor, off_tensor}, {tmp_onehot},
{{"axis", -1}, {"depth", cls_num}});
runner_onehot.Run(stream);

// cast one_hot from int32 to T
Tensor cast_onehot(logits->type());
cast_onehot.Resize(tmp_onehot.dims());
cast_onehot.mutable_data<T>(ctx.GetPlace());
auto dst_dtype = ConvertToNpuDtype(logits->type());
auto runner_cast_onehot =
NpuOpRunner("Cast", {tmp_onehot}, {cast_onehot},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_onehot.Run(stream);

// SoftmaxCrossEntropyWithLogits
Tensor backprop(logits->type());
backprop.Resize(logits->dims());
backprop.mutable_data<T>(ctx.GetPlace());

loss->mutable_data<T>(ctx.GetPlace());

// SoftmaxCrossEntropyWithLogits requires loss to be of shape [batch_size]
auto loss_dims = loss->dims();
loss->Resize({loss_dims[0]});
auto runner_s = NpuOpRunner("SoftmaxCrossEntropyWithLogits",
{*logits, cast_onehot}, {*loss, backprop}, {});
runner_s.Run(stream);
loss->Resize(loss_dims);
}
};

template <typename DeviceContext, typename T>
class SoftmaxWithCrossEntropyGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* labels = ctx.Input<Tensor>("Label");
auto* softmax = ctx.Input<Tensor>("Softmax");
auto* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
auto* logits_grad = ctx.Output<Tensor>(framework::GradVarName("Logits"));

int cls_num = softmax->dims()[1];

auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();

// cast label from int64/int32 to int32
Tensor tmp_labels(framework::proto::VarType::INT32);
if (labels->type() != framework::proto::VarType::INT32) {
tmp_labels.Resize(labels->dims());
tmp_labels.mutable_data(ctx.GetPlace(), framework::proto::VarType::INT32);
auto dst_dtype = ConvertToNpuDtype(framework::proto::VarType::INT32);
auto runner_cast_label =
NpuOpRunner("Cast", {*labels}, {tmp_labels},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_label.Run(stream);
labels = &tmp_labels;
}

// on and off
Tensor on_tensor(framework::proto::VarType::INT32);
on_tensor.mutable_data<int>({1}, ctx.GetPlace());
TensorFromVector(std::vector<int>{static_cast<int>(1)},
ctx.device_context(), &on_tensor);
Tensor off_tensor(framework::proto::VarType::INT32);
off_tensor.mutable_data<int>({1}, ctx.GetPlace());
TensorFromVector(std::vector<int>{static_cast<int>(0)},
ctx.device_context(), &off_tensor);

// one_hot
Tensor tmp_onehot(on_tensor.type());
tmp_onehot.Resize(softmax->dims());
tmp_onehot.mutable_data<int>(ctx.GetPlace());

auto runner_onehot =
NpuOpRunner("OneHotD", {*labels, on_tensor, off_tensor}, {tmp_onehot},
{{"axis", -1}, {"depth", cls_num}});
runner_onehot.Run(stream);

// cast one_hot from int32 to T
Tensor cast_onehot(softmax->type());
cast_onehot.Resize(tmp_onehot.dims());
cast_onehot.mutable_data<T>(ctx.GetPlace());
auto dst_dtype = ConvertToNpuDtype(softmax->type());
auto runner_cast_onehot =
NpuOpRunner("Cast", {tmp_onehot}, {cast_onehot},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_onehot.Run(stream);

// sub
Tensor tmp_sub(softmax->type());
tmp_sub.Resize(softmax->dims());
tmp_sub.mutable_data<T>(ctx.GetPlace());
auto runner_sub =
NpuOpRunner("Sub", {*softmax, cast_onehot}, {tmp_sub}, {});

runner_sub.Run(stream);
// mul
logits_grad->mutable_data<T>(ctx.GetPlace());
auto runner_mul =
NpuOpRunner("Mul", {*loss_grad, tmp_sub}, {*logits_grad}, {});
runner_mul.Run(stream);
}
};
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_NPU_KERNEL(
softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyNPUKernel<paddle::platform::NPUDeviceContext,
float>,
ops::SoftmaxWithCrossEntropyNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradNPUKernel<
paddle::platform::NPUDeviceContext, float>,
ops::SoftmaxWithCrossEntropyGradNPUKernel<
paddle::platform::NPUDeviceContext, paddle::platform::float16>);
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from test_softmax_op import stable_softmax
from test_softmax_with_cross_entropy_op import cross_entropy

paddle.enable_static()
SEED = 2021


@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestSoftmaxWithCrossEntropyOp(OpTest):
def set_npu(self):
self.__class__.use_npu = True

def init_dtype(self):
self.dtype = np.float32

def initParams(self):
self.set_npu()
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = False
self.place = paddle.NPUPlace(0)
self.soft_label = False
self.init_dtype()
self.axis = -1
self.ignore_index = -1
self.shape = [41, 37]
np.random.seed(SEED)

def setUp(self):
self.initParams()

logits = getattr(
self, "logits",
np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype))
softmax = np.apply_along_axis(stable_softmax, self.axis, logits)

if self.soft_label:
labels = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)
labels /= np.sum(labels, axis=self.axis, keepdims=True)
else:
axis_dim = self.shape[self.axis]
self.shape[self.axis] = 1
labels = np.random.randint(0, axis_dim, self.shape, dtype="int64")

loss = cross_entropy(softmax, labels, self.soft_label, self.axis,
self.ignore_index)

self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {
"Softmax": softmax.astype(self.dtype),
"Loss": loss.astype(self.dtype)
}
self.attrs = {
"numeric_stable_mode": self.numeric_stable_mode,
"soft_label": self.soft_label,
"ignore_index": self.ignore_index,
}

if self.axis != -1:
self.attrs['axis'] = self.axis

def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False)

# TODO(ascendrc): Add grad test
# def test_check_grad(self):
# if self.dtype == np.float16:
# return
# self.check_grad(['X'], 'Out')
#


@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestPowNet(unittest.TestCase):
def _test(self, run_npu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)

a_np = np.random.random(size=(32, 32)).astype('float32')
b_np = np.random.random(size=(32, 32)).astype('float32')
label_np = np.random.randint(2, size=(32, 1)).astype('int64')

with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
label = paddle.static.data(
name="label", shape=[32, 1], dtype='int64')

sum = paddle.add(a, b)
z = paddle.pow(sum, 2.0)

fc_1 = fluid.layers.fc(input=z, size=128)
prediction = fluid.layers.fc(input=fc_1, size=2)

cost = fluid.layers.softmax_with_cross_entropy(prediction, label)
loss = fluid.layers.reduce_mean(cost)
sgd = fluid.optimizer.SGD(learning_rate=0.01)
sgd.minimize(loss)

if run_npu:
place = paddle.NPUPlace(0)
else:
place = paddle.CPUPlace()

exe = paddle.static.Executor(place)
exe.run(startup_prog)

print("Start run on {}".format(place))
for epoch in range(100):

pred_res, loss_res = exe.run(
main_prog,
feed={"a": a_np,
"b": b_np,
"label": label_np},
fetch_list=[prediction, loss])
if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res))

return pred_res, loss_res

def test_npu(self):
cpu_pred, cpu_loss = self._test(False)
npu_pred, npu_loss = self._test(True)

self.assertTrue(np.allclose(npu_pred, cpu_pred))
self.assertTrue(np.allclose(npu_loss, cpu_loss))


if __name__ == '__main__':
unittest.main()

0 comments on commit 5118968

Please sign in to comment.