diff --git a/paddle/fluid/framework/fleet/box_wrapper_kernel.kps b/paddle/fluid/framework/fleet/box_wrapper_kernel.kps index 9d2408d1fea39..12b686dd3d0ec 100644 --- a/paddle/fluid/framework/fleet/box_wrapper_kernel.kps +++ b/paddle/fluid/framework/fleet/box_wrapper_kernel.kps @@ -1238,7 +1238,38 @@ void BoxWrapperKernel::CopyForPull( } } else { EmbedxNormalOp op; - FeaturePullCopy(place, + if (expand_embed_dim > 0 && pull_info_.expand_size > 0) { + FeaturePullCopyNNCross(place, + &op, + pull_embedx_scale_, + pull_offset, + total_dims, + xpu_values, + key2slot, + d_res_idx, + (float*)total_values_xpu, + xpu_merged_idx, + xpu_merged_offsets, + merged_length, + slot_lens, + slot_num, + (int)total_length, + hidden_size, + expand_embed_dim_, + (int)pull_float_num_, + skip_offset, + cvm_offset, + expand_only + ); + + } else if (pull_info_.expand_size < 0 && + expand_embed_dim == cvm_offset + expand_embed_dim_ && + hidden_size == cvm_offset + embedx_dim_) { + // TODO: + CHECK(false) << "FeaturePullCopyVariable not implement"; + + } else { + FeaturePullCopy(place, &op, pull_embedx_scale_, pull_offset, @@ -1256,6 +1287,7 @@ void BoxWrapperKernel::CopyForPull( pull_float_num_, skip_offset, cvm_offset); + } } }