Skip to content

Commit

Permalink
[oneDNN] GRU BF16 kernel (#27731)
Browse files Browse the repository at this point in the history
  • Loading branch information
jczaja authored Oct 9, 2020
1 parent 6c1acf3 commit 606611d
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 6 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1894,7 +1894,8 @@ PDNode *patterns::QuantizePlacement::operator()(

PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types{"conv2d"};
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>({"conv2d", "fusion_gru"});
if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types;
}
Expand Down
13 changes: 8 additions & 5 deletions paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {

// Weights for int8 kernel are of a type s8
const auto weights_dt =
is_INT8 ? dnnl::memory::data_type::s8 : dnnl::memory::data_type::f32;
is_INT8 ? dnnl::memory::data_type::s8 : MKLDNNGetDataType<T>();

// oneDNN RNN dimensions
const int64_t D = 1; // Directions
Expand Down Expand Up @@ -226,6 +226,8 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
}

// TODO(grygielski) H0 is for now persistable
// TODO(jczaja) H0 should be updated each iter and of T type (Fusion pass does
// not support in yet)
std::shared_ptr<dnnl::memory> AcquireH0Memory(const Tensor* h0) {
const std::string h0_key = memory_key_ + "@h0";
auto memory_p =
Expand Down Expand Up @@ -397,14 +399,14 @@ template <typename T>
class FusionGRUMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const bool is_INT8 = std::is_same<T, uint8_t>::value;
const bool is_bf16 = std::is_same<T, paddle::platform::bfloat16>::value;
const bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");

// TODO(grygielski) Add option for bfloat
if (!is_INT8 || force_fp32_output) {
// BF16 does not support force output
if (!is_bf16 && force_fp32_output) {
RunKernel<float>(ctx);
} else {
RunKernel<uint8_t>(ctx);
RunKernel<T>(ctx);
}
}

Expand Down Expand Up @@ -495,4 +497,5 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(fusion_gru, MKLDNN, paddle::platform::CPUPlace,
ops::FusionGRUMKLDNNKernel<float>,
ops::FusionGRUMKLDNNKernel<paddle::platform::bfloat16>,
ops::FusionGRUMKLDNNKernel<uint8_t>);
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
import struct
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.test_fusion_gru_op import fusion_gru
from paddle.fluid.tests.unittests.test_fusion_lstm_op import fc, ACTIVATION


@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestFusionGRUBF16MKLDNNOp(OpTest):
def set_confs(self):
self.mkldnn_data_type = False

def setUp(self):
self.op_type = "fusion_gru"
self.lod = [[2, 4, 3]]
self.M = 3
self.D = 5
self.is_reverse = False
self.with_h0 = False
self.use_mkldnn = True
self._cpu_only = True
self.with_bias = True
self.act_state = 'tanh'
self.act_gate = 'sigmoid'
self.origin_mode = False
self.use_mkldnn = True
self.force_fp32_output = False
self.set_confs()

T = sum(self.lod[0])
N = len(self.lod[0])

# fp32 X input for reference implementation and
# corressponding bf16 data as input to GRU oneDNN bf16 kernel
x_fp32 = np.random.rand(T, self.M).astype('float32')
x_bf16 = convert_float_to_uint16(x_fp32)

wx_fp32 = np.random.rand(self.M, 3 * self.D).astype('float32')
wh_fp32 = np.random.rand(self.D, 3 * self.D).astype('float32')

# bias is fp32 despite other inputs being in bf16
bias = np.random.rand(
1, 3 * self.D).astype('float32') if self.with_bias else np.zeros(
(1, 3 * self.D), dtype='float32')

h0_fp32 = np.random.rand(
N, self.D).astype('float32') if self.with_h0 else np.zeros(
(N, self.D), dtype='float32')

_, _, _, hidden = fusion_gru(
x_fp32, self.lod, h0_fp32, wx_fp32, wh_fp32, bias, self.is_reverse,
self.origin_mode, ACTIVATION[self.act_state],
ACTIVATION[self.act_gate])

hidden_bf16 = convert_float_to_uint16(hidden)

self.inputs = {
'X': (x_bf16, self.lod),
'WeightX': wx_fp32,
'WeightH': wh_fp32
}

if self.with_bias:
self.inputs['Bias'] = bias

if self.with_h0:
self.inputs['H0'] = h0_bf16

h0_bf16 = convert_float_to_uint16(h0_fp32)
self.outputs = {'Hidden': (hidden_bf16, self.lod)}

self.attrs = {
'activation': self.act_state,
'gate_activation': self.act_gate,
'is_reverse': self.is_reverse,
'origin_mode': self.origin_mode,
'force_fp32_output': self.force_fp32_output,
'use_mkldnn': self.use_mkldnn
}


class TestFusionGRUINT8MKLDNNOp2(TestFusionGRUBF16MKLDNNOp):
def set_confs(self):
self.origin_mode = False


class TestFusionGRUINT8MKLDNNOp3(TestFusionGRUBF16MKLDNNOp):
def set_confs(self):
self.with_bias = False


if __name__ == "__main__":
unittest.main()

0 comments on commit 606611d

Please sign in to comment.