diff --git a/cmake/external/box_ps.cmake b/cmake/external/box_ps.cmake index a1e97f72b1e41..e85c185aaf4b8 100644 --- a/cmake/external/box_ps.cmake +++ b/cmake/external/box_ps.cmake @@ -20,7 +20,7 @@ IF((NOT DEFINED BOX_PS_VER) OR (NOT DEFINED BOX_PS_URL)) SET(BOX_PS_VER "0.1.1" CACHE STRING "" FORCE) SET(BOX_PS_NAME "box_ps" CACHE STRING "" FORCE) #SET(BOX_PS_URL "http://box-ps.gz.bcebos.com/box_ps.tar.gz" CACHE STRING "" FORCE) - SET(BOX_PS_URL "data-im.baidu.com:/home/work/var/CI_DATA/im/static/box_ps.tar.gz/box_ps.tar.gz.20" CACHE STRING "" FORCE) + SET(BOX_PS_URL "data-im.baidu.com:/home/work/var/CI_DATA/im/static/box_ps.tar.gz/box_ps.tar.gz.30" CACHE STRING "" FORCE) ENDIF() MESSAGE(STATUS "BOX_PS_NAME: ${BOX_PS_NAME}, BOX_PS_URL: ${BOX_PS_URL}") SET(BOX_PS_SOURCE_DIR "${THIRD_PARTY_PATH}/box_ps") diff --git a/paddle/fluid/framework/fleet/box_wrapper.cc b/paddle/fluid/framework/fleet/box_wrapper.cc index 397a72761224f..662b29632f8a5 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.cc +++ b/paddle/fluid/framework/fleet/box_wrapper.cc @@ -424,6 +424,9 @@ void BoxWrapper::PullSparse(const paddle::platform::Place& place, feature_type_ == static_cast(boxps::FEATURE_SHOWCLK)) { \ PullSparseCase>( \ place, keys, values, slot_lengths, hidden_size, expand_embed_dim); \ + } else if (feature_type_ == static_cast(boxps::FEATURE_CONV)) { \ + PullSparseCase>( \ + place, keys, values, slot_lengths, hidden_size, expand_embed_dim); \ } else if (feature_type_ == static_cast(boxps::FEATURE_VARIABLE)) { \ PullSparseCase>( \ place, keys, values, slot_lengths, hidden_size, expand_embed_dim); \ @@ -475,28 +478,33 @@ void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place, } \ } break -#define PUSHSPARSE_CASE(i, ...) \ - case i: { \ - constexpr size_t ExpandDim = i; \ - if (feature_type_ == static_cast(boxps::FEATURE_SHARE_EMBEDDING)) { \ - PushSparseGradCase< \ - boxps::FeaturePushValueGpuShareEmbedding>( \ - place, keys, grad_values, slot_lengths, hidden_size, \ - expand_embed_dim, batch_size); \ - } else if (feature_type_ == static_cast(boxps::FEATURE_PCOC)) { \ - PushSparseGradCase< \ - boxps::FeaturePushValueGpuPCOC>( \ - place, keys, grad_values, slot_lengths, hidden_size, \ - expand_embed_dim, batch_size); \ +#define PUSHSPARSE_CASE(i, ...) \ + case i: { \ + constexpr size_t ExpandDim = i; \ + if (feature_type_ == static_cast(boxps::FEATURE_SHARE_EMBEDDING)) { \ + PushSparseGradCase< \ + boxps::FeaturePushValueGpuShareEmbedding>( \ + place, keys, grad_values, slot_lengths, hidden_size, \ + expand_embed_dim, batch_size); \ + } else if (feature_type_ == static_cast(boxps::FEATURE_PCOC)) { \ + PushSparseGradCase< \ + boxps::FeaturePushValueGpuPCOC>( \ + place, keys, grad_values, slot_lengths, hidden_size, \ + expand_embed_dim, batch_size); \ } else if (feature_type_ == static_cast(boxps::FEATURE_VARIABLE)) { \ PushSparseGradCase>( \ place, keys, grad_values, slot_lengths, hidden_size, \ expand_embed_dim, batch_size); \ - } else { \ - PushSparseGradCase>( \ - place, keys, grad_values, slot_lengths, hidden_size, \ - expand_embed_dim, batch_size); \ - } \ + } else if (feature_type_ == static_cast(boxps::FEATURE_CONV)) { \ + PushSparseGradCase< \ + boxps::FeaturePushValueGpuConv>( \ + place, keys, grad_values, slot_lengths, hidden_size, \ + expand_embed_dim, batch_size); \ + } else { \ + PushSparseGradCase>( \ + place, keys, grad_values, slot_lengths, hidden_size, \ + expand_embed_dim, batch_size); \ + } \ } break CheckEmbedSizeIsValid(hidden_size - cvm_offset_, expand_embed_dim); diff --git a/paddle/fluid/framework/fleet/box_wrapper.cu b/paddle/fluid/framework/fleet/box_wrapper.cu index 2aafb60f602e1..be78dd90e0a45 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.cu +++ b/paddle/fluid/framework/fleet/box_wrapper.cu @@ -1189,6 +1189,11 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place, stream, gpu_keys, gpu_values, total_values_gpu, hidden_size, \ EmbedxDim, total_length, total_dims, slot_lens, slot_num, key2slot, \ pull_embedx_scale_, cvm_offset_, gpu_restore_idx); \ + } else if (feature_type_ == static_cast(boxps::FEATURE_CONV)) { \ + FeaturePullCopy>( \ + stream, gpu_keys, gpu_values, total_values_gpu, hidden_size, \ + EmbedxDim, total_length, total_dims, slot_lens, slot_num, key2slot, \ + pull_embedx_scale_, cvm_offset_, gpu_restore_idx); \ } else { \ FeaturePullCopy>( \ stream, gpu_keys, gpu_values, total_values_gpu, hidden_size, \ @@ -1219,6 +1224,12 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place, stream, gpu_keys, gpu_values, total_values_gpu, hidden_size, \ EmbedxDim, ExpandDim, total_length, total_dims, slot_lens, slot_num, \ key2slot, 1.0, cvm_offset_, gpu_restore_idx); \ + } else if (feature_type_ == static_cast(boxps::FEATURE_CONV)) { \ + FeaturePullCopyNNCross< \ + boxps::FeaturePullValueGpuConv>( \ + stream, gpu_keys, gpu_values, total_values_gpu, hidden_size, \ + EmbedxDim, ExpandDim, total_length, total_dims, slot_lens, slot_num, \ + key2slot, 1.0, cvm_offset_, gpu_restore_idx); \ } else { \ FeaturePullCopyNNCross< \ boxps::FeaturePullValueGpu>( \ @@ -1479,6 +1490,12 @@ void BoxWrapper::CopyForPush(const paddle::platform::Place& place, total_length, batch_size, d_slot_vector, total_dims, slot_lens, \ slot_num, key2slot, cvm_offset_, gpu_sort_idx, gpu_sort_offset, \ gpu_sort_lens); \ + } else if (feature_type_ == static_cast(boxps::FEATURE_CONV)) { \ + FeaturePushCopy>( \ + stream, total_grad_values_gpu, grad_values, hidden_size, EmbedxDim, \ + total_length, batch_size, d_slot_vector, total_dims, slot_lens, \ + slot_num, key2slot, cvm_offset_, gpu_sort_idx, gpu_sort_offset, \ + gpu_sort_lens); \ } else { \ FeaturePushCopy>( \ stream, total_grad_values_gpu, grad_values, hidden_size, EmbedxDim, \ @@ -1505,6 +1522,13 @@ void BoxWrapper::CopyForPush(const paddle::platform::Place& place, ExpandDim, total_length, batch_size, d_slot_vector, total_dims, \ slot_lens, slot_num, key2slot, cvm_offset_, gpu_sort_idx, \ gpu_sort_offset, gpu_sort_lens); \ + } else if (feature_type_ == static_cast(boxps::FEATURE_CONV)) { \ + FeaturePushCopyVariable< \ + boxps::FeaturePushValueGpuConv>( \ + stream, total_grad_values_gpu, grad_values, hidden_size, EmbedxDim, \ + ExpandDim, total_length, batch_size, d_slot_vector, total_dims, \ + slot_lens, slot_num, key2slot, cvm_offset_, gpu_sort_idx, \ + gpu_sort_offset, gpu_sort_lens); \ } else { \ FeaturePushCopyNNCross< \ boxps::FeaturePushValueGpu>( \ diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h index c0125a7fc207f..13df3cadc45ee 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.h +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -579,6 +579,8 @@ class BoxWrapper { } else if (s_instance_->feature_type_ == static_cast(boxps::FEATURE_PCOC)) { s_instance_->cvm_offset_ = 8; + } else if (s_instance_->feature_type_ == static_cast(boxps::FEATURE_CONV)) { + s_instance_->cvm_offset_ = 4; } else { s_instance_->cvm_offset_ = 3; } diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cc b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cc new file mode 100644 index 0000000000000..47267b7f5748d --- /dev/null +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cc @@ -0,0 +1,220 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.h" +#include +namespace paddle { +namespace operators { + +class FusedSeqpoolCVMOpWithConv : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, "Inputs(X) of FusedSeqpoolCVMOpWithConv should not be empty."); + PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL, "Outputs(Out) of FusedSeqpoolCVMOpWithConv should not be empty."); + + auto cvm_dims = ctx->GetInputDim("CVM"); + PADDLE_ENFORCE_EQ(cvm_dims.size(), 2UL, platform::errors::InvalidArgument("Input(CVM)'s rank should be 2.")); + PADDLE_ENFORCE_EQ(cvm_dims[1], 3UL, + platform::errors::InvalidArgument("The 2nd dimension of Input(CVM) should be 3.")); + + auto ins_dims = ctx->GetInputsDim("X"); + const int cvm_offset = ctx->Attrs().Get("cvm_offset"); + const size_t num_inputs = ins_dims.size(); + std::vector outs_dims; + outs_dims.resize(num_inputs); + bool use_cvm = ctx->Attrs().Get("use_cvm"); + bool show_filter = ctx->Attrs().Get("show_filter"); + + PADDLE_ENFORCE_GT(num_inputs, 0UL, + platform::errors::InvalidArgument( + "Input tensors count should be greater than 0, " + "but received value is %d.", + num_inputs)); + + // The output height should be confirmed in Compute, + // since input lod is not accessible here. + PADDLE_ENFORCE_EQ(ins_dims[0].size(), 2, + platform::errors::InvalidArgument( + "The dims size of first input should be equal to 2, " + "but received value is %d.", + ins_dims[0].size())); + + for (size_t i = 0; i < num_inputs; ++i) { + const auto dims = ins_dims[i]; + int rank = dims.size(); + if (use_cvm) { + PADDLE_ENFORCE_GT( + dims[rank - 1], 2, + "Shape error in %lu id, the last dimension(embedding) of the " + "'X' tensor must be larger than 2.", + i); + } + // input lod is not accessible here + std::vector out_dim; + if (use_cvm) { + if (show_filter) { + out_dim = {-1, dims[rank - 1] - 1}; + } else { + out_dim = {-1, dims[rank - 1]}; + } + } else { + out_dim = {-1, dims[rank - 1] - cvm_offset}; + } + outs_dims[i] = framework::make_ddim(out_dim); + } + ctx->SetOutputsDim("Out", outs_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.device_context()); + } +}; + +class FusedSeqpoolCVMOpWithConvMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(vector) The input tensors of" + " operator.") + .AsDuplicable(); + AddInput("CVM", + "(Tensor), a 2-D Tensor with shape [N x 2], where N is the batch " + "size, 2 is show and click."); + AddOutput("Out", + "(vector) The output of Op does not contain LoD " + "information.") + .AsDuplicable(); + AddAttr("pooltype", + "(string, default 'SUM') the pooling pooltype of " + "SequencePoolOp, only support SUM now.") + .SetDefault("SUM") + .InEnum({"SUM"}); + AddAttr("pad_value", + "(float, default 0.0) The value to pad for empty sequence.") + .SetDefault(0.0); + AddAttr("use_cvm", "bool, use cvm or not").SetDefault(true); + AddAttr("cvm_offset", "(int, default 3)").SetDefault(3); + AddAttr("show_filter", "(bool, default false)").SetDefault(false); + + AddComment(R"DOC( +Fuse multiple pairs of Sequence Pool and CVM Operator. + +)DOC"); + } +}; + +class FusedSeqpoolCVMGradOpWithConv : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto og_dims = ctx->GetInputsDim(framework::GradVarName("Out")); + auto x_dims = ctx->GetInputsDim("X"); + auto cvm_dims = ctx->GetInputDim("CVM"); + const int cvm_offset = ctx->Attrs().Get("cvm_offset"); + bool use_cvm = ctx->Attrs().Get("use_cvm"); + bool show_filter = ctx->Attrs().Get("show_filter"); + + PADDLE_ENFORCE_EQ( + cvm_dims.size(), 2, + platform::errors::InvalidArgument("Input(CVM)'s rank should be 2.")); + + for (size_t i = 0; i < og_dims.size(); i++) { + PADDLE_ENFORCE_EQ( + og_dims[i].size(), x_dims[i].size(), + platform::errors::InvalidArgument( + "The rank of output grad must equal to Input(X). But " + "received: input rank %u, input shape [%s].", + og_dims[i].size(), og_dims[i])); + if (use_cvm) { + auto o_dim = og_dims[i][og_dims[i].size() - 1]; + if (show_filter) { + o_dim += 1; + } + PADDLE_ENFORCE_EQ( + o_dim, x_dims[i][og_dims[i].size() - 1], + platform::errors::InvalidArgument( + "The dimension mismatch between Input(OUT@GRAD) and " + "Input(X). Received Input(OUT@GRAD): input rank %u, " + "input shape [%s]; received Input(X): input rank %u, " + "input shape [%s].", + og_dims[i].size(), og_dims[i], x_dims[i].size(), x_dims[i])); + } else { + PADDLE_ENFORCE_EQ( + og_dims[i][og_dims[i].size() - 1], + x_dims[i][og_dims[i].size() - 1] - cvm_offset, + platform::errors::InvalidArgument( + "The dimension mismatch between Input(OUT@GRAD) and " + "Input(X). Received Input(OUT@GRAD): input rank %u, " + "input shape [%s]; received Input(X): input rank %u, " + "input shape [%s].", + og_dims[i].size(), og_dims[i], x_dims[i].size(), x_dims[i])); + } + } + for (size_t i = 0; i < x_dims.size(); ++i) { + ctx->ShareLoD("X", framework::GradVarName("X"), i, i); + ctx->ShareDim("X", framework::GradVarName("X"), i, i); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + +template +class FusedSeqpoolCVMGradOpWithConvMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op_desc_ptr) const override { + op_desc_ptr->SetType("fused_seqpool_cvm_with_conv_grad"); + op_desc_ptr->SetInput("X", this->Input("X")); + op_desc_ptr->SetInput("CVM", this->Input("CVM")); + + op_desc_ptr->SetInput(framework::GradVarName("Out"), + this->OutputGrad("Out")); + op_desc_ptr->SetOutput(framework::GradVarName("X"), + this->InputGrad("X", false)); + op_desc_ptr->SetOutput(framework::GradVarName("CVM"), + this->InputGrad("CVM")); + op_desc_ptr->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR(fused_seqpool_cvm_with_conv, ops::FusedSeqpoolCVMOpWithConv, + ops::FusedSeqpoolCVMOpWithConvMaker, + ops::FusedSeqpoolCVMGradOpWithConvMaker, + ops::FusedSeqpoolCVMGradOpWithConvMaker); +REGISTER_OPERATOR(fused_seqpool_cvm_with_conv_grad, ops::FusedSeqpoolCVMGradOpWithConv) + +REGISTER_OP_CPU_KERNEL(fused_seqpool_cvm_with_conv, + ops::FusedSeqpoolCVMOpWithConvCPUKernel) +REGISTER_OP_CPU_KERNEL(fused_seqpool_cvm_with_conv_grad, + ops::FusedSeqpoolCVMGradOpWithConvCPUKernel) diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cu b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cu new file mode 100644 index 0000000000000..1a92f72e1cfe4 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cu @@ -0,0 +1,449 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; + +#define GET_BLOCK(N) \ + ((N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS) + +#define CUDA_KERNEL_LOOP(i, n) \ + for (auto i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +// normal +template +__global__ void FusedSeqpoolWithConvKernelNormal(const size_t N, T **input_values, + T **seqpool_output_values, + size_t **lods_values, + const int batch_size, + const int embedding_size, + const float pad_value) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / embedding_size; + int offset = i % embedding_size; + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + auto &start = *(lods_values[x] + y); + auto &end = *(lods_values[x] + y + 1); + + double val = pad_value; + for (auto k = start; k < end; ++k) { + val += *(input_values[x] + k * embedding_size + offset); + } + *(seqpool_output_values[x] + y * embedding_size + offset) = val; + } +} +// join only need show input +template +__global__ void FusedCVMWithConvKernelNormal(const size_t N, T **output_values, + T **seqpool_output_values, + const int batch_size, + const int embedding_size, + const int noclk_embedding_size) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / noclk_embedding_size; + int offset = i % noclk_embedding_size; + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + if (offset == 0) { // show + *(output_values[x] + y * embedding_size) = + log(*(seqpool_output_values[x] + y * embedding_size) + 1); + } else if (offset == 1) { // click + *(output_values[x] + y * embedding_size + 1) = + log(*(seqpool_output_values[x] + y * embedding_size + 1) + 1); + } else if (offset == 2) { // conv + *(output_values[x] + y * embedding_size + 2) = + log(*(seqpool_output_values[x] + y * embedding_size + 2) + 1) - + log(*(seqpool_output_values[x] + y * embedding_size + 1) + 1); + } else { // filter show, offset - 1 + *(output_values[x] + y * noclk_embedding_size + offset) = + *(seqpool_output_values[x] + y * embedding_size + offset); + } + } +} + +// join only need show input +template +__global__ void FusedCVMWithConvKernelWithOutShow(const size_t N, T **output_values, + T **seqpool_output_values, + const int batch_size, + const int embedding_size, + const int noclk_embedding_size) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / noclk_embedding_size; + int offset = i % noclk_embedding_size; + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + if (offset == 0) { // show + // do nothing + } else if (offset == 1) { // click + *(output_values[x] + y * embedding_size) = + log(*(seqpool_output_values[x] + y * embedding_size + 1) + 1); + } else if (offset == 2) { // conv + *(output_values[x] + y * embedding_size + 1) = + log(*(seqpool_output_values[x] + y * embedding_size + 2) + 1) - + log(*(seqpool_output_values[x] + y * embedding_size + 1) + 1); + } else { // filter show, offset - 1 + *(output_values[x] + y * noclk_embedding_size + offset - 1) = + *(seqpool_output_values[x] + y * embedding_size + offset); + } + } +} + +// update not need show click input +template +__global__ void FusedCVMWithConvKernelNoCVM(const size_t N, T **output_values, + T **seqpool_output_values, + const int batch_size, + const int no_cvm_embedding_size, + const int cvm_offset) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / no_cvm_embedding_size; + int offset = i % no_cvm_embedding_size; + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + // no cvm + *(output_values[x] + y * no_cvm_embedding_size + offset) = + *(seqpool_output_values[x] + y * (no_cvm_embedding_size + cvm_offset) + + offset + cvm_offset); + } +} + +template +void FusedSeqpoolCVMWithConv(const paddle::platform::Place &place, + const std::vector &input_data, + const std::vector &output_data, + const std::vector &seqpool_output_data, + std::vector lods, const int batch_size, + const int slot_num, const int embedding_size, + const float padding_value, const bool use_cvm, + const int cvm_offset, bool show_filter) { + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get( + BOOST_GET_CONST(platform::CUDAPlace, place))) + ->stream(); + + size_t total_ptr_len = input_data.size() + output_data.size() + + seqpool_output_data.size() + lods.size(); + auto temp_ptr = memory::AllocShared(place, total_ptr_len * sizeof(void *)); + void *ptr = temp_ptr->ptr(); + + T **gpu_input_values = reinterpret_cast(temp_ptr->ptr()); + cudaMemcpyAsync(gpu_input_values, input_data.data(), + input_data.size() * sizeof(T *), cudaMemcpyHostToDevice, + stream); + T **gpu_output_values = + reinterpret_cast(&gpu_input_values[input_data.size()]); + cudaMemcpyAsync(gpu_output_values, output_data.data(), + output_data.size() * sizeof(T *), cudaMemcpyHostToDevice, + stream); + T **gpu_seqpool_output_values = + reinterpret_cast(&gpu_output_values[output_data.size()]); + cudaMemcpyAsync(gpu_seqpool_output_values, seqpool_output_data.data(), + seqpool_output_data.size() * sizeof(T *), + cudaMemcpyHostToDevice, stream); + size_t **lods_values = reinterpret_cast( + &gpu_seqpool_output_values[seqpool_output_data.size()]); + cudaMemcpyAsync(lods_values, lods.data(), lods.size() * sizeof(size_t *), + cudaMemcpyHostToDevice, stream); + + size_t N = static_cast(batch_size * slot_num * embedding_size); + // first sum pool + FusedSeqpoolWithConvKernelNormal<<>>( + N, gpu_input_values, gpu_seqpool_output_values, lods_values, batch_size, + embedding_size, padding_value); + // second log + if (use_cvm) { + if (show_filter) { + N = static_cast(batch_size * slot_num * (embedding_size - 1)); + FusedCVMWithConvKernelWithOutShow<<>>(N, gpu_output_values, + gpu_seqpool_output_values, batch_size, + embedding_size, embedding_size - 1); + } else { + FusedCVMWithConvKernelNormal<<>>(N, gpu_output_values, + gpu_seqpool_output_values, batch_size, + embedding_size, embedding_size); + } + } else { + // not need show click input + N = static_cast(batch_size * slot_num * + (embedding_size - cvm_offset)); + FusedCVMWithConvKernelNoCVM<<>>( + N, gpu_output_values, gpu_seqpool_output_values, batch_size, + (embedding_size - cvm_offset), cvm_offset); + } +} + // join grad + template + __global__ void FusedSeqpoolCVMWithConvGradKernelWithCVM( + const size_t N, T **out_grads_values, T **in_grads_values, T **cvm_values, + size_t **lods_values, const int batch_size, const int embedding_size, + const int cvm_offset) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / embedding_size; + int offset = i % embedding_size; // embedx offset + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + + T &val = (offset < cvm_offset) + ? *(cvm_values[x] + y * cvm_offset + offset) + : *(out_grads_values[x] + y * embedding_size + offset); + + auto &start = *(lods_values[x] + y); + auto &end = *(lods_values[x] + y + 1); + for (auto k = start; k < end; ++k) { + *(in_grads_values[x] + k * embedding_size + offset) = val; + } + } + } + + // join only show not has click + template + __global__ void FusedSeqpoolCVMWithConvGradKernelWithShow( + const size_t N, T **out_grads_values, T **in_grads_values, T **cvm_values, + size_t **lods_values, const int batch_size, const int embedding_size, + const int cvm_offset) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / embedding_size; + int offset = i % embedding_size; // embedx offset + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + + T &val = + (offset < cvm_offset) + ? *(cvm_values[x] + y * cvm_offset + offset) + : *(out_grads_values[x] + y * (embedding_size - 1) + offset - 1); + + auto &start = *(lods_values[x] + y); + auto &end = *(lods_values[x] + y + 1); + for (auto k = start; k < end; ++k) { + *(in_grads_values[x] + k * embedding_size + offset) = val; + } + } + } + +// update grad +template +__global__ void FusedSeqpoolCVMWithConvGradKernelNoCVM( + const size_t N, T **out_grads_values, T **in_grads_values, T **cvm_values, + size_t **lods_values, const int batch_size, const int embedding_size, + const int cvm_offset) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / embedding_size; + int offset = i % embedding_size; // embedx offset + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + + T &val = (offset < cvm_offset) + ? *(cvm_values[x] + y * cvm_offset + offset) + : *(out_grads_values[x] + y * (embedding_size - cvm_offset) + + offset - cvm_offset); + + auto &start = *(lods_values[x] + y); + auto &end = *(lods_values[x] + y + 1); + for (auto k = start; k < end; ++k) { + *(in_grads_values[x] + k * embedding_size + offset) = val; + } + } +} +template +void FusedSeqpoolCVMGradWithConv(const paddle::platform::Place &place, + const std::vector &out_grads_data, + const std::vector &in_grads_data, + const std::vector &cvm_data, + const std::vector &lods, + const int batch_size, const int slot_num, + const int embedding_size, const bool use_cvm, + const int cvm_offset, bool show_filter) { + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get( + BOOST_GET_CONST(platform::CUDAPlace, place))) + ->stream(); + size_t total_ptr_len = out_grads_data.size() + in_grads_data.size() + + cvm_data.size() + lods.size(); + auto temp_ptr = memory::AllocShared(place, total_ptr_len * sizeof(void *)); + T **gpu_out_grads_values = reinterpret_cast(temp_ptr->ptr()); + cudaMemcpyAsync(gpu_out_grads_values, out_grads_data.data(), + out_grads_data.size() * sizeof(T *), cudaMemcpyHostToDevice, + stream); + + T **gpu_in_grads_values = + reinterpret_cast(&gpu_out_grads_values[out_grads_data.size()]); + cudaMemcpyAsync(gpu_in_grads_values, in_grads_data.data(), + in_grads_data.size() * sizeof(T *), cudaMemcpyHostToDevice, + stream); + + T **gpu_cvm_values = + reinterpret_cast(&gpu_in_grads_values[in_grads_data.size()]); + cudaMemcpyAsync(gpu_cvm_values, cvm_data.data(), + cvm_data.size() * sizeof(T *), cudaMemcpyHostToDevice, + stream); + + size_t **lods_values = + reinterpret_cast(&gpu_cvm_values[cvm_data.size()]); + cudaMemcpyAsync(lods_values, lods.data(), lods.size() * sizeof(size_t *), + cudaMemcpyHostToDevice, stream); + + size_t N = static_cast(batch_size * slot_num * embedding_size); + if (use_cvm) { + if (show_filter) { + FusedSeqpoolCVMWithConvGradKernelWithShow<<>>( + N, gpu_out_grads_values, gpu_in_grads_values, gpu_cvm_values, + lods_values, batch_size, embedding_size, cvm_offset); + + } else { + FusedSeqpoolCVMWithConvGradKernelWithCVM<<>>( + N, gpu_out_grads_values, gpu_in_grads_values, gpu_cvm_values, + lods_values, batch_size, embedding_size, cvm_offset); + } + } else { + // update grad + FusedSeqpoolCVMWithConvGradKernelNoCVM<<>>( + N, gpu_out_grads_values, gpu_in_grads_values, gpu_cvm_values, + lods_values, batch_size, embedding_size, cvm_offset); + } +} + +template +class FusedSeqpoolCVMWithConvCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto inputs = ctx.MultiInput("X"); + auto outputs = ctx.MultiOutput("Out"); + + const auto slot_size = inputs.size(); + std::vector input_data(slot_size); + std::vector lods_data(slot_size); + std::vector output_data(slot_size); + + std::vector seqpool_outputs(slot_size); + std::vector seqpool_output_data(slot_size); + + auto padding_value = ctx.Attr("pad_value"); + auto use_cvm = ctx.Attr("use_cvm"); + const int cvm_offset = ctx.Attr("cvm_offset"); + bool show_filter = ctx.Attr("show_filter"); + + int embedding_size = inputs[0]->numel() / inputs[0]->dims()[0]; + int batch_size = -1; + for (size_t i = 0; i < slot_size; ++i) { + const auto *input = inputs[i]; + + auto lod = input->lod(); + auto lod_level = lod.size(); + + int cur_batch = lod[lod_level - 1].size() - 1; + if (batch_size == -1) { + batch_size = cur_batch; + } else { + CHECK(batch_size == cur_batch) << "batch: " << batch_size << ", current: " << cur_batch; + } + input_data[i] = reinterpret_cast(input->data()); + auto *output = outputs[i]; + if (use_cvm) { + if (show_filter) { + // show will filtered + output->Resize({batch_size, embedding_size - 1}); + } else { + // show will filtered + output->Resize({batch_size, embedding_size}); + } + } else { + output->Resize({batch_size, embedding_size - cvm_offset}); + } + output_data[i] = + reinterpret_cast(output->mutable_data(ctx.GetPlace())); + lods_data[i] = lod[lod_level - 1].CUDAData(ctx.GetPlace()); + + seqpool_output_data[i] = + reinterpret_cast(seqpool_outputs[i].mutable_data( + {batch_size, embedding_size}, ctx.GetPlace())); + } + FusedSeqpoolCVMWithConv(ctx.GetPlace(), input_data, output_data, + seqpool_output_data, lods_data, batch_size, slot_size, + embedding_size, padding_value, use_cvm, cvm_offset, show_filter); + } +}; + +template +class FusedSeqpoolCVMWithConvGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto out_grads = ctx.MultiInput(framework::GradVarName("Out")); + auto in_grads = ctx.MultiOutput(framework::GradVarName("X")); + auto *cvm = ctx.Input("CVM"); + + std::string pooltype = ctx.Attr("pooltype"); + auto use_cvm = ctx.Attr("use_cvm"); + const int cvm_offset = ctx.Attr("cvm_offset"); + bool show_filter = ctx.Attr("show_filter"); + + const auto slot_size = in_grads.size(); + std::vector out_grads_data(slot_size); + std::vector in_grads_data(slot_size); + std::vector cvm_data(slot_size); + std::vector lods_data(slot_size); + + int embedding_size = in_grads[0]->numel() / in_grads[0]->dims()[0]; + int batch_size = -1; + for (size_t i = 0; i < slot_size; ++i) { + auto *in_grad = in_grads[i]; + + auto lod = in_grad->lod(); + auto lod_level = lod.size(); + int cur_batch = lod[lod_level - 1].size() - 1; + if (batch_size == -1) { + batch_size = cur_batch; + } else { + CHECK(batch_size == cur_batch) << "batch: " << batch_size + << ", current: " << cur_batch; + } + + auto *out_grad = out_grads[i]; + out_grads_data[i] = reinterpret_cast(out_grad->data()); + + in_grads_data[i] = + reinterpret_cast(in_grad->mutable_data(ctx.GetPlace())); + lods_data[i] = lod[lod_level - 1].CUDAData(ctx.GetPlace()); + cvm_data[i] = reinterpret_cast(cvm->data()); + } + FusedSeqpoolCVMGradWithConv(ctx.GetPlace(), out_grads_data, in_grads_data, cvm_data, + lods_data, batch_size, slot_size, embedding_size, + use_cvm, cvm_offset, show_filter); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(fused_seqpool_cvm_with_conv, + ops::FusedSeqpoolCVMWithConvCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(fused_seqpool_cvm_with_conv_grad, + ops::FusedSeqpoolCVMWithConvGradCUDAKernel); diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.h b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.h new file mode 100644 index 0000000000000..d757fbcd06689 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; + +template +class FusedSeqpoolCVMOpWithConvCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW( + "Unimplemented CPU kernel for FusedSeqpoolCVMOpWithConv only support GPU " + "now."); + } +}; + +template +class FusedSeqpoolCVMGradOpWithConvCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW( + "Unimplemented CPU kernel for FusedSeqpoolCVMGradOpWithConv, only support GPU " + "now."); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index fbf13a0ce89cc..a7509fd23dd46 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -75,6 +75,7 @@ 'correlation', 'fused_bn_add_act', 'fused_seqpool_cvm', + 'fused_seqpool_cvm_with_conv', 'fused_seqpool_cvm_with_diff_thres', 'cross_norm_layer_hadamard', 'fused_seqpool_cvm_with_pcoc', @@ -1690,6 +1691,60 @@ def fused_seqpool_cvm_with_diff_thres(input, return outs +def fused_seqpool_cvm_with_conv(input, + pool_type, + cvm, + pad_value=0.0, + use_cvm=True, + show_filter=False, + cvm_offset=3): + """ + **Notes: The Op only receives List of LoDTensor as input, only support SUM pooling now. + :attr:`input`. + Args: + input(Variable|list of Variable): Input is List of LoDTensor. + pool_type(str): pooling type, only support SUM pooling now. + cvm(Variable): cvm Variable. + pad_value(float): padding value of sequence pool. + use_cvm(bool): use cvm or not. + Returns: + Variable|list of Variable: The tensor variable storing sequence pool and cvm + of input. + """ + helper = LayerHelper('fused_seqpool_cvm_with_conv', **locals()) + + if pool_type.upper() != 'SUM': + raise ValueError( + "fused_seqpool_cvm_with_conv only support SUM pooling now, and your type is: " + + pool_type) + + check_type(input, 'input', list, 'fused_seqpool_cvm_with_conv') + if isinstance(input, list): + for _input in input: + check_variable_and_dtype(_input, 'input', ['float32'], + 'fused_seqpool_cvm_with_conv') + + dtype = helper.input_dtype() + inputs = helper.multiple_input() + outs = [ + helper.create_variable_for_type_inference(dtype) + for i in range(len(inputs)) + ] + + helper.append_op( + type="fused_seqpool_cvm_with_conv", + inputs={"X": inputs, + "CVM": cvm}, + outputs={"Out": outs}, + attrs={ + "pooltype": pool_type.upper(), + "pad_value": pad_value, + "use_cvm": use_cvm, + "cvm_offset": cvm_offset, + "show_filter" : show_filter + }) + + return outs def fused_seqpool_cvm_with_pcoc(input, pool_type,