Skip to content

Commit

Permalink
support sparse output for lookup table grad op (#5145)
Browse files Browse the repository at this point in the history
* add sparse support for sum op

* typo fix

* fix gpu build error

* fix unittest error

* typo fix

* infer var type and shape in op_test

* follow comments

* fix build error

* bypass some unittests depend on NetOp

* support sparse output for lookup table grad op

* refine codes

* fix gpu build error

* fix lookup table grad gpu kernel

* fix ci

* fix ci

* fix ci

* fix bug in lookup_table_grad op

* fix bug in test_word2vec

* register double kernel for some operators

* set is_sparse=True in test_word2vec

* fix lookup table grad op CUDA kernel bug

* disable test_modified_huber_loss_op temporarily

* disable test_lstm_unit_op temporarily
  • Loading branch information
QiJune authored Oct 28, 2017
1 parent 3ecad8a commit 008f40c
Show file tree
Hide file tree
Showing 23 changed files with 218 additions and 114 deletions.
8 changes: 4 additions & 4 deletions paddle/operators/cross_entropy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace {

template <typename T>
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
const int* label, const int N,
const int64_t* label, const int N,
const int D) {
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
// CUDA_1D_KERNEL_LOOP(i, N) {
Expand Down Expand Up @@ -77,8 +77,8 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
const T* x_data = x->data<T>();

int batch_size = x->dims()[0];
int class_num = x->dims()[1];
int64_t batch_size = x->dims()[0];
int64_t class_num = x->dims()[1];

int block = 512;
int grid = (batch_size * class_num + block - 1) / block;
Expand All @@ -93,7 +93,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
} else {
math::SetConstant<platform::GPUPlace, T> functor;
functor(ctx.device_context(), dx, 0);
auto* label_data = label->data<int>();
auto* label_data = label->data<int64_t>();
grid = (batch_size + block - 1) / block;
CrossEntropyGradientKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
Expand Down
14 changes: 7 additions & 7 deletions paddle/operators/cross_entropy_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,28 +54,28 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());

int class_num = x->dims()[1];
int64_t class_num = x->dims()[1];
if (ctx.Attr<bool>("soft_label")) {
auto x_mat = EigenMatrix<T>::From(*x);
auto dy_mat = EigenMatrix<T>::From(*dy);
auto lbl_mat = EigenMatrix<T>::From(*label);
auto dx_mat = EigenMatrix<T>::From(*dx);

dx_mat.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
-(lbl_mat * dy_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)) /
x_mat);
-(lbl_mat *
dy_mat.broadcast(Eigen::DSizes<int64_t, 2>(1, class_num)) / x_mat);
} else {
int batch_size = x->dims()[0];
int64_t batch_size = x->dims()[0];
const T* dy_data = dy->data<T>();
const T* x_data = x->data<T>();
const int* label_data = label->data<int>();
const int64_t* label_data = label->data<int64_t>();

math::SetConstant<platform::CPUPlace, T> functor;
functor(ctx.device_context(), dx, 0);

for (int i = 0; i < batch_size; ++i) {
for (int64_t i = 0; i < batch_size; ++i) {
PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num);
int index = i * class_num + label_data[i];
int64_t index = i * class_num + label_data[i];
dx_data[index] = -dy_data[i] / x_data[index];
}
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/feed_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class FeedOp : public framework::OperatorBase {

auto col = Attr<int>("col");

VLOG(3) << "Feed Var " << feed_var_name << "'s " << col << " column to var"
VLOG(3) << "Feed Var " << feed_var_name << "'s " << col << " column to var "
<< out_name;

auto &feed_list = feed_var->Get<framework::FeedFetchList>();
Expand Down
44 changes: 39 additions & 5 deletions paddle/operators/lookup_table_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
limitations under the License. */

#include "paddle/operators/lookup_table_op.h"
#include "paddle/framework/var_type_inference.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -60,6 +61,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"Ids must be a column vector with rank = 2."
"The 2nd dimension size must be 1");
AddOutput("Out", "The lookup results, which have the same type with W.");
AddAttr<bool>("is_sparse", "Sparse update").SetDefault(false);
AddComment(R"DOC(
This operator is used to perform lookups on the parameter W,
then concatenated into a dense tensor.
Expand All @@ -70,6 +72,15 @@ or not. And the output only shares the LoD with input `Ids`.
}
};

