Skip to content

Commit

Permalink
[oneDNN] lookup_table op with support for BF16 data type. (#31558)
Browse files Browse the repository at this point in the history
  • Loading branch information
arogowie-intel authored Mar 19, 2021
1 parent c86e771 commit a4a2b77
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void CPUBfloat16PlacementPass::SetMkldnnDataType(
gpd(graph, handler);
}

void CPUBfloat16PlacementPass::RemoveOrhanedOperators(
void CPUBfloat16PlacementPass::RemoveOrphanedOperators(
ir::Graph* graph, int* bfloat16_operators) const {
// find orphaned bfloat16 operator that is between two float32 operators
// revert mkldnn_data_type attr to float32
Expand All @@ -74,7 +74,7 @@ void CPUBfloat16PlacementPass::RemoveOrhanedOperators(
void CPUBfloat16PlacementPass::ApplyImpl(ir::Graph* graph) const {
int bfloat16_operators = 0;
SetMkldnnDataType(graph, &bfloat16_operators);
RemoveOrhanedOperators(graph, &bfloat16_operators);
RemoveOrphanedOperators(graph, &bfloat16_operators);
PrettyLogDetail("--- marked %d operators to bfloat16 ",
bfloat16_operators);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class CPUBfloat16PlacementPass : public Pass {
protected:
void SetMkldnnDataType(ir::Graph* graph, int* bfloat16_operators) const;

void RemoveOrhanedOperators(ir::Graph* graph, int* bfloat16_operators) const;
void RemoveOrphanedOperators(ir::Graph* graph, int* bfloat16_operators) const;

void ApplyImpl(ir::Graph* graph) const override;
};
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/operators/lookup_table_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/platform/bfloat16.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -222,9 +223,11 @@ REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad,

REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>,
ops::LookupTableKernel<double>,
ops::LookupTableKernel<int8_t>);
ops::LookupTableKernel<int8_t>,
ops::LookupTableKernel<paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>,
ops::LookupTableGradKernel<double>);
ops::LookupTableGradKernel<double>,
ops::LookupTableGradKernel<paddle::platform::bfloat16>);

/* ========================== register checkpoint ===========================*/

Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/operators/lookup_table_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto id_index = table_t.GetIndexFromId(ids[i]);

if (id_index != -1) {
if (input_data_type == framework::proto::VarType::INT8) {
if (input_data_type == framework::proto::VarType::INT8 ||
input_data_type == framework::proto::VarType::BF16) {
memcpy(output + i * row_width, table + id_index * row_width,
row_width * sizeof(T));
} else {
Expand All @@ -128,7 +129,8 @@ class LookupTableKernel : public framework::OpKernel<T> {
"the input key should be exists. But received %d.",
id_index));

if (input_data_type == framework::proto::VarType::INT8) {
if (input_data_type == framework::proto::VarType::INT8 ||
input_data_type == framework::proto::VarType::BF16) {
memcpy(output + i * row_width, table + id_index * row_width,
row_width * sizeof(T));
} else {
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/operators/math/blas_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <vector>

#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"

Expand All @@ -40,6 +41,16 @@ struct CBlas<int8_t> {
}
};

template <>
struct CBlas<platform::bfloat16> {
template <typename... ARGS>
static void VCOPY(ARGS... args) {
PADDLE_THROW(platform::errors::Unimplemented(
"Blas VCOPY do not supported on CPU with bfloat16,"
" please check your code"));
}
};

#ifdef PADDLE_WITH_MKLML
template <>
struct CBlas<float> {
Expand Down
16 changes: 13 additions & 3 deletions python/paddle/fluid/tests/unittests/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,19 @@
from paddle.fluid.op import Operator
from paddle.fluid.executor import Executor
from paddle.fluid.framework import Program, OpProtoHolder, Variable
from testsuite import create_op, set_input, append_input_output, append_loss_ops
from paddle.fluid.tests.unittests.testsuite import (
create_op,
set_input,
append_input_output,
append_loss_ops, )
from paddle.fluid import unique_name
from white_list import op_accuracy_white_list, check_shape_white_list, compile_vs_runtime_white_list, no_check_set_white_list
from white_list import op_threshold_white_list, no_grad_set_white_list
from paddle.fluid.tests.unittests.white_list import (
op_accuracy_white_list,
check_shape_white_list,
compile_vs_runtime_white_list,
no_check_set_white_list,
op_threshold_white_list,
no_grad_set_white_list, )


def check_out_dtype(api_fn, in_specs, expect_dtypes, target_index=0, **configs):
Expand Down Expand Up @@ -1452,6 +1461,7 @@ def check_grad_with_place(self,
analytic_grads = self._get_gradient(inputs_to_check, place,
output_names, no_grad_set,
user_defined_grad_outputs)

# comparison of bf16 results will happen as fp32
# loop over list of grads and convert bf16 to fp32
fp32_grads = []
Expand Down
176 changes: 176 additions & 0 deletions python/paddle/fluid/tests/unittests/test_lookup_table_bf16_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import (
OpTest, convert_float_to_uint16, convert_uint16_to_float,
skip_check_grad_ci)
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from paddle import enable_static


def _lookup(weights, ids, flat_ids):
w_shape = weights.shape
out_shape = list(ids.shape[:-1])
out_shape.append(w_shape[-1])
out = weights[flat_ids].reshape(out_shape)
return out


def _get_grad(weights, ids, flat_ids):
w_shape = weights.shape
w_grad = np.zeros((w_shape), dtype=weights.dtype)
out_grad_shape = (np.prod(ids.shape[:-1]), w_shape[-1])
out_grad = weights[flat_ids].reshape(out_grad_shape)
for i, idx in enumerate(flat_ids):
w_grad[idx, :] += out_grad[i]
return w_grad


@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestLookupTableBF16Op(OpTest):
def setUp(self):
self.op_type = "lookup_table"
self.dtype = np.uint16

table = np.random.random((17, 31)).astype("float32")
self.ids = np.random.randint(0, 17, (4, 1)).astype("int64")
self.flat_ids = self.ids.flatten()

self.w_bf16 = convert_float_to_uint16(table)
self.out_bf16 = _lookup(self.w_bf16, self.ids, self.flat_ids)
self.out_fp32 = _lookup(table, self.ids, self.flat_ids)
self.w_grad_fp32 = _get_grad(table, self.ids, self.flat_ids)

self.inputs = {'W': self.w_bf16, 'Ids': self.ids}
self.outputs = {'Out': self.out_fp32}

def test_check_output(self):
self.check_output_with_place(core.CPUPlace(), check_dygraph=False)

def test_check_grad(self):
self.check_grad_with_place(
core.CPUPlace(), ['W'],
'Out',
no_grad_set=set('Ids'),
check_dygraph=False,
max_relative_error=1.5e-2,
user_defined_grads=[self.w_grad_fp32],
user_defined_grad_outputs=[self.out_bf16])


@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestLookupTableBF16OpIds4D(TestLookupTableBF16Op):
def setUp(self):
super(TestLookupTableBF16OpIds4D, self).setUp()
self.ids = np.random.randint(0, 17, (2, 4, 5, 1)).astype("int64")


@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestLookupTableBF16OpWIsSelectedRows(unittest.TestCase):
def setUp(self):
self.ids = np.random.randint(
low=0, high=15, size=(10, 1)).astype("int64")
self.flat_ids = self.ids.flatten()
self.w_fp32 = np.random.random((15, 32)).astype("float32")
self.w_bf16 = convert_float_to_uint16(self.w_fp32)
self.scope = core.Scope()
self.place = core.CPUPlace()

def prepare_w(self):
rows = [a for a in range(self.w_bf16.shape[0])]
row_numel = self.w_bf16.shape[1]

w_selected_rows = self.scope.var('W').get_selected_rows()
w_selected_rows.set_height(len(rows))
w_selected_rows.set_rows(rows)
w_tensor = w_selected_rows.get_tensor()
w_tensor.set(self.w_bf16, self.place)

def prepare_ids(self):
ids_tensor = self.scope.var('Ids').get_tensor()
ids_tensor.set(self.ids, self.place)

def _check_output(self, reference, result_array):
result_array_fp32 = convert_uint16_to_float(result_array)
np.testing.assert_allclose(result_array_fp32, reference, rtol=1.5e-2)

def test_check_output(self):
self.prepare_ids()
self.prepare_w()
out_tensor = self.scope.var('Out').get_tensor()

# create and run lookup_table operator
lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
lookup_table.run(self.scope, self.place)

# get result from Out
result_array = np.array(out_tensor)
ref = _lookup(self.w_fp32, self.ids, self.flat_ids)
self._check_output(ref, result_array)


@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestLookupTableBF16OpWIsSelectedRows4DIds(
TestLookupTableBF16OpWIsSelectedRows):
def setUp(self):
super(TestLookupTableBF16OpWIsSelectedRows4DIds, self).setUp()
self.ids = np.random.randint(
low=0, high=15, size=(3, 4, 5, 1)).astype("int64")
self.flat_ids = self.ids.flatten()


@skip_check_grad_ci(
reason="Since paddings are not trainable and fixed in forward,"
"the gradient of paddings makes no sense and we don't "
"test the gradient here.")
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestLookupTableBF16OpWithPadding(TestLookupTableBF16Op):
def test_check_output(self):
ids = np.squeeze(self.inputs['Ids'])
padding_idx = np.random.choice(ids, 1)[0]
self.outputs['Out'][ids == padding_idx] = np.zeros(31)
self.attrs = {'padding_idx': int(padding_idx)}
self.check_output_with_place(core.CPUPlace(), check_dygraph=False)


@skip_check_grad_ci(
reason="Since paddings are not trainable and fixed in forward,"
"the gradient of paddings makes no sense and we don't "
"test the gradient here.")
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestLookupTableBF16OpIds4DPadding(TestLookupTableBF16OpIds4D):
def test_check_output(self):
ids = self.inputs['Ids']
flatten_idx = ids.flatten()
padding_idx = np.random.choice(flatten_idx, 1)[0]
self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
self.attrs = {'padding_idx': int(padding_idx)}
self.check_output_with_place(core.CPUPlace(), check_dygraph=False)


if __name__ == "__main__":
enable_static()
unittest.main()
1 change: 1 addition & 0 deletions tools/static_mode_white_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
'test_linear_chain_crf_op',
'test_lod_reset_op',
'test_lookup_table_op',
'test_lookup_table_bf16_op',
'test_pad2d_op',
'test_scatter_op',
'test_sequence_concat',
Expand Down

0 comments on commit a4a2b77

Please sign in to comment.