Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5 No.16】为 Paddle 新增 EmbeddingBag API (WIP) #58027

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
a611cd9
make the new branch from the develop and transfer my modification
JianmingGuo Dec 12, 2022
99d481d
fix the problem of indent
JianmingGuo Dec 13, 2022
64b821a
fix the problem of gpu build
JianmingGuo Dec 17, 2022
50d8421
delete the operators/xx and modify the ops.yaml & backward.yaml
JianmingGuo Jan 5, 2023
051152a
delete the operators/xx and modify the ops.yaml & backward.yaml
JianmingGuo Feb 2, 2023
4268d82
Merge branch 'develop' into embeddingbag1
JianmingGuo Feb 2, 2023
d1cfb03
Merge branch 'PaddlePaddle:develop' into embeddingbag1
JianmingGuo Feb 2, 2023
a70c447
delete the ops/campat/xx.sig to solve the problem of multiple difinition
JianmingGuo Feb 2, 2023
42edc6b
unitest: initializer from fluid.initializer to paddle.nn.initializer
JianmingGuo Feb 2, 2023
038c3a9
unitest: label from fluid.layers.data to paddle.static.data
JianmingGuo Feb 2, 2023
b1c7f47
unitest: label from fluid.layers.data to paddle.static.data
JianmingGuo Feb 2, 2023
cc7c380
update the examples
JianmingGuo Feb 6, 2023
86bc1b7
change the datatype for windows error
JianmingGuo Feb 6, 2023
d9d2a6e
update the static check
JianmingGuo Feb 7, 2023
e18a93a
codestyle check except clang-format
JianmingGuo Feb 7, 2023
f2dd59e
update the static check containing clang-format
JianmingGuo Feb 8, 2023
1bb6818
Merge branch 'PaddlePaddle:develop' into embeddingbag1
JianmingGuo Feb 8, 2023
3f57c57
modify unitests for corverage-CI
JianmingGuo Feb 9, 2023
76584f8
Merge branch 'embeddingbag1' of github.com:JianmingGuo/Paddle into em…
JianmingGuo Feb 9, 2023
6a51a45
modify unitests for corverage-CI
JianmingGuo Feb 9, 2023
318702f
modify unitests for corverage-CI
JianmingGuo Feb 9, 2023
53cd75c
Merge branch 'PaddlePaddle:develop' into embeddingbag1
JianmingGuo Feb 9, 2023
20c4653
initialize params_grad with 0
JianmingGuo Feb 14, 2023
9fe3724
Merge branch 'PaddlePaddle:develop' into embeddingbag1
JianmingGuo Feb 14, 2023
c2ae8e8
add op unitest
JianmingGuo Feb 14, 2023
83c1760
Merge branch 'embeddingbag1' of github.com:JianmingGuo/Paddle into em…
JianmingGuo Feb 14, 2023
17c72a4
fix the problem in windows
JianmingGuo Feb 14, 2023
3ef914c
fix the problem of unused variables of seq
JianmingGuo Feb 20, 2023
fbdde28
add comments in embedding_bag_grad_kernel.cu
JianmingGuo Feb 23, 2023
9e4aa94
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xlcjz Oct 2, 2023
3e98492
🐞 fix(continue-embeddingbag): rm fluid
xlcjz Oct 11, 2023
1798b63
fix(continue-embeddingbag): rm fluid
xlcjz Oct 11, 2023
72a58f7
fix(continue-embeddingbag): replace memory header
xlcjz Oct 20, 2023
7d9a6a8
feat(embedding_bag): refactor op gpu implementation
xlcjz Nov 3, 2023
d0644f6
Merge branch 'develop' into continue-embeddingbag
xlcjz Nov 3, 2023
d33024a
fix(embedding_bag): code style
xlcjz Nov 4, 2023
8ddfdda
fix(kernel): refine forward kernel
xlcjz Dec 8, 2023
0ffc4e8
Merge branch 'develop' into continue-embeddingbag
xlcjz Dec 8, 2023
cc9df1e
fix typo
xlcjz Dec 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,16 @@
backward : elu_double_grad
inplace : (out_grad -> x_grad)

- backward_op : embedding_bag_grad
forward : embedding_bag (Tensor input, Tensor weight, Tensor per_sample_weight, int padding_idx, str mode, bool sparse=false) -> Tensor(out)
args : (Tensor input, Tensor weight, Tensor per_sample_weight, Tensor out_grad, str mode)
output : Tensor(weight_grad), Tensor(per_sample_weight_grad)
infer_meta :
func : EmbeddingBagGradInferMeta
param : [input, weight, per_sample_weight]
kernel :
func : embedding_bag_grad

- backward_op : erf_grad
forward : erf (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,16 @@
inplace : (x -> out)
backward : elu_grad

- op : embedding_bag
args : (Tensor input, Tensor weight, Tensor per_sample_weight, int padding_idx=-1, str mode="sum", bool sparse=false)
output : Tensor
infer_meta :
func : EmbeddingBagInferMeta
param : [input, weight, per_sample_weight, padding_idx, mode]
kernel :
func : embedding_bag
param : [input, weight, per_sample_weight, padding_idx, mode]

- op : equal_all
args : (Tensor x, Tensor y)
output : Tensor(out)
Expand Down
1 change: 0 additions & 1 deletion paddle/phi/core/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ struct DeviceContext::Impl {
ClearHolder(tensor);
}
}

