diff --git a/paddle/fluid/framework/fleet/box_wrapper.cc b/paddle/fluid/framework/fleet/box_wrapper.cc index ec2fe3075b495..9d01357df17e3 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.cc +++ b/paddle/fluid/framework/fleet/box_wrapper.cc @@ -406,7 +406,7 @@ void BoxWrapper::PullSparse(const paddle::platform::Place& place, const std::vector& values, const std::vector& slot_lengths, const int hidden_size, const int expand_embed_dim, - const int skip_offset) { + const int skip_offset, bool expand_only) { #define EMBEDX_CASE(i, ...) \ case i: { \ constexpr size_t EmbedxDim = i; \ @@ -425,33 +425,33 @@ void BoxWrapper::PullSparse(const paddle::platform::Place& place, PullSparseCase< \ boxps::FeaturePullValueGpuShareEmbedding>( \ place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \ - skip_offset); \ + skip_offset, expand_only); \ } else if (feature_type_ == static_cast(boxps::FEATURE_PCOC)) { \ PullSparseCase>( \ place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \ - skip_offset); \ + skip_offset, expand_only); \ } else if (feature_type_ == static_cast(boxps::FEATURE_QUANT) || \ feature_type_ == static_cast(boxps::FEATURE_SHOWCLK)) { \ PullSparseCase>( \ place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \ - skip_offset); \ + skip_offset, expand_only); \ } else if (feature_type_ == static_cast(boxps::FEATURE_CONV)) { \ PullSparseCase>( \ place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \ - skip_offset); \ + skip_offset, expand_only); \ } else if (feature_type_ == static_cast(boxps::FEATURE_VARIABLE)) { \ PullSparseCase>( \ place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \ - skip_offset); \ + skip_offset, expand_only); \ } else if (EmbedxDim == 0 && \ feature_type_ == static_cast(boxps::FEATURE_ADAM)) { \ PullSparseCase>( \ place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \ - skip_offset); \ + skip_offset, expand_only); \ } else { \ PullSparseCase>( \ place, keys, values, slot_lengths, hidden_size, expand_embed_dim, \ - skip_offset); \ + skip_offset, expand_only); \ } \ } break @@ -489,7 +489,9 @@ void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place, const std::vector& slot_lengths, const int hidden_size, const int expand_embed_dim, - const int batch_size, const int skip_offset) { + const int batch_size, + const int skip_offset, + bool expand_only) { #define EMBEDX_CASE(i, ...) \ case i: { \ constexpr size_t EmbedxDim = i; \ @@ -508,30 +510,30 @@ void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place, PushSparseGradCase< \ boxps::FeaturePushValueGpuShareEmbedding>( \ place, keys, grad_values, slot_lengths, hidden_size, \ - expand_embed_dim, batch_size, skip_offset); \ + expand_embed_dim, batch_size, skip_offset, expand_only); \ } else if (feature_type_ == static_cast(boxps::FEATURE_PCOC)) { \ PushSparseGradCase< \ boxps::FeaturePushValueGpuPCOC>( \ place, keys, grad_values, slot_lengths, hidden_size, \ - expand_embed_dim, batch_size, skip_offset); \ + expand_embed_dim, batch_size, skip_offset, expand_only); \ } else if (feature_type_ == static_cast(boxps::FEATURE_VARIABLE)) { \ PushSparseGradCase>( \ place, keys, grad_values, slot_lengths, hidden_size, \ - expand_embed_dim, batch_size, skip_offset); \ + expand_embed_dim, batch_size, skip_offset, expand_only); \ } else if (feature_type_ == static_cast(boxps::FEATURE_CONV)) { \ PushSparseGradCase< \ boxps::FeaturePushValueGpuConv>( \ place, keys, grad_values, slot_lengths, hidden_size, \ - expand_embed_dim, batch_size, skip_offset); \ + expand_embed_dim, batch_size, skip_offset, expand_only); \ } else if (EmbedxDim == 0 && \ feature_type_ == static_cast(boxps::FEATURE_ADAM)) { \ PushSparseGradCase>( \ place, keys, grad_values, slot_lengths, hidden_size, \ - expand_embed_dim, batch_size, skip_offset); \ + expand_embed_dim, batch_size, skip_offset, expand_only); \ } else { \ PushSparseGradCase>( \ place, keys, grad_values, slot_lengths, hidden_size, \ - expand_embed_dim, batch_size, skip_offset); \ + expand_embed_dim, batch_size, skip_offset, expand_only); \ } \ } break diff --git a/paddle/fluid/framework/fleet/box_wrapper.cu b/paddle/fluid/framework/fleet/box_wrapper.cu index e9c5a28186a87..417b08de4feb5 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.cu +++ b/paddle/fluid/framework/fleet/box_wrapper.cu @@ -182,6 +182,39 @@ __global__ void PullCopyBaseNNCross( } // end kernel loop } template +__global__ void PullCopyBaseNNCrossWithEmb( + float** dest, const FEATURE_VALUE_GPU_TYPE* src, const int hidden, + const int expand_dim, const int total_len, uint64_t** keys, int* total_dims, + const int64_t* slot_lens, const int slot_num, const int* key2slot, + const int cvm_offset, const int skip_offset) { + CUDA_KERNEL_LOOP(i, total_len) { + int x = key2slot[i]; + int y = i - slot_lens[x]; + + auto& src_val = src[i]; + if (dest[x] != 0) { + float* dest_ptr = dest[x] + y * hidden; + const float* src_ptr = reinterpret_cast(&src_val.show); + for (int k = 0; k < cvm_offset; ++k) { + dest_ptr[k] = src_ptr[k + skip_offset]; + } + } + + if (dest[x + slot_num] != 0) { + float* dest_ptr = dest[x + slot_num] + y * (hidden + expand_dim); + const float* src_ptr = reinterpret_cast(&src_val.show); + for (int k = 0; k < cvm_offset; ++k) { + dest_ptr[k] = src_ptr[k]; + } + } + + // embedx flags + expand flags && *(keys[x] + y) != 0 && *(keys[x] + y) + // != 0 + total_dims[i] = static_cast(src_val.embedding_size > 0) + + (static_cast(src_val.embed_expand_size[0] > 0) << 1); + } // end kernel loop +} +template __global__ void PullCopyExpandNNCross( float** dest, const FEATURE_VALUE_GPU_TYPE* src, const int total_embedx_dim, const int embedx_dim, const int expand_dim, const int total_len, @@ -220,6 +253,56 @@ __global__ void PullCopyExpandNNCross( } // end kernel loop } template +__global__ void PullCopyExpandNNCrossWithEmb( + float** dest, const FEATURE_VALUE_GPU_TYPE* src, const int total_embedx_dim, + const int embedx_dim, const int expand_dim, const int total_len, + const int* total_dims, const int64_t* slot_lens, const int slot_num, + const int* key2slot, const float scale, const int cvm_offset) { + CUDA_KERNEL_LOOP(i, total_len) { + int idx = i / total_embedx_dim; + int col = i % total_embedx_dim; + + int x = key2slot[idx]; + int y = idx - slot_lens[x]; + + auto& src_val = src[idx]; + if (col < embedx_dim) { // embedx + if (dest[x] == 0) { + return; + } + int offset = y * (embedx_dim + cvm_offset) + cvm_offset + col; + if (total_dims[idx] & 0x01) { + *(dest[x] + offset) = src_val.embedx[col] * scale; + } else { + *(dest[x] + offset) = 0; + } + + if (dest[x + slot_num] == 0) { + return; + } + int offset_2 = y * (embedx_dim + cvm_offset + expand_dim) + + cvm_offset + col; + if (total_dims[idx] & 0x02) { + *(dest[x + slot_num] + offset_2) = src_val.embedx[col] * scale; + } else { + *(dest[x + slot_num] + offset_2) = 0; + } + } else { // expand + if (dest[x + slot_num] == 0) { + return; + } + int offset = y * (embedx_dim + cvm_offset + expand_dim) + + cvm_offset + col; + if (total_dims[idx] & 0x02) { + *(dest[x + slot_num] + offset) = + src_val.embed_expand[col - embedx_dim] * scale; + } else { + *(dest[x + slot_num] + offset) = 0; + } + } + } // end kernel loop +} +template __global__ void PullDedupCopyBaseNNCross( float** dest, const FEATURE_VALUE_GPU_TYPE* src, const int hidden, const int expand_dim, const int total_len, uint64_t** keys, int* total_dims, @@ -244,6 +327,39 @@ __global__ void PullDedupCopyBaseNNCross( } // end kernel loop } template +__global__ void PullDedupCopyBaseNNCrossWithEmb( + float** dest, const FEATURE_VALUE_GPU_TYPE* src, const int hidden, + const int expand_dim, const int total_len, uint64_t** keys, int* total_dims, + const int64_t* slot_lens, const int slot_num, const int* key2slot, + const int cvm_offset, const uint32_t* restore_idx, const int skip_offset) { + CUDA_KERNEL_LOOP(i, total_len) { + int x = key2slot[i]; + int y = i - slot_lens[x]; + + auto& src_val = src[restore_idx[i]]; + if (dest[x] != 0) { + float* dest_ptr = dest[x] + y * hidden; + const float* src_ptr = reinterpret_cast(&src_val.show); + for (int k = 0; k < cvm_offset; ++k) { + dest_ptr[k] = src_ptr[k + skip_offset]; + } + } + + if (dest[x + slot_num] != 0) { + float* dest_ptr = dest[x + slot_num] + y * (hidden + expand_dim); + const float* src_ptr = reinterpret_cast(&src_val.show); + for (int k = 0; k < cvm_offset; ++k) { + dest_ptr[k] = src_ptr[k]; + } + } + + // embedx flags + expand flags && *(keys[x] + y) != 0 && *(keys[x] + y) + // != 0 + total_dims[i] = static_cast(src_val.embedding_size > 0) + + (static_cast(src_val.embed_expand_size[0] > 0) << 1); + } // end kernel loop +} +template __global__ void PullDedupCopyExpandNNCross( float** dest, const FEATURE_VALUE_GPU_TYPE* src, const int total_embedx_dim, const int embedx_dim, const int expand_dim, const int total_len, @@ -282,6 +398,57 @@ __global__ void PullDedupCopyExpandNNCross( } } // end kernel loop } +template +__global__ void PullDedupCopyExpandNNCrossWithEmb( + float** dest, const FEATURE_VALUE_GPU_TYPE* src, const int total_embedx_dim, + const int embedx_dim, const int expand_dim, const int total_len, + const int* total_dims, const int64_t* slot_lens, const int slot_num, + const int* key2slot, const float scale, const int cvm_offset, + const uint32_t* restore_idx) { + CUDA_KERNEL_LOOP(i, total_len) { + int idx = i / total_embedx_dim; + int col = i % total_embedx_dim; + + int x = key2slot[idx]; + int y = idx - slot_lens[x]; + + auto& src_val = src[restore_idx[idx]]; + if (col < embedx_dim) { // embedx + if (dest[x] == 0) { + return; + } + int offset = y * (embedx_dim + cvm_offset) + cvm_offset + col; + if (total_dims[idx] & 0x01) { + *(dest[x] + offset) = src_val.embedx[col] * scale; + } else { + *(dest[x] + offset) = 0; + } + + if (dest[x + slot_num] == 0) { + return; + } + int offset_2 = y * (embedx_dim + cvm_offset + expand_dim) + + cvm_offset + col; + if (total_dims[idx] & 0x02) { + *(dest[x + slot_num] + offset_2) = src_val.embedx[col] * scale; + } else { + *(dest[x + slot_num] + offset_2) = 0; + } + } else { // expand + if (dest[x + slot_num] == 0) { + return; + } + int offset = y * (embedx_dim + cvm_offset + expand_dim) + + cvm_offset + col; + if (total_dims[idx] & 0x02) { + *(dest[x + slot_num] + offset) = + src_val.embed_expand[col - embedx_dim] * scale; + } else { + *(dest[x + slot_num] + offset) = 0; + } + } + } // end kernel loop +} //========================== feature var pull ======================== template __global__ void PullCopyBaseVariable( @@ -750,6 +917,42 @@ __global__ void PushCopyExpandNNCross( } } +template +__global__ void PushCopyExpandNNCrossWithEmb( + FeaturePushValueGpuType* dest, float** src, const int total_embedx_dim, + const int embedx_dim, const int expand_dim, const int total_len, + const int bs, const int* slot_vector, const int* total_dims, + const int64_t* slot_lens, const int slot_num, const int* key2slot, + const int cvm_offset) { + CUDA_KERNEL_LOOP(i, total_len) { + int idx = i / total_embedx_dim; + int col = i % total_embedx_dim; + + int x = key2slot[idx]; + int y = idx - slot_lens[x]; + + auto& dest_val = dest[idx]; + if (col < embedx_dim) { // embedx + if ((total_dims[idx] & 0x01) && src[x] != 0) { + dest_val.embedx_g[col] = + *(src[x] + y * (embedx_dim + cvm_offset) + cvm_offset + col) * -1. * + bs; + } else { + dest_val.embedx_g[col] = 0; + } + } else { // expand + int offset = y * (embedx_dim + cvm_offset + expand_dim) + + cvm_offset + col; + if ((total_dims[idx] & 0x02) && src[x + slot_num] != 0) { + dest_val.embed_expand_g[col - embedx_dim] = + *(src[x + slot_num] + offset) * -1. * bs; + } else { + dest_val.embed_expand_g[col - embedx_dim] = 0; + } + } + } +} + template __global__ void PushMergeCopyBaseNNCross( FeaturePushValueGpuType* dest, float** src, const int hidden, @@ -846,6 +1049,53 @@ __global__ void PushMergeCopyExpandNNCross( } } } +template +__global__ void PushMergeCopyExpandNNCrossWithEmb( + FeaturePushValueGpuType* dest, float** src, const int total_embedx_dim, + const int embedx_dim, const int expand_dim, const int total_len, + const int bs, const int* slot_vector, const int* total_dims, + const int64_t* slot_lens, const int slot_num, const int* key2slot, + const int cvm_offset, const uint32_t* d_sort_idx, + const uint32_t* d_sort_offset, const uint32_t* d_sort_cnt) { + CUDA_KERNEL_LOOP(i, total_len) { + int id = i / total_embedx_dim; + int col = i % total_embedx_dim; + + const uint32_t& start = d_sort_offset[id]; + const uint32_t& count = d_sort_cnt[id]; + const uint32_t& pos = d_sort_idx[start]; + + const int& x = key2slot[pos]; + int y = pos - slot_lens[x]; + + auto& dest_val = dest[id]; + if (col < embedx_dim) { // embedx + double val = 0.0; + for (uint32_t j = 0; j < count; ++j) { + const uint32_t& pos = d_sort_idx[start + j]; + const int& x = key2slot[pos]; + if ((total_dims[pos] & 0x01) && src[x] != 0) { + y = pos - slot_lens[x]; + val += *(src[x] + y * (embedx_dim + cvm_offset) + cvm_offset + col); + } + } + dest_val.embedx_g[col] = val * -1. * bs; + } else { // expand + int offset = y * (embedx_dim + cvm_offset + expand_dim) + + cvm_offset + col; + double val = 0.0; + for (uint32_t j = 0; j < count; ++j) { + const uint32_t& pos = d_sort_idx[start + j]; + const int& x = key2slot[pos]; + if ((total_dims[pos] & 0x02) && src[x + slot_num] != 0) { + y = pos - slot_lens[x]; + val += *(src[x + slot_num] + offset); + } + } + dest_val.embed_expand_g[col - embedx_dim] = val * -1 * bs; + } + } +} //========================== feature variable push ============================ template __global__ void PushCopyBaseVariable( @@ -1143,6 +1393,49 @@ void FeaturePullCopyNNCross(cudaStream_t stream, uint64_t** gpu_keys, } } +template +void FeaturePullCopyNNCrossWithEmb(cudaStream_t stream, uint64_t** gpu_keys, + float** gpu_values, void* src, + const int hidden_size, const size_t embedx_dim, + const size_t expand_dim, const int total_length, + int* total_dims, const int64_t* slot_lens, + const int slot_num, const int* key2slot, + const float scale, const int cvm_offset, + const uint32_t* gpu_restore_idx, + const int skip_offset) { + FeaturePullValueType* pull_values_gpu = + reinterpret_cast(src); + if (gpu_restore_idx != nullptr) { + // nncross + PullDedupCopyBaseNNCrossWithEmb< + FeaturePullValueType><<>>( + gpu_values, pull_values_gpu, hidden_size, expand_dim, total_length, + gpu_keys, total_dims, slot_lens, slot_num, key2slot, cvm_offset, + gpu_restore_idx, skip_offset); + // embedx + expand_embedx + int embedx_total_length = total_length * (embedx_dim + expand_dim); + PullDedupCopyExpandNNCrossWithEmb< + FeaturePullValueType><<>>( + gpu_values, pull_values_gpu, (embedx_dim + expand_dim), embedx_dim, + expand_dim, embedx_total_length, total_dims, slot_lens, slot_num, + key2slot, scale, cvm_offset, gpu_restore_idx); + } else { + // nncross + PullCopyBaseNNCrossWithEmb< + FeaturePullValueType><<>>( + gpu_values, pull_values_gpu, hidden_size, expand_dim, total_length, + gpu_keys, total_dims, slot_lens, slot_num, key2slot, cvm_offset, + skip_offset); + // embedx + expand_embedx + int embedx_total_length = total_length * (embedx_dim + expand_dim); + PullCopyExpandNNCrossWithEmb< + FeaturePullValueType><<>>( + gpu_values, pull_values_gpu, (embedx_dim + expand_dim), embedx_dim, + expand_dim, embedx_total_length, total_dims, slot_lens, slot_num, + key2slot, scale, cvm_offset); + } +} + template void FeaturePullCopyVariable(cudaStream_t stream, uint64_t** gpu_keys, float** gpu_values, void* src, @@ -1193,6 +1486,7 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place, const int hidden_size, const int expand_embed_dim, const int64_t total_length, int* total_dims, const int skip_offset, + bool expand_only=true, const uint32_t* gpu_restore_idx) { auto stream = dynamic_cast( platform::DeviceContextPool::Instance().Get( @@ -1249,12 +1543,21 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place, skip_offset); \ } else if (feature_type_ == static_cast(boxps::FEATURE_QUANT) || \ feature_type_ == static_cast(boxps::FEATURE_SHOWCLK)) { \ - FeaturePullCopyNNCross< \ + if (expand_only) { \ + FeaturePullCopyNNCross< \ boxps::FeaturePullValueGpuQuant>( \ stream, gpu_keys, gpu_values, total_values_gpu, hidden_size, \ EmbedxDim, ExpandDim, total_length, total_dims, slot_lens, slot_num, \ key2slot, pull_embedx_scale_, cvm_offset, gpu_restore_idx, \ skip_offset); \ + } else { \ + FeaturePullCopyNNCrossWithEmb< \ + boxps::FeaturePullValueGpuQuant>( \ + stream, gpu_keys, gpu_values, total_values_gpu, hidden_size, \ + EmbedxDim, ExpandDim, total_length, total_dims, slot_lens, slot_num, \ + key2slot, pull_embedx_scale_, cvm_offset, gpu_restore_idx, \ + skip_offset); \ + } \ } else if (feature_type_ == static_cast(boxps::FEATURE_VARIABLE)) { \ FeaturePullCopyVariable< \ boxps::FeatureVarPullValueGpu>( \ @@ -1433,6 +1736,47 @@ void FeaturePushCopyNNCross( } } +template +void FeaturePushCopyNNCrossWithEmb( + cudaStream_t stream, void* dest, float** grad_values, const int hidden_size, + const int embedx_dim, const int expand_dim, const int total_length, + const int batch_size, const int* slot_vector, const int* total_dims, + const int64_t* slot_lens, const int slot_num, const int* key2slot, + const int cvm_offset, const uint32_t* d_sort_idx, + const uint32_t* d_sort_offset, const uint32_t* d_sort_lens, + const int skip_offset) { + FeaturePushValueGpuType* push_grad_values = + reinterpret_cast(dest); + if (d_sort_idx != nullptr) { + // nncross + PushMergeCopyBaseNNCross< + FeaturePushValueGpuType><<>>( + push_grad_values, grad_values, hidden_size, total_length, batch_size, + slot_vector, total_dims, slot_lens, slot_num, key2slot, cvm_offset, + d_sort_idx, d_sort_offset, d_sort_lens, skip_offset); + int embedx_total_length = total_length * (embedx_dim + expand_dim); + PushMergeCopyExpandNNCrossWithEmb< + FeaturePushValueGpuType><<>>( + push_grad_values, grad_values, (embedx_dim + expand_dim), embedx_dim, + expand_dim, embedx_total_length, batch_size, slot_vector, total_dims, + slot_lens, slot_num, key2slot, cvm_offset, d_sort_idx, d_sort_offset, + d_sort_lens); + } else { + // nncross + PushCopyBaseNNCross< + FeaturePushValueGpuType><<>>( + push_grad_values, grad_values, hidden_size, total_length, batch_size, + slot_vector, total_dims, slot_lens, slot_num, key2slot, cvm_offset, + skip_offset); + int embedx_total_length = total_length * (embedx_dim + expand_dim); + PushCopyExpandNNCrossWithEmb< + FeaturePushValueGpuType><<>>( + push_grad_values, grad_values, (embedx_dim + expand_dim), embedx_dim, + expand_dim, embedx_total_length, batch_size, slot_vector, total_dims, + slot_lens, slot_num, key2slot, cvm_offset); + } +} + template void FeaturePushCopyShareEmbedding( cudaStream_t stream, void* dest, float** grad_values, const int hidden_size, @@ -1519,7 +1863,7 @@ void BoxWrapper::CopyForPush( const int64_t* slot_lens, const int slot_num, const int hidden_size, const int expand_embed_dim, const int64_t total_length, const int batch_size, const int* total_dims, const int* key2slot, - const int skip_offset, const uint32_t* gpu_sort_idx, + const int skip_offset, bool expand_only, const uint32_t* gpu_sort_idx, const uint32_t* gpu_sort_offset, const uint32_t* gpu_sort_lens) { auto stream = dynamic_cast( platform::DeviceContextPool::Instance().Get( @@ -1565,12 +1909,21 @@ void BoxWrapper::CopyForPush( case i: { \ constexpr size_t ExpandDim = i; \ if (feature_type_ == static_cast(boxps::FEATURE_PCOC)) { \ - FeaturePushCopyNNCross< \ + if (expand_only) { \ + FeaturePushCopyNNCross< \ + boxps::FeaturePushValueGpuPCOC>( \ + stream, total_grad_values_gpu, grad_values, hidden_size, EmbedxDim, \ + ExpandDim, total_length, batch_size, d_slot_vector, total_dims, \ + slot_lens, slot_num, key2slot, cvm_offset, gpu_sort_idx, \ + gpu_sort_offset, gpu_sort_lens, skip_offset); \ + } else { \ + FeaturePushCopyNNCrossWithEmb< \ boxps::FeaturePushValueGpuPCOC>( \ stream, total_grad_values_gpu, grad_values, hidden_size, EmbedxDim, \ ExpandDim, total_length, batch_size, d_slot_vector, total_dims, \ slot_lens, slot_num, key2slot, cvm_offset, gpu_sort_idx, \ gpu_sort_offset, gpu_sort_lens, skip_offset); \ + } \ } else if (feature_type_ == static_cast(boxps::FEATURE_VARIABLE)) { \ FeaturePushCopyVariable< \ boxps::FeatureVarPushValueGpu>( \ diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h index 44d9c9de69c10..c09772ed48517 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.h +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -429,14 +429,14 @@ class BoxWrapper { const std::vector& values, const std::vector& slot_lengths, const int hidden_size, const int expand_embed_dim, - const int skip_offset); + const int skip_offset, bool expand_only); void PullSparse(const paddle::platform::Place& place, const std::vector& keys, const std::vector& values, const std::vector& slot_lengths, const int hidden_size, const int expand_embed_dim, - const int skip_offset); + const int skip_offset, bool expand_only); template void PushSparseGradCase(const paddle::platform::Place& place, @@ -444,14 +444,16 @@ class BoxWrapper { const std::vector& grad_values, const std::vector& slot_lengths, const int hidden_size, const int expand_embed_dim, - const int batch_size, const int skip_offset); + const int batch_size, const int skip_offset, + bool expand_only); void PushSparseGrad(const paddle::platform::Place& place, const std::vector& keys, const std::vector& grad_values, const std::vector& slot_lengths, const int hidden_size, const int expand_embed_dim, - const int batch_size, const int skip_offset); + const int batch_size, const int skip_offset, + bool expand_only); void CopyForPull(const paddle::platform::Place& place, uint64_t** gpu_keys, float** gpu_values, void* total_values_gpu, @@ -459,6 +461,7 @@ class BoxWrapper { const int* key2slot, const int hidden_size, const int expand_embed_dim, const int64_t total_length, int* total_dims, const int skip_offset, + bool expand_only, const uint32_t* gpu_restore_idx = nullptr); void CopyForPush(const paddle::platform::Place& place, float** grad_values, @@ -468,6 +471,7 @@ class BoxWrapper { const int64_t total_length, const int batch_size, const int* total_dims, const int* key2slot, const int skip_offset, + bool expand_only, const uint32_t* gpu_sort_idx = nullptr, const uint32_t* gpu_sort_offset = nullptr, const uint32_t* gpu_sort_lens = nullptr); diff --git a/paddle/fluid/framework/fleet/box_wrapper_impl.h b/paddle/fluid/framework/fleet/box_wrapper_impl.h index e62c6aa6b0a71..506ced264d93b 100644 --- a/paddle/fluid/framework/fleet/box_wrapper_impl.h +++ b/paddle/fluid/framework/fleet/box_wrapper_impl.h @@ -28,7 +28,8 @@ void BoxWrapper::PullSparseCase(const paddle::platform::Place& place, const std::vector& slot_lengths, const int hidden_size, const int expand_embed_dim, - const int skip_offset) { + const int skip_offset, + bool expand_only) { if (!platform::is_gpu_place(place)) { PADDLE_THROW(platform::errors::Unimplemented( "Warning:: CPUPlace is not supported in PaddleBox now.")); @@ -134,7 +135,7 @@ void BoxWrapper::PullSparseCase(const paddle::platform::Place& place, this->CopyForPull(place, gpu_keys, gpu_values, total_values_gpu, slot_lens, slot_num, key2slot, hidden_size, expand_embed_dim, - total_length, total_dims, skip_offset, d_restore_idx); + total_length, total_dims, skip_offset, expand_only, d_restore_idx); } else { int64_t total_bytes = total_length * sizeof(FEATURE_VALUE_GPU_TYPE); FEATURE_VALUE_GPU_TYPE* total_values_gpu = @@ -157,7 +158,7 @@ void BoxWrapper::PullSparseCase(const paddle::platform::Place& place, this->CopyForPull(place, gpu_keys, gpu_values, total_values_gpu, slot_lens, slot_num, key2slot, hidden_size, expand_embed_dim, - total_length, total_dims, skip_offset); + total_length, total_dims, skip_offset, expand_only); } all_timer.Pause(); } @@ -168,7 +169,8 @@ void BoxWrapper::PushSparseGradCase( const std::vector& keys, const std::vector& grad_values, const std::vector& slot_lengths, const int hidden_size, - const int expand_embed_dim, const int batch_size, const int skip_offset) { + const int expand_embed_dim, const int batch_size, const int skip_offset, + bool expand_only) { if (!platform::is_gpu_place(place)) { PADDLE_THROW(platform::errors::Unimplemented( "Warning:: CPUPlace is not supported in PaddleBox now.")); @@ -223,7 +225,7 @@ void BoxWrapper::PushSparseGradCase( place); this->CopyForPush(place, gpu_values, total_grad_values_gpu, d_slot_vector, slot_lens, slot_num, hidden_size, expand_embed_dim, - dedup_size, batch_size, total_dims, key2slot, skip_offset, + dedup_size, batch_size, total_dims, key2slot, skip_offset, expand_only, d_sorted_idx, d_offset, d_merged_cnts); push_boxps_timer.Resume(); @@ -241,7 +243,7 @@ void BoxWrapper::PushSparseGradCase( this->CopyForPush(place, gpu_values, total_grad_values_gpu, d_slot_vector, slot_lens, slot_num, hidden_size, expand_embed_dim, total_length, batch_size, total_dims, key2slot, - skip_offset); + skip_offset, expand_only); push_boxps_timer.Resume(); int ret = boxps_ptr_->PushSparseGPU( diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc index 08ba487b87f9a..c9353e363f757 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc @@ -41,6 +41,7 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel { outs_dims.resize(num_inputs); bool use_cvm = ctx->Attrs().Get("use_cvm"); bool clk_filter = ctx->Attrs().Get("clk_filter"); + const int embed_thres_size = ctx->Attrs().Get("embed_thres_size"); // need filter quant_ratio more than zero if (ctx->Attrs().Get("need_filter")) { @@ -84,7 +85,7 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel { out_dim = {-1, dims[rank - 1]}; } } else { - out_dim = {-1, dims[rank - 1] - cvm_offset}; + out_dim = {-1, dims[rank - 1] - cvm_offset - embed_thres_size}; } outs_dims[i] = framework::make_ddim(out_dim); } @@ -132,6 +133,7 @@ class FusedSeqpoolCVMOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("cvm_offset", "(int, default 2)").SetDefault(2); AddAttr("quant_ratio", "(int, default 128)").SetDefault(0); AddAttr("clk_filter", "(bool, default false)").SetDefault(false); + AddAttr("embed_thres_size", "(int, default 0)").SetDefault(0); AddComment(R"DOC( Fuse multiple pairs of Sequence Pool and CVM Operator. @@ -151,6 +153,7 @@ class FusedSeqpoolCVMGradOp : public framework::OperatorWithKernel { const int cvm_offset = ctx->Attrs().Get("cvm_offset"); bool use_cvm = ctx->Attrs().Get("use_cvm"); bool clk_filter = ctx->Attrs().Get("clk_filter"); + const int embed_thres_size = ctx->Attrs().Get("embed_thres_size"); PADDLE_ENFORCE_EQ( cvm_dims.size(), 2, @@ -179,7 +182,7 @@ class FusedSeqpoolCVMGradOp : public framework::OperatorWithKernel { } else { PADDLE_ENFORCE_EQ( og_dims[i][og_dims[i].size() - 1], - x_dims[i][og_dims[i].size() - 1] - cvm_offset, + x_dims[i][og_dims[i].size() - 1] - cvm_offset - embed_thres_size, platform::errors::InvalidArgument( "The dimension mismatch between Input(OUT@GRAD) and " "Input(X). Received Input(OUT@GRAD): input rank %u, " diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu index 115bfe7a64fb3..31f463c7c9437 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu @@ -57,7 +57,8 @@ template __global__ void FusedSeqpoolKernelQuant( const size_t N, T **input_values, T **seqpool_output_values, size_t **lods_values, const int batch_size, const int embedding_size, - const float pad_value, const int cvm_offset, const int quant_ratio) { + const float pad_value, const int cvm_offset, const int quant_ratio, + const int embed_thres_size) { CUDA_KERNEL_LOOP(i, N) { int key = i / embedding_size; int offset = i % embedding_size; @@ -88,7 +89,8 @@ __global__ void FusedSeqpoolKernelQuantFilter( const size_t N, T **input_values, T **seqpool_output_values, size_t **lods_values, const int batch_size, const int embedding_size, const float pad_value, const int cvm_offset, const float show_coeff, - const float clk_coeff, const float threshold, const int quant_ratio) { + const float clk_coeff, const float threshold, const int quant_ratio, + const int embed_thres_size) { CUDA_KERNEL_LOOP(i, N) { int key = i / embedding_size; int offset = i % embedding_size; // embedx id @@ -124,7 +126,7 @@ __global__ void FusedSeqpoolKernelEmbedQuantFilter( size_t **lods_values, const int batch_size, const int embedding_size, const float pad_value, const int cvm_offset, const float show_coeff, const float clk_coeff, const float threshold, const int quant_ratio, - const float embed_threshold) { + const float embed_threshold, const int embed_thres_size) { CUDA_KERNEL_LOOP(i, N) { int key = i / embedding_size; int offset = i % embedding_size; // embedx id @@ -142,7 +144,11 @@ __global__ void FusedSeqpoolKernelEmbedQuantFilter( } T &embedw = *(input_values[x] + k * embedding_size + cvm_offset); T embedx_weight_score = 0.0; - for (int i = cvm_offset + 1; i < embedding_size; i++) { + int embed_thres_size_ = embed_thres_size; + if (embed_thres_size == 0) { + embed_thres_size_ = embedding_size - cvm_offset; + } + for (int i = cvm_offset + 1; i < cvm_offset + embed_thres_size_; i++) { embedx_weight_score += pow(*(input_values[x] + k * embedding_size + i), 2); } @@ -239,7 +245,8 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place, const int cvm_offset, float need_filter, const bool embed_threshold_filter, float show_coeff, float clk_coeff, float threshold, float embed_threshold, - const int quant_ratio, const bool clk_filter) { + const int quant_ratio, const bool clk_filter, + const int embed_thres_size) { auto stream = dynamic_cast( platform::DeviceContextPool::Instance().Get( BOOST_GET_CONST(platform::CUDAPlace, place))) @@ -276,18 +283,18 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place, 0, stream>>>( N, gpu_input_values, gpu_seqpool_output_values, lods_values, batch_size, embedding_size, padding_value, cvm_offset, show_coeff, clk_coeff, - threshold, quant_ratio, embed_threshold); + threshold, quant_ratio, embed_threshold, embed_thres_size); } else if (need_filter) { // quant need filter FusedSeqpoolKernelQuantFilter<<>>( N, gpu_input_values, gpu_seqpool_output_values, lods_values, batch_size, embedding_size, padding_value, cvm_offset, show_coeff, clk_coeff, - threshold, quant_ratio); + threshold, quant_ratio, embed_thres_size); } else if (quant_ratio > 0) { // quant not filter FusedSeqpoolKernelQuant<<>>( N, gpu_input_values, gpu_seqpool_output_values, lods_values, batch_size, - embedding_size, padding_value, cvm_offset, quant_ratio); + embedding_size, padding_value, cvm_offset, quant_ratio, embed_thres_size); } else { // normal FusedSeqpoolKernelNormal<<>>( @@ -311,10 +318,11 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place, } else { // not need show click input N = static_cast(batch_size * slot_num * - (embedding_size - cvm_offset)); + (embedding_size - cvm_offset - embed_thres_size)); FusedCVMKernelNoCVM<<>>( N, gpu_output_values, gpu_seqpool_output_values, batch_size, - (embedding_size - cvm_offset), cvm_offset); + (embedding_size - cvm_offset - embed_thres_size), + (cvm_offset + embed_thres_size)); } } // join grad @@ -369,17 +377,26 @@ template __global__ void FusedSeqpoolCVMGradKernelNoCVM( const size_t N, T **out_grads_values, T **in_grads_values, T **cvm_values, size_t **lods_values, const int batch_size, const int embedding_size, - const int cvm_offset) { + const int cvm_offset, const int embed_thres_size) { CUDA_KERNEL_LOOP(i, N) { int key = i / embedding_size; int offset = i % embedding_size; // embedx offset int x = key / batch_size; // slot id int y = key % batch_size; // ins id - T &val = (offset < cvm_offset) - ? *(cvm_values[x] + y * cvm_offset + offset) - : *(out_grads_values[x] + y * (embedding_size - cvm_offset) + - offset - cvm_offset); + T val = 0; + if (embed_thres_size == 0) { + val = (offset < cvm_offset) + ? *(cvm_values[x] + y * cvm_offset + offset) + : *(out_grads_values[x] + y * (embedding_size + - cvm_offset) + offset - cvm_offset); + } else { + val = (offset < cvm_offset + embed_thres_size) + ? 0 + : *(out_grads_values[x] + y * (embedding_size + - cvm_offset - embed_thres_size) + offset + - cvm_offset - embed_thres_size); + } auto &start = *(lods_values[x] + y); auto &end = *(lods_values[x] + y + 1); @@ -396,7 +413,8 @@ void FusedSeqpoolCVMGrad(const paddle::platform::Place &place, const std::vector &lods, const int batch_size, const int slot_num, const int embedding_size, const bool use_cvm, - const int cvm_offset, const bool clk_filter) { + const int cvm_offset, const bool clk_filter, + const int embed_thres_size) { auto stream = dynamic_cast( platform::DeviceContextPool::Instance().Get( BOOST_GET_CONST(platform::CUDAPlace, place))) @@ -445,7 +463,8 @@ void FusedSeqpoolCVMGrad(const paddle::platform::Place &place, FusedSeqpoolCVMGradKernelNoCVM<<>>( N, gpu_out_grads_values, gpu_in_grads_values, gpu_cvm_values, - lods_values, batch_size, embedding_size, cvm_offset); + lods_values, batch_size, embedding_size, cvm_offset, + embed_thres_size); } } @@ -475,6 +494,7 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel { const int cvm_offset = ctx.Attr("cvm_offset"); const int quant_ratio = ctx.Attr("quant_ratio"); bool clk_filter = ctx.Attr("clk_filter"); + const int embed_thres_size = ctx.Attr("embed_thres_size"); int embedding_size = inputs[0]->numel() / inputs[0]->dims()[0]; int batch_size = -1; @@ -500,7 +520,7 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel { output->Resize({batch_size, embedding_size}); } } else { - output->Resize({batch_size, embedding_size - cvm_offset}); + output->Resize({batch_size, embedding_size - cvm_offset - embed_thres_size}); } output_data[i] = reinterpret_cast(output->mutable_data(ctx.GetPlace())); @@ -514,7 +534,8 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel { seqpool_output_data, lods_data, batch_size, slot_size, embedding_size, padding_value, use_cvm, cvm_offset, need_filter, embed_threshold_filter, show_coeff, clk_coeff, - threshold, embed_threshold, quant_ratio, clk_filter); + threshold, embed_threshold, quant_ratio, clk_filter, + embed_thres_size); } }; @@ -530,6 +551,7 @@ class FusedSeqpoolCVMGradCUDAKernel : public framework::OpKernel { auto use_cvm = ctx.Attr("use_cvm"); const int cvm_offset = ctx.Attr("cvm_offset"); bool clk_filter = ctx.Attr("clk_filter"); + const int embed_thres_size = ctx.Attr("embed_thres_size"); const auto slot_size = in_grads.size(); std::vector out_grads_data(slot_size); @@ -562,7 +584,7 @@ class FusedSeqpoolCVMGradCUDAKernel : public framework::OpKernel { } FusedSeqpoolCVMGrad(ctx.GetPlace(), out_grads_data, in_grads_data, cvm_data, lods_data, batch_size, slot_size, embedding_size, - use_cvm, cvm_offset, clk_filter); + use_cvm, cvm_offset, clk_filter, embed_thres_size); } }; diff --git a/paddle/fluid/operators/pull_box_extended_sparse_op.cc b/paddle/fluid/operators/pull_box_extended_sparse_op.cc index 03ae7a21a9a68..0d64fc426fe2b 100644 --- a/paddle/fluid/operators/pull_box_extended_sparse_op.cc +++ b/paddle/fluid/operators/pull_box_extended_sparse_op.cc @@ -41,6 +41,7 @@ class PullBoxExtendedSparseOp : public framework::OperatorWithKernel { const size_t n_ids = all_ids_dim.size(); std::vector outs_dims; std::vector outs_extended_dims; + auto expand_only = ctx->Attrs().Get("expand_only"); auto flags = ctx->Attrs().Get>("mask"); if (flags.empty()) { for (size_t i = 0; i < n_ids; ++i) { @@ -58,7 +59,11 @@ class PullBoxExtendedSparseOp : public framework::OperatorWithKernel { outs_dims.push_back(framework::make_ddim(out_dim)); auto out_extended_dim = framework::vectorize( framework::slice_ddim(ids_dims, 0, ids_rank - 1)); - out_extended_dim.push_back(emb_extended_size); + if (expand_only) { + out_extended_dim.push_back(emb_extended_size); + } else { + out_extended_dim.push_back(emb_size + emb_extended_size); + } outs_extended_dims.push_back(framework::make_ddim(out_extended_dim)); } ctx->SetOutputsDim("Out", outs_dims); @@ -87,7 +92,11 @@ class PullBoxExtendedSparseOp : public framework::OperatorWithKernel { if (flags[i] & 0x02) { auto out_extended_dim = framework::vectorize( framework::slice_ddim(ids_dims, 0, ids_rank - 1)); - out_extended_dim.push_back(emb_extended_size); + if (expand_only) { + out_extended_dim.push_back(emb_extended_size); + } else { + out_extended_dim.push_back(emb_size + emb_extended_size); + } outs_extended_dims.push_back(framework::make_ddim(out_extended_dim)); } } @@ -135,6 +144,7 @@ class PullBoxExtendedSparseOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>("mask", "The embedx expand mask.").SetDefault({}); AddAttr("offset", "(int, the skip pull value cvm offset") .SetDefault(0); + AddAttr("expand_only","bool, expand output with show clk embed or not").SetDefault(true); AddComment(R"DOC( Pull Box Extended Sparse Operator. diff --git a/paddle/fluid/operators/pull_box_extended_sparse_op.h b/paddle/fluid/operators/pull_box_extended_sparse_op.h index e40077c56e9e3..c7a9111c394c2 100644 --- a/paddle/fluid/operators/pull_box_extended_sparse_op.h +++ b/paddle/fluid/operators/pull_box_extended_sparse_op.h @@ -78,9 +78,10 @@ static void PullBoxExtendedSparseFunctor( int skip_offset = ctx.Attr("offset"); auto emb_size = ctx.Attr("emb_size"); auto emb_extended_size = ctx.Attr("emb_extended_size"); + auto expand_only = ctx.Attr("expand_only"); auto box_ptr = paddle::framework::BoxWrapper::GetInstance(); box_ptr->PullSparse(ctx.GetPlace(), all_keys, all_values, slot_lengths, - emb_size, emb_extended_size, skip_offset); + emb_size, emb_extended_size, skip_offset, expand_only); #endif } @@ -166,10 +167,11 @@ static void PushBoxExtendedSparseFunctor( int skip_offset = ctx.Attr("offset"); auto emb_size = ctx.Attr("emb_size"); auto emb_extended_size = ctx.Attr("emb_extended_size"); + auto expand_only = ctx.Attr("expand_only"); auto box_ptr = paddle::framework::BoxWrapper::GetInstance(); box_ptr->PushSparseGrad(ctx.GetPlace(), all_keys, all_grad_values, slot_lengths, emb_size, emb_extended_size, batch_size, - skip_offset); + skip_offset, expand_only); #endif } diff --git a/paddle/fluid/operators/pull_box_sparse_op.h b/paddle/fluid/operators/pull_box_sparse_op.h index aa57ca22964c9..e4742a772b277 100644 --- a/paddle/fluid/operators/pull_box_sparse_op.h +++ b/paddle/fluid/operators/pull_box_sparse_op.h @@ -154,7 +154,7 @@ static void PullBoxSparseFunctor(const framework::ExecutionContext &ctx) { auto box_ptr = paddle::framework::BoxWrapper::GetInstance(); auto expand_dim = box_ptr->GetExpandEmbedDim(); box_ptr->PullSparse(ctx.GetPlace(), all_keys, all_values, slot_lengths, - hidden_size, expand_dim, skip_offset); + hidden_size, expand_dim, skip_offset, true); #endif } @@ -209,7 +209,7 @@ static void PushBoxSparseFunctor(const framework::ExecutionContext &ctx) { auto expand_dim = box_ptr->GetExpandEmbedDim(); box_ptr->PushSparseGrad(ctx.GetPlace(), all_keys, all_grad_values, slot_lengths, hidden_size, expand_dim, batch_size, - skip_offset); + skip_offset, true); #endif } diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index 3809942f7cbda..0396131f28a28 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -1514,7 +1514,8 @@ def _pull_box_extended_sparse(input, extend_size=64, dtype='float32', mask=[], - offset=0): + offset=0, + expand_only=True): r""" **Pull Box Extended Sparse Layer** This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in @@ -1570,7 +1571,8 @@ def _pull_box_extended_sparse(input, 'emb_size': size, 'emb_extended_size': extend_size, 'mask': mask, - 'offset': offset + 'offset': offset, + 'expand_only': expand_only }) if len(outs) == 1: return outs[0], outs_extend[0] @@ -1590,7 +1592,8 @@ def fused_seqpool_cvm(input, embed_threshold=0, cvm_offset=2, quant_ratio=0, - clk_filter=False): + clk_filter=False, + embed_thres_size=0): """ **Notes: The Op only receives List of LoDTensor as input, only support SUM pooling now. :attr:`input`. @@ -1645,7 +1648,8 @@ def fused_seqpool_cvm(input, "threshold": threshold, "embed_threshold": embed_threshold, "quant_ratio": quant_ratio, - "clk_filter": clk_filter + "clk_filter": clk_filter, + "embed_thres_size": embed_thres_size }) return outs