Skip to content

Commit

Permalink
[PHI] Migrate matmul kernel (PaddlePaddle#48162)
Browse files Browse the repository at this point in the history
* cleanup unused code

* unify is_int8 is_bfloat16

* Simplify matmul_v2 FWD kernel

* remove RunKernel methods

* remove import namespace

* remove headers

* clean fluid/phi cross imports

* remove fluid axpy_handler

* delete fluid methods

* activations

* OneDNNMemDesc

* MKLDNNFormatForSize

* MatchShapeToLayout

* MKLDNNMemoryFormat

* MKLDNNFormat

* ReorderMKLDNNHandler

* to_void_cast

* review suggestions

* interpolate

* remove fluid depedency

* init

* ExecuteMatMulV2

* rm fluid kernel

* matmul_grad

* remove mutable_data

* mul_grad

* matmul fwd

* add extra attr

* temp disable passes

* re-enable passes

* workaround for matmul+act

* fix for matmul+eltwise_add

* fix typo

* merge bugfix PaddlePaddle#48364

* remove merge conflict
  • Loading branch information
Silv3S authored Nov 29, 2022
1 parent 37a445c commit ad04a4b
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 21 deletions.
18 changes: 5 additions & 13 deletions paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
}

template <typename T>
class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const ExecutionContext &ctx) const override {
if (ctx.HasAttr("head_number")) {
Expand Down Expand Up @@ -696,21 +696,13 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
REGISTER_OP_KERNEL(matmul,
MKLDNN,
::paddle::platform::CPUPlace,
MatMulV2MKLDNNKernel<float>,
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>,
MatMulV2MKLDNNKernel<int8_t>,
MatMulV2MKLDNNKernel<uint8_t>);
MatMulMKLDNNKernel<float>,
MatMulMKLDNNKernel<paddle::platform::bfloat16>,
MatMulMKLDNNKernel<int8_t>,
MatMulMKLDNNKernel<uint8_t>);

REGISTER_OP_KERNEL(matmul_grad,
MKLDNN,
::paddle::platform::CPUPlace,
MatMulGradMKLDNNKernel<float>,
MatMulGradMKLDNNKernel<paddle::platform::bfloat16>);

REGISTER_OP_KERNEL(matmul_v2,
MKLDNN,
::paddle::platform::CPUPlace,
MatMulV2MKLDNNKernel<float>,
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>,
MatMulV2MKLDNNKernel<int8_t>,
MatMulV2MKLDNNKernel<uint8_t>);
4 changes: 3 additions & 1 deletion paddle/fluid/operators/ops_extra_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ const std::unordered_map<std::string, ExtraAttrPropertySet>
{"fuse_alpha", ExtraAttrProperty::ONEDNN},
{"fuse_beta", ExtraAttrProperty::ONEDNN},
{"fuse_relu", ExtraAttrProperty::ONEDNN},
{"fused_output_scale", ExtraAttrProperty::ONEDNN},
{"fuse_residual_connection", ExtraAttrProperty::ONEDNN},
{"fuse_with_relu", ExtraAttrProperty::ONEDNN},
{"fused_reshape_Out", ExtraAttrProperty::ONEDNN},
Expand Down Expand Up @@ -221,7 +222,8 @@ class ExtraInfoUtils {
std::unordered_map<std::string, std::vector<std::string>>
g_extra_input_names_map_ = {{"conv2d", {"Bias", "ResidualData"}},
{"conv2d_transpose", {"Bias"}},
{"conv2d_grad", {"Bias"}}};
{"conv2d_grad", {"Bias"}},
{"matmul_v2", {"ResidualData"}}};
std::vector<std::string> empty_extra_input_names_;
};

Expand Down
21 changes: 14 additions & 7 deletions paddle/phi/backends/onednn/onednn_reuse.h
Original file line number Diff line number Diff line change
Expand Up @@ -1874,9 +1874,11 @@ class MatmulOneDNNHandler : public OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
if (scale_out != 1.0f) {
matmul_attrs.set_output_scales(0, {scale_out});
}
const auto* residual_data = dev_ctx.HasDnnInput("ResidualData")
? dev_ctx.GetDnnInput("ResidualData")
: nullptr;