auto* allocator =
(fake_alloc || tensor->numel() == 0) && requested_size == 0
? zero_allocator_
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1285,4 +1285,17 @@ void SetValueGradInferMeta(const MetaTensor& out_grad,
}
}

void EmbeddingBagGradInferMeta(const MetaTensor& input,
const MetaTensor& weight,
const MetaTensor& per_sample_weight,
MetaTensor* weight_grad,
MetaTensor* per_sample_weight_grad) {
if (weight_grad) {
weight_grad->share_meta(weight);
}
if (per_sample_weight_grad) {
per_sample_weight_grad->share_meta(per_sample_weight);
}
}

} // namespace phi
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,4 +494,10 @@ void SetValueGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& values,
MetaTensor* x_grad,
MetaTensor* value_grad);

void EmbeddingBagGradInferMeta(const MetaTensor& input,
const MetaTensor& weight,
const MetaTensor& per_sample_weight,
MetaTensor* weight_grad,
MetaTensor* per_sample_weight_grad);
} // namespace phi
39 changes: 39 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1526,4 +1526,43 @@ void QuantLinearInferMeta(const MetaTensor& x,
y->set_dtype(x.dtype());
}

void EmbeddingBagInferMeta(const MetaTensor& input,
const MetaTensor& weight,
const MetaTensor& per_sample_weight,
int64_t padding_idx,
const std::string& mode,
MetaTensor* out) {
const auto& table_dims = weight.dims();
const auto& ids_dims = input.dims();
auto ids_dims_size = ids_dims.size();
const auto& weight_dims = per_sample_weight.dims();
int ids_rank = ids_dims.size();
VLOG(5) << "ids rank is " << ids_rank << std::endl;
PADDLE_ENFORCE_EQ(
ids_dims,
weight_dims,
phi::errors::InvalidArgument("ShapeError: The shapes of 'input' and "
"'per_sample_weight' must be the same."
"But received input's shape = [%s],"
"per_sample_weight's shape = [%s].",
ids_dims,
weight_dims));
PADDLE_ENFORCE_EQ(
table_dims.size(),
2,
phi::errors::InvalidArgument(
"ShapeError: The dimensions of the 'lookup table' tensor must be 2."
"But received lookup table's dimensions = %d, "
"lookup table's shape = [%s].",
table_dims.size(),
table_dims));

auto output_dims =
phi::vectorize(phi::slice_ddim(ids_dims, 0, ids_dims_size - 1));
output_dims.push_back(table_dims[1]);
out->set_dims(phi::make_ddim(output_dims));
out->set_dtype(weight.dtype());
out->share_lod(input);
}

} // namespace phi
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ void BoxCoderInferMeta(const MetaTensor& prior_box,
MetaTensor* output_box,
MetaConfig config = MetaConfig());

void EmbeddingBagInferMeta(const MetaTensor& input,
const MetaTensor& weight,
const MetaTensor& per_sample_weight,
int64_t padding_idx,
const std::string& mode,
MetaTensor* out);

void DpsgdInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
Expand Down
156 changes: 156 additions & 0 deletions paddle/phi/kernels/cpu/embedding_bag_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/embedding_bag_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

template <typename T, typename Context>
struct EmbeddingBagGradCPUFunctor {
EmbeddingBagGradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& params,
const DenseTensor& weight,
const DenseTensor& out_grad,
const std::string& mode,
DenseTensor* params_grad,
DenseTensor* weight_grad)

: dev_ctx_(dev_ctx),
input_(input),
params_(params),
weight_(weight),
out_grad_(out_grad),
mode_(mode),
params_grad_(params_grad),
weight_grad_(weight_grad) {}

using EigenVectorMap = Eigen::Map<Eigen::Vector<T, Eigen::Dynamic>>;
using ConstEigenVectorMap =
Eigen::Map<const Eigen::Vector<T, Eigen::Dynamic>>;
using EigenIndex = Eigen::Index;

template <typename IdT>
void apply() {
dev_ctx_.template Alloc<T>(params_grad_);
dev_ctx_.template Alloc<T>(weight_grad_);

const EigenIndex sequence_length = input_.dims()[1];
const EigenIndex output_dim = params_.dims()[1];

std::unordered_map<IdT, EigenIndex> index_map;
std::vector<std::pair<IdT, std::vector<EigenIndex>>> index_vec;

auto* d_grad = out_grad_.data<T>();
auto* d_weights = weight_.data<T>();
auto* d_params = params_.data<T>();
auto* d_inputs = input_.data<IdT>();

phi::funcs::SetConstant<Context, T>()(
dev_ctx_, params_grad_, static_cast<T>(0));
auto* d_params_grad = params_grad_->data<T>();
auto* d_weight_grad = weight_grad_->data<T>();

EigenIndex bags = input_.dims()[0];

for (EigenIndex i = 0; i < bags * sequence_length; ++i) {
auto index = d_inputs[i];
if (index_map.find(index) == index_map.end()) {
index_map[index] = index_vec.size();
index_vec.push_back({index, {}});
}
index_vec[index_map[index]].second.push_back(i);
}

auto ids_num = static_cast<int64_t>(index_vec.size());
for (EigenIndex i = 0; i < ids_num; ++i) {
EigenVectorMap params_grads_slice(
&d_params_grad[index_vec[i].first * output_dim], output_dim);

for (EigenIndex index : index_vec[i].second) {
const EigenIndex bag = index / sequence_length;
const ConstEigenVectorMap grads_slice(&d_grad[bag * output_dim],
output_dim);
params_grads_slice += grads_slice * d_weights[index];
}
if (mode_ == "mean") {
params_grads_slice /= static_cast<T>(sequence_length);
}
}

for (EigenIndex i = 0; i < bags; ++i) {
for (EigenIndex j = 0; j < sequence_length; ++j) {
const ConstEigenVectorMap grads_slice(&d_grad[i * output_dim],
output_dim);
const ConstEigenVectorMap params_slice(
&d_params[d_inputs[i * sequence_length + j] * output_dim],
output_dim);
if (mode_ == "sum") {
d_weight_grad[i * sequence_length + j] =
params_slice.dot(grads_slice);
} else {
d_weight_grad[i * sequence_length + j] =
params_slice.dot(grads_slice) / static_cast<T>(sequence_length);
}
}
}
}

private:
const Context& dev_ctx_;
const DenseTensor& input_;
const DenseTensor& params_;
const DenseTensor& weight_;
const DenseTensor& out_grad_;
const std::string& mode_;
DenseTensor* params_grad_;
DenseTensor* weight_grad_;
};

template <typename T, typename Context>
void EmbeddingBagGradKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& params,
const DenseTensor& weight,
const DenseTensor& out_grad,
const std::string& mode,
DenseTensor* params_grad,
DenseTensor* weight_grad) {
EmbeddingBagGradCPUFunctor<T, Context> functor(
ctx, input, params, weight, out_grad, mode, params_grad, weight_grad);

if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64"));
}
}

} // namespace phi

PD_REGISTER_KERNEL(embedding_bag_grad,
CPU,
ALL_LAYOUT,
phi::EmbeddingBagGradKernel,
float,
double,
phi::dtype::bfloat16) {}
112 changes: 112 additions & 0 deletions paddle/phi/kernels/cpu/embedding_bag_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/embedding_bag_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
namespace phi {

template <typename T, typename Context>
struct EmbeddingBagCPUFunctor {
EmbeddingBagCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& per_sample_weight,
const int64_t padding_idx,
const std::string& mode,
DenseTensor* out)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
per_sample_weight_(per_sample_weight),
padding_idx_(padding_idx),
mode_(mode),
out_(out) {}

using EigenArrayMap =
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
using EigenVectorMap = Eigen::Map<Eigen::Vector<T, Eigen::Dynamic>>;
using ConstEigenVectorMap =
Eigen::Map<const Eigen::Vector<T, Eigen::Dynamic>>;
using EigenIndex = Eigen::Index;

template <typename IdT>
void apply() {
dev_ctx_.template Alloc<T>(out_);
const EigenIndex bag_number = input_.dims()[0];
const EigenIndex sequence_length = input_.dims()[1];
const EigenIndex output_dim = weight_.dims()[1];

auto* input_d = input_.data<IdT>();

auto* weight_d = weight_.data<T>();
auto* per_sample_weight_d = per_sample_weight_.data<T>();

auto* output_d = out_->data<T>();

for (EigenIndex bag = 0; bag < bag_number; ++bag) {
EigenVectorMap output_slice(&output_d[bag * output_dim], output_dim);
output_slice.setZero();
for (EigenIndex seq = 0; seq < sequence_length; ++seq) {
const ConstEigenVectorMap weight_slice(
&weight_d[input_d[bag * sequence_length + seq] * output_dim],
output_dim);
output_slice +=
weight_slice * per_sample_weight_d[bag * sequence_length + seq];
}
if (mode_ == "mean") {
output_slice /= static_cast<T>(sequence_length);
}
}
}

private:
const Context& dev_ctx_;
const DenseTensor& input_;
const DenseTensor& weight_;
const DenseTensor& per_sample_weight_;
const int64_t padding_idx_;
const std::string& mode_;
DenseTensor* out_;
};
template <typename T, typename Context>
void EmbeddingBagKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& per_sample_weight,
int64_t padding_idx,
const std::string& mode,
DenseTensor* out) {
EmbeddingBagCPUFunctor<T, Context> functor(
ctx, input, weight, per_sample_weight, padding_idx, mode, out);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"embebddingbag input only support int32 and int64"));
}
}

} // namespace phi

PD_REGISTER_KERNEL(embedding_bag,
CPU,
ALL_LAYOUT,
phi::EmbeddingBagKernel,
float,
double,
phi::dtype::bfloat16) {}
Loading