Skip to content

Commit

Permalink
Modify bf16 and fix the elementwise_max (#54799)
Browse files Browse the repository at this point in the history
* modify the accuracy checking framework of bf16 optest, including both of forward and backward
  • Loading branch information
Vvsmile authored Jul 13, 2023
1 parent 4c5ce83 commit 6f7ceca
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 76 deletions.
73 changes: 39 additions & 34 deletions paddle/phi/kernels/funcs/elementwise_grad_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */

#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
Expand Down Expand Up @@ -114,41 +115,43 @@ static void ElemwiseGradBroadcast1CPU(const T *x,
DY_OP dy_op,
T *dx,
T *dy) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

if (is_xsize_larger) {
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
for (int j = 0; j < w; ++j) {
MPType sum_y = static_cast<MPType>(0);
for (int i = 0; i < h; ++i) {
int x_offset = i * w + j;
if (dx != nullptr) {
dx[x_offset] =
dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
if (dy != nullptr) {
T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
if (i == 0) {
dy[j] = tmp;
} else {
dy[j] += tmp;
}
sum_y += static_cast<MPType>(
dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]));
}
}
if (dy != nullptr) {
dy[j] = static_cast<T>(sum_y);
}
}
} else { // x.dims < y.dims, broadcast for x.
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
} else {
for (int j = 0; j < w; ++j) {
MPType sum_x = static_cast<MPType>(0);
for (int i = 0; i < h; ++i) {
int y_offset = i * w + j;
if (dy != nullptr) {
dy[y_offset] =
dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
}
if (dx != nullptr) {
T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
if (i == 0) {
dx[j] = tmp;
} else {
dx[j] += tmp;
}
sum_x += static_cast<MPType>(
dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]));
}
}
if (dx != nullptr) {
dx[j] = static_cast<T>(sum_x);
}
}
}
}
Expand All @@ -166,45 +169,47 @@ static void ElemwiseGradBroadcast2CPU(const T *x,
DY_OP dy_op,
T *dx,
T *dy) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

if (is_xsize_larger) {
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int j = 0; j < n; ++j) {
MPType sum_y = static_cast<MPType>(0);
for (int i = 0; i < pre; ++i) {
for (int k = 0; k < post; ++k) {
int x_offset = i * n * post + j * post + k;
if (dx != nullptr) {
dx[x_offset] =
dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
if (dy != nullptr) {
T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
if (i == 0 && k == 0) {
dy[j] = tmp;
} else {
dy[j] += tmp;
}
sum_y += static_cast<MPType>(
dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]));
}
}
}
if (dy != nullptr) {
dy[j] = static_cast<T>(sum_y);
}
}
} else { // x.dims < y.dims, broadcast for x.
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
} else {
for (int j = 0; j < n; ++j) {
MPType sum_x = static_cast<MPType>(0);
for (int i = 0; i < pre; ++i) {
for (int k = 0; k < post; ++k) {
int y_offset = i * n * post + j * post + k;
if (dy != nullptr) {
dy[y_offset] =
dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
}
if (dx != nullptr) {
T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
if (i == 0 && k == 0) {
dx[j] = tmp;
} else {
dx[j] += tmp;
}
sum_x += static_cast<MPType>(
dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]));
}
}
}
if (dx != nullptr) {
dx[j] = static_cast<T>(sum_x);
}
}
}
}
Expand Down
138 changes: 105 additions & 33 deletions test/legacy_test/eager_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,8 +552,20 @@ def is_fp16_compared_with_fp32(self):
not in op_accuracy_white_list.NO_FP16_COMPARED_WITH_FP32_OP_LIST
)

def is_bf16_compared_with_fp32(self):
return self.is_bfloat16_op() and (
self.op_type
not in op_accuracy_white_list.NO_BF16_COMPARED_WITH_FP32_OP_LIST
)

def is_compared_with_fp32(self):
return (
self.is_fp16_compared_with_fp32()
or self.is_bf16_compared_with_fp32()
)

def enable_cal_ref_output(self):
self.is_calc_ref = self.is_fp16_compared_with_fp32()
self.is_calc_ref = True