if (dev_ctx.HasDnnInput("ResidualData")) {
auto* residual_data = dev_ctx.GetDnnInput("ResidualData");
if (residual_data) {
auto residual_data_tz = vectorize(residual_data->dims());
auto residual_data_md = memory::desc(residual_data_tz,
OneDNNGetDataType<OT>(),
Expand All @@ -1893,9 +1895,11 @@ class MatmulOneDNNHandler : public OneDNNHandlerNoCachingT<XT, dnnl::matmul> {

AppendActivation(dev_ctx, post_operations);

if (dev_ctx.HasDnnAttr("fused_output_scale")) {
float scale_alpha =
PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fused_output_scale"));
const float scale_alpha =
dev_ctx.HasDnnAttr("fused_output_scale")
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fused_output_scale"))
: 1.0f;
if (scale_alpha != 1.0f) {
post_operations.append_eltwise(
1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
}
Expand Down Expand Up @@ -2014,8 +2018,11 @@ void ExecuteMatmul(const OneDNNContext& dev_ctx,
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};

if (dev_ctx.HasDnnInput("ResidualData")) {
auto* residual_data = dev_ctx.GetDnnInput("ResidualData");
const auto* residual_data = dev_ctx.HasDnnInput("ResidualData")
? dev_ctx.GetDnnInput("ResidualData")
: nullptr;

if (residual_data) {
const auto residual_data_memory_p = handler.AcquireSrcMemory(residual_data);
matmul_args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1,
*residual_data_memory_p});
Expand Down
164 changes: 164 additions & 0 deletions paddle/phi/kernels/onednn/matmul_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/matmul_kernel.h"

#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

DDim GetDimsForInput(const OneDNNContext &dev_ctx,
DDim input_dims,
std::string input_name) {
auto shape =
dev_ctx.HasDnnAttr("fused_reshape_" + input_name)
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_reshape_" + input_name))
: std::vector<int>();
auto axis = dev_ctx.HasDnnAttr("fused_transpose_" + input_name)
? PADDLE_GET_CONST(
std::vector<int>,
dev_ctx.GetDnnAttr("fused_transpose_" + input_name))
: std::vector<int>();
if (!shape.empty() && !axis.empty()) {
return input_dims.reshape(shape).transpose(axis);
}
return input_dims;
}

void CalculateMatrixDims(const std::vector<int64_t> &x_dims,
const std::vector<int64_t> &y_dims,
std::vector<int64_t> *x_bd_dims,
std::vector<int64_t> *y_bd_dims,
DenseTensor *out,
const bool is_output_fused) {
if (x_dims.size() == 1) {
(*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) {
(*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[1];
(*x_bd_dims)[(*x_bd_dims).size() - 2] = x_dims[0];
} else {
for (size_t i = 0; i < x_dims.size(); ++i) {
(*x_bd_dims)[(*x_bd_dims).size() - x_dims.size() + i] = x_dims[i];
}
}
if (y_dims.size() == 1) {
(*y_bd_dims)[(*x_bd_dims).size() - 2] = y_dims[0];
} else if (y_dims.size() == 2) {
(*y_bd_dims)[(*y_bd_dims).size() - 1] = y_dims[1];
(*y_bd_dims)[(*y_bd_dims).size() - 2] = y_dims[0];
} else {
for (size_t i = 0; i < y_dims.size(); ++i) {
(*y_bd_dims)[(*y_bd_dims).size() - y_dims.size() + i] = y_dims[i];
}
}

if (!is_output_fused && x_dims.size() > 2 && y_dims.size() > 2) {
auto out_dims = vectorize(out->dims());
for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) {
PADDLE_ENFORCE_EQ(
(*x_bd_dims)[i] == (*y_bd_dims)[i] || (*x_bd_dims)[i] == 1 ||
(*y_bd_dims)[i] == 1,
true,
errors::InvalidArgument(
"Tensor dimensions are incorrect for broadcasting."
"Dimensions in X and Y must be same or equal to 1, but "
"received x_dim[%d]=%d and y_dims[%d]= %d",
i,
(*x_bd_dims)[i],
i,
(*y_bd_dims)[i]));
(out_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]);
}
out->Resize(make_ddim((out_dims)));
}
}

template <typename T, typename Context>
void MatmulKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
bool transpose_x,
bool transpose_y,
DenseTensor *out) {
if (dev_ctx.HasDnnAttr("head_number")) {
const auto head_number =
PADDLE_GET_CONST(int, dev_ctx.GetDnnAttr("head_number"));
PADDLE_ENFORCE_EQ(
head_number,
1,
errors::Unimplemented(
"oneDNN matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d",
head_number));
}

constexpr bool is_int8 = funcs::is_int8<T>();
constexpr bool is_bfloat16 = funcs::is_bfloat16<T>();
const bool force_fp32_output =
dev_ctx.HasDnnAttr("force_fp32_output")
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output"))
: false;

bool fuse_relu = false;
if (dev_ctx.HasDnnAttr("fuse_activation")) {
auto act_type =
PADDLE_GET_CONST(std::string, dev_ctx.GetDnnAttr("fuse_activation"));
if (act_type == "relu" || act_type == "relu6") {
fuse_relu = true;
}
}

auto x_dims = vectorize(GetDimsForInput(dev_ctx, x.dims(), "X"));
auto y_dims = vectorize(GetDimsForInput(dev_ctx, y.dims(), "Y"));

int ndims = std::max(x_dims.size(), y_dims.size());
ndims = std::max(ndims, 3);

std::vector<int64_t> x_bd_dims(ndims, 1);
std::vector<int64_t> y_bd_dims(ndims, 1);

CalculateMatrixDims(x_dims,
y_dims,
&x_bd_dims,
&y_bd_dims,
out,
funcs::IsOutputFused(dev_ctx));

if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
funcs::ExecuteMatmul<T, float>(
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
} else if (is_bfloat16) {
funcs::ExecuteMatmul<T, paddle::platform::bfloat16>(
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
} else if (fuse_relu) {
funcs::ExecuteMatmul<T, uint8_t>(
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
} else {
funcs::ExecuteMatmul<T, int8_t>(
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
}
}

} // namespace phi

PD_REGISTER_KERNEL(matmul,
OneDNN,
ONEDNN,
phi::MatmulKernel,
float,
phi::dtype::bfloat16,
int8_t,
uint8_t) {}

0 comments on commit ad04a4b

Please sign in to comment.