Skip to content

Commit

Permalink
optimized kernels/sparse/gpu/full_kernel (PaddlePaddle#57290)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianhaodongbd authored Sep 15, 2023
1 parent d83044a commit f0ab3be
Showing 1 changed file with 5 additions and 33 deletions.
38 changes: 5 additions & 33 deletions paddle/phi/kernels/sparse/gpu/full_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,11 @@ limitations under the License. */
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"

namespace phi {

template <typename InT, typename OutT = InT>
struct FullFunctor {
OutT value;

template <typename VType>
explicit inline FullFunctor(VType val) {
value = static_cast<OutT>(val);
}

__device__ __forceinline__ OutT operator()() const {
return static_cast<OutT>(value);
}
};

template <typename T, typename Context>
void FullLikeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
Expand All @@ -46,16 +33,8 @@ void FullLikeCooKernel(const Context& dev_ctx,
dev_ctx, x.indices(), dev_ctx.GetPlace(), false, out->mutable_indices());

DenseTensor* values = out->mutable_values();
values->Resize(x.values().dims());
dev_ctx.template Alloc<T>(values);

std::vector<const DenseTensor*> inputs = {};
std::vector<DenseTensor*> outputs = {values};
int numel = values->numel();
if (numel > 0) {
phi::funcs::ElementwiseKernel<T>(
dev_ctx, inputs, &outputs, FullFunctor<T>(val.to<T>()));
}
phi::Full<T, Context>(
dev_ctx, phi::vectorize(x.values().dims()), val, values);
out->set_dims(x.dims());
}

Expand All @@ -72,16 +51,9 @@ void FullLikeCsrKernel(const Context& dev_ctx,
dev_ctx, x.cols(), dev_ctx.GetPlace(), false, out->mutable_cols());

DenseTensor* values = out->mutable_values();
values->Resize(x.values().dims());
dev_ctx.template Alloc<T>(values);
phi::Full<T, Context>(
dev_ctx, phi::vectorize(x.values().dims()), val, values);

std::vector<const DenseTensor*> inputs = {};
std::vector<DenseTensor*> outputs = {values};
int numel = values->numel();
if (numel > 0) {
phi::funcs::ElementwiseKernel<T>(
dev_ctx, inputs, &outputs, FullFunctor<T>(val.to<T>()));
}
out->set_dims(x.dims());
}

Expand Down

0 comments on commit f0ab3be

Please sign in to comment.