class LookupTableOpGradDescMaker
: public framework::DefaultGradOpDescMaker<true> {
using ::paddle::framework::DefaultGradOpDescMaker<
true>::DefaultGradOpDescMaker;

protected:
virtual std::string GradOpType() const { return "lookup_table_grad"; }
};

class LookupTableOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
Expand All @@ -86,12 +97,35 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
}
};

class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDescBind& op_desc,
framework::BlockDescBind* block) const override {
auto out_var_name = op_desc.Output(framework::GradVarName("W")).front();
auto attr = op_desc.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows";
block->Var(out_var_name)->SetType(framework::VarDesc::SELECTED_ROWS);
} else {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::VarDesc::LOD_TENSOR);
}
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker,
lookup_table_grad, ops::LookupTableOpGrad);

REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>);
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>);
REGISTER_OPERATOR(lookup_table, ops::LookupTableOp,
ops::LookupTableOpGradDescMaker, ops::LookupTableOpMaker);
REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad,
ops::LookupTableOpGradVarTypeInference);

REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>,
ops::LookupTableKernel<double>);
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>,
ops::LookupTableGradKernel<double>);
100 changes: 67 additions & 33 deletions paddle/operators/lookup_table_op.cu
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand All @@ -14,22 +11,21 @@

#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/lookup_table_op.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void LookupTable(T* output, const T* table, const int32_t* ids,
const int N, const int K, const int D) {
__global__ void LookupTable(T* output, const T* table, const int64_t* ids,
const int64_t N, const int64_t K, const int64_t D) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX;

while (idy < K) {
int id = ids[idy];
int64_t id = ids[idy];
PADDLE_ASSERT(id >= 0);
PADDLE_ASSERT(id < N);
T* out = output + idy * D;
Expand All @@ -42,8 +38,9 @@ __global__ void LookupTable(T* output, const T* table, const int32_t* ids,
}

template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids,
const int N, const int K, const int D) {
__global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids,
const int64_t N, const int64_t K,
const int64_t D) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX;

Expand Down Expand Up @@ -71,7 +68,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
auto ids = ids_t->data<int32_t>();
auto ids = ids_t->data<int64_t>();
auto table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace());

Expand All @@ -88,34 +85,71 @@ template <typename T>
class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto ids_t = context.Input<Tensor>("Ids");
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));

int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1];
int K = ids_t->numel();
const int32_t* ids = ids_t->data<int32_t>();
const T* d_output = d_output_t->data<T>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());

auto t = framework::EigenVector<T>::Flatten(*d_table_t);
t.device(context.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));

dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTableGrad<T, 128, 8, 8><<<
grids, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
bool is_sparse = context.Attr<bool>("is_sparse");
if (is_sparse) {
auto* ids = context.Input<Tensor>("Ids");
auto* table = context.Input<Tensor>("W");
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));

auto* ids_data = ids->data<int64_t>();
auto ids_dim = ids->dims();

auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context())
.stream();
// copy GPU memory to CPU pinned memory
framework::Vector<int64_t> new_rows;
new_rows.resize(ids_dim[0]);
auto gpu_place = boost::get<platform::GPUPlace>(context.GetPlace());

memory::Copy(platform::CPUPlace(), new_rows.data(), gpu_place, ids_data,
ids_dim[0] * sizeof(int64_t), stream);

d_table->set_rows(new_rows);

auto* d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_dim[0], table->dims()[1]});
d_table_value->mutable_data<T>(context.GetPlace());

auto* d_table_data = d_table_value->data<T>();
auto* d_output_data = d_output->data<T>();
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
d_output->numel(), stream);

} else {
auto ids_t = context.Input<Tensor>("Ids");
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));

int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1];
int K = ids_t->numel();
const int64_t* ids = ids_t->data<int64_t>();
const T* d_output = d_output_t->data<T>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());

auto t = framework::EigenVector<T>::Flatten(*d_table_t);
t.device(context.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));

dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTableGrad<T, 128, 8,
8><<<grids, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context())
.stream()>>>(d_table, d_output, ids, N, K, D);
}
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(lookup_table_grad,
ops::LookupTableGradCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>,
ops::LookupTableCUDAKernel<double>);
REGISTER_OP_GPU_KERNEL(lookup_table_grad, ops::LookupTableGradCUDAKernel<float>,
ops::LookupTableGradCUDAKernel<double>);
Loading

0 comments on commit 008f40c

Please sign in to comment.