def disable_cal_ref_output(self):
self.is_calc_ref = False
Expand Down Expand Up @@ -654,46 +666,105 @@ def feed_var(self, input_vars, place):
if isinstance(np_value, tuple):
tensor.set(np_value[0], place)
dtype = np.array(np_value[1]).dtype
if self.is_calc_ref and dtype == np.float16:
if isinstance(np_value[1], list):
tensor.set_recursive_sequence_lengths(
np.array(np_value[1]).astype(np.float32)
)

if self.is_calc_ref:
# convert the float16 to float by numpy.astype
if dtype == np.float16:
if isinstance(np_value[1], list):
tensor.set_recursive_sequence_lengths(
np.array(np_value[1]).astype(np.float32)
)
else:
tensor.set_recursive_sequence_lengths(
np_value[1].astype(np.float32)
)
# convert the bfloat16 to float by convert_uint16_to_float
# provided in this file
elif dtype == np.uint16:
if isinstance(np_value[1], list):
tensor.set_recursive_sequence_lengths(
convert_uint16_to_float(
np.array(np_value[1])
)
)
else:
tensor.set_recursive_sequence_lengths(
convert_uint16_to_float(np_value[1])
)
else:
tensor.set_recursive_sequence_lengths(
np_value[1].astype(np.float32)
np_value[1]
)
else:
tensor.set_recursive_sequence_lengths(np_value[1])
else:
if self.is_calc_ref and np_value.dtype == np.float16:
tensor.set(np_value.astype(np.float32), place)
if self.is_calc_ref:
if np_value.dtype == np.float16:
tensor.set(np_value.astype(np.float32), place)
elif np_value.dtype == np.uint16:
tensor.set(
convert_uint16_to_float(np_value), place
)
else:
tensor.set(np_value, place)
else:
tensor.set(np_value, place)
feed_map[name] = tensor
else:
tensor = core.LoDTensor()
if isinstance(self.inputs[var_name], tuple):
tensor.set(self.inputs[var_name][0], place)
if (
self.is_calc_ref
and self.inputs[var_name][1].dtype == np.float16
):
tensor.set_recursive_sequence_lengths(
self.inputs[var_name][1].astype(np.float32)
)
if self.is_calc_ref:
if isinstance(self.inputs[var_name][1], list):
dtype = np.array(self.inputs[var_name][1]).dtype
if dtype == np.float16:
tensor.set_recursive_sequence_lengths(
np.array(self.inputs[var_name][1]).astype(
np.float32
)
)
elif dtype == np.uint16:
tensor.set_recursive_sequence_lengths(
convert_uint16_to_float(
np.array(self.inputs[var_name][1])
)
)
else:
tensor.set_recursive_sequence_lengths(
self.inputs[var_name][1]
)

elif self.inputs[var_name][1].dtype == np.float16:
tensor.set_recursive_sequence_lengths(
self.inputs[var_name][1].astype(np.float32)
)
elif self.inputs[var_name][1].dtype == np.uint16:
tensor.set_recursive_sequence_lengths(
convert_uint16_to_float(
self.inputs[var_name][1]
)
)
else:
tensor.set_recursive_sequence_lengths(
self.inputs[var_name][1]
)
else:
tensor.set_recursive_sequence_lengths(
self.inputs[var_name][1]
)
else:
if (
self.is_calc_ref
and self.inputs[var_name].dtype == np.float16
):
tensor.set(
self.inputs[var_name].astype(np.float32), place
)
if self.is_calc_ref:
if self.inputs[var_name].dtype == np.float16:
tensor.set(
self.inputs[var_name].astype(np.float32), place
)
elif self.inputs[var_name].dtype == np.uint16:
tensor.set(
convert_uint16_to_float(self.inputs[var_name]),
place,
)
else:
tensor.set(self.inputs[var_name], place)
else:
tensor.set(self.inputs[var_name], place)
feed_map[var_name] = tensor
Expand All @@ -711,7 +782,8 @@ def _append_ops(self, block):
self.__class__.use_xpu = True

op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
"infer datatype from inputs and outputs for this test case"
# "infer datatype from inputs and outputs for this test case"

if self.is_float16_op():
self.dtype = np.float16
self.__class__.dtype = self.dtype
Expand All @@ -722,6 +794,7 @@ def _append_ops(self, block):
self.output_dtype = np.uint16
else:
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)

inputs = append_input_output(
block, op_proto, self.inputs, True, self.dtype, self.is_calc_ref
)
Expand Down Expand Up @@ -1809,7 +1882,7 @@ def _compare_list(self, name, actual, expect):
def compare_single_output_with_expect(self, name, expect):
actual, actual_np = self.find_actual_value(name)
# expect_np = expect[0] if isinstance(expect, tuple) else expect
if self.op_test.is_fp16_compared_with_fp32():
if self.op_test.is_compared_with_fp32():
expect, expect_np = self.find_expect_value(name)
else:
expect_np = (
Expand Down Expand Up @@ -1864,7 +1937,7 @@ def calculate_output(self):
)
self.outputs = outs
self.fetch_list = fetch_list
if self.op_test.is_fp16_compared_with_fp32():
if self.op_test.is_compared_with_fp32():
self.op_test.enable_cal_ref_output()
ref_outs, ref_fetch_list = self.op_test._calc_output(
place, no_check_set=no_check_set
Expand Down Expand Up @@ -1931,7 +2004,7 @@ def calculate_output(self):
place, no_check_set=no_check_set
)
self.outputs = dygraph_outs
if self.op_test.is_fp16_compared_with_fp32():
if self.op_test.is_compared_with_fp32():
self.op_test.enable_cal_ref_output()
self.is_python_api_test = True
self.ref_outputs = self.op_test._calc_python_api_output(
Expand Down Expand Up @@ -2460,9 +2533,7 @@ def check_grad_with_place(
if self.is_bfloat16_op():
if self.is_mkldnn_op():
check_dygraph = False
atol = 1e-2 if atol < 1e-2 else atol
else:
atol = 1e-1 if atol < 1e-1 else atol
atol = 1e-2 if atol < 1e-2 else atol

if self.is_float16_op():
atol = 1e-3 if atol < 1e-3 else atol
Expand Down Expand Up @@ -2492,7 +2563,6 @@ def check_grad_with_place(
if "use_mkldnn" in op_attrs and op_attrs["use_mkldnn"]:
op_attrs["use_mkldnn"] = False
use_onednn = True

self.op = create_op(
self.scope,
self.op_type,
Expand Down Expand Up @@ -2538,8 +2608,9 @@ def check_grad_with_place(
if numeric_place is None:
numeric_place = place

if user_defined_grads is None and self.is_fp16_compared_with_fp32():
if user_defined_grads is None and self.is_compared_with_fp32():
self.enable_cal_ref_output()

numeric_grads = self._get_gradient(
inputs_to_check,
place,
Expand Down Expand Up @@ -2573,6 +2644,7 @@ def check_grad_with_place(
)
# comparison of bf16 results will happen as fp32
# loop over list of grads and convert bf16 to fp32

fp32_analytic_grads = []
for grad in analytic_grads:
if grad.dtype == np.uint16:
Expand Down Expand Up @@ -2869,7 +2941,7 @@ def _get_gradient(
feed_dict = self.feed_var(inputs, place)

if user_defined_grad_outputs is None:
if self.dtype == np.uint16:
if self.dtype == np.uint16 and not self.is_calc_ref:
cast_inputs = list(map(block.var, output_names))
if self.op_type in ["broadcast_tensors", "meshgrid"]:
output_names = self.cast_bf16_output(block, cast_inputs)
Expand Down
2 changes: 0 additions & 2 deletions test/legacy_test/test_elementwise_div_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,6 @@ def test_check_gradient(self):
check_args = [check_option['grad'], 'Out']
check_kwargs = {
'no_grad_set': check_option['no_grad'],
'user_defined_grads': check_option['val_grad'],
'user_defined_grad_outputs': [self.grad_out],
'check_dygraph': self.check_dygraph,
}
if self.place is None:
Expand Down
Loading

0 comments on commit 6f7ceca

Please sign in to comment.