Skip to content

Commit

Permalink
[PTen]Refactor scale kernel that has selected_rows input (#39278)
Browse files Browse the repository at this point in the history
* refactor scale kernel that its input is selected_rows

* complement upload file
  • Loading branch information
YuanRisheng authored Jan 28, 2022
1 parent 848ae7d commit abfc2fe
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 40 deletions.
67 changes: 47 additions & 20 deletions cmake/pten.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ function(kernel_library TARGET)
set(cpu_srcs)
set(gpu_srcs)
set(xpu_srcs)
set(selected_rows_srcs)
# parse and save the deps kerenl targets
set(all_srcs)
set(kernel_deps)
Expand All @@ -106,6 +107,9 @@ function(kernel_library TARGET)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/cpu/${TARGET}.cc)
list(APPEND cpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/cpu/${TARGET}.cc)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/selected_rows/${TARGET}.cc)
list(APPEND selected_rows_srcs ${CMAKE_CURRENT_SOURCE_DIR}/selected_rows/${TARGET}.cc)
endif()
if (WITH_GPU OR WITH_ROCM)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu)
list(APPEND gpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu)
Expand Down Expand Up @@ -144,27 +148,30 @@ function(kernel_library TARGET)
list(LENGTH cpu_srcs cpu_srcs_len)
list(LENGTH gpu_srcs gpu_srcs_len)
list(LENGTH xpu_srcs xpu_srcs_len)
list(LENGTH selected_rows_srcs selected_rows_srcs_len)

# Build Target according different src organization
if((${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR
${xpu_srcs_len} GREATER 0) AND ${common_srcs_len} GREATER 0)
# If the common_srcs depends on specific device srcs, build target using this rule.
${xpu_srcs_len} GREATER 0) AND (${common_srcs_len} GREATER 0 OR
${selected_rows_srcs_len} GREATER 0))
# If the common_srcs/selected_rows_srcs depends on specific device srcs, build target using this rule.
if (WITH_GPU)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
nv_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
nv_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part)
nv_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part)
endif()
elseif (WITH_ROCM)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
hip_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
hip_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part)
hip_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part)
endif()
else()
if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
cc_library(${TARGET}_part SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
cc_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part)
cc_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part)
endif()
endif()
# If there are only specific device srcs, build target using this rule.
elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
if (WITH_GPU)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
Expand All @@ -179,25 +186,42 @@ function(kernel_library TARGET)
cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
endif()
else()
if (${common_srcs_len} EQUAL 0)
message(FATAL_ERROR "Cannot find any implementation for ${TARGET}")
# If the selected_rows_srcs depends on common_srcs, build target using this rule.
elseif (${common_srcs_len} GREATER 0 AND ${selected_rows_srcs_len} GREATER 0)
if (WITH_GPU)
nv_library(${TARGET}_part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
nv_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET}_part)
elseif (WITH_ROCM)
hip_library(${TARGET}_part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
hip_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET}_part)
else()
# If the kernel has a device independent public implementation,
# we will use this implementation and will not adopt the implementation
# under specific devices
if (WITH_GPU)
nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
elseif (WITH_ROCM)
hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
else()
cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
endif()
cc_library(${TARGET}_part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET}_part)
endif()
# If there are only common_srcs or selected_rows_srcs, build target using below rules.
elseif (${common_srcs_len} GREATER 0)
if (WITH_GPU)
nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
elseif (WITH_ROCM)
hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
else()
cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
elseif (${selected_rows_srcs_len} GREATER 0)
if (WITH_GPU)
nv_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
elseif (WITH_ROCM)
hip_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
else()
cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
else()
message(FATAL_ERROR "Cannot find any implementation for ${TARGET}")
endif()

if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR
${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR
${selected_rows_srcs_len} GREATER 0)
# append target into PTEN_KERNELS property
get_property(pten_kernels GLOBAL PROPERTY PTEN_KERNELS)
set(pten_kernels ${pten_kernels} ${TARGET})
Expand All @@ -219,6 +243,9 @@ function(kernel_library TARGET)
if (${xpu_srcs_len} GREATER 0)
kernel_declare(${xpu_srcs})
endif()
if (${selected_rows_srcs_len} GREATER 0)
kernel_declare(${selected_rows_srcs})
endif()
endfunction()

function(register_kernels)
Expand Down
28 changes: 15 additions & 13 deletions paddle/fluid/operators/scale_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,34 +43,36 @@ class ScaleKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* in_var = ctx.InputVar("X");
auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var);

auto bias = ctx.Attr<float>("bias");
auto bias_after_scale = ctx.Attr<bool>("bias_after_scale");

auto scale = ctx.Attr<float>("scale");
auto* out_var = ctx.OutputVar("Out");

if (ctx.HasInput("ScaleTensor")) {
auto* scale_tensor = ctx.Input<framework::Tensor>("ScaleTensor");
scale = static_cast<float>(GetAttrFromTensor<T>(scale_tensor));
}

auto* out_var = ctx.OutputVar("Out");
if (in_var->IsType<pten::SelectedRows>() && in_var != out_var) {
auto& in_slr = in_var->Get<pten::SelectedRows>();
auto* out_slr = out_var->GetMutable<pten::SelectedRows>();
out_slr->set_rows(in_slr.rows());
out_slr->set_height(in_slr.height());
}
auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var);
auto* out =
framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(out_var);
out->mutable_data<T>(in->place());
auto& dev_ctx = ctx.device_context<DeviceContext>();

// call new kernel
pten::ScaleKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*in, scale, bias, bias_after_scale, out);
if (in_var->IsType<pten::SelectedRows>()) {
pten::ScaleSR<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
in_var->Get<pten::SelectedRows>(), scale, bias, bias_after_scale,
out_var->GetMutable<pten::SelectedRows>());
} else {
pten::ScaleKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*in, scale, bias, bias_after_scale, out);
}
}
};

Expand Down
6 changes: 6 additions & 0 deletions paddle/pten/core/kernel_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,19 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
std::type_index(typeid(const std::vector<DenseTensor>&))) {
args_def->AppendInput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else if (arg_type == std::type_index(typeid(const SelectedRows&))) {
args_def->AppendInput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else if (arg_type == std::type_index(typeid(DenseTensor*))) {
args_def->AppendOutput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else if (arg_type ==
std::type_index(typeid(std::vector<DenseTensor*>))) {
args_def->AppendOutput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else if (arg_type == std::type_index(typeid(SelectedRows*))) {
args_def->AppendOutput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else {
// Attribute deal with
// TODO(chenweihang): now here allow any types of attribute, maybe
Expand Down
7 changes: 3 additions & 4 deletions paddle/pten/core/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_context.h"
#include "paddle/pten/core/kernel_def.h"
#include "paddle/pten/core/selected_rows.h"
#include "paddle/pten/core/sparse_coo_tensor.h"
#include "paddle/pten/core/sparse_csr_tensor.h"

Expand Down Expand Up @@ -215,6 +216,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows);

PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCooTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCooTensor);
Expand All @@ -223,8 +225,6 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCsrTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCsrTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(SparseCsrTensor);
// TODO(chenweihang): adapt SelectedRows
// PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRowsTensor);

/* Attribute Helpers */

Expand All @@ -244,14 +244,13 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {

PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SelectedRows);

PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SparseCooTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(SparseCooTensor);

PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SparseCsrTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(SparseCsrTensor);
// TODO(chenweihang): adapt SelectedRows
// PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SelectedRowsTensor);

/* End case */
template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion paddle/pten/kernels/cpu/scale_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ PT_REGISTER_KERNEL(scale,
pten::ScaleKernel,
float,
double,
paddle::platform::bfloat16,
pten::dtype::bfloat16,
uint8_t,
int8_t,
int16_t,
Expand Down
2 changes: 1 addition & 1 deletion paddle/pten/kernels/gpu/scale_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ PT_REGISTER_KERNEL(scale,
pten::ScaleKernel,
float,
double,
paddle::platform::float16,
pten::dtype::float16,
uint8_t,
int8_t,
int16_t,
Expand Down
9 changes: 9 additions & 0 deletions paddle/pten/kernels/scale_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#include "paddle/pten/common/scalar.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/selected_rows.h"
#include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten {
Expand All @@ -28,6 +29,14 @@ void ScaleKernel(const Context& dev_ctx,
bool bias_after_scale,
DenseTensor* out);

template <typename T, typename Context>
void ScaleSR(const Context& dev_ctx,
const SelectedRows& x,
const Scalar& scale,
float bias,
bool bias_after_scale,
SelectedRows* out);

template <typename T, typename Context>
DenseTensor Scale(const Context& dev_ctx,
const DenseTensor& x,
Expand Down
68 changes: 68 additions & 0 deletions paddle/pten/kernels/selected_rows/scale_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pten/kernels/scale_kernel.h"

#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"

// See Note [ Why still include the fluid headers? ]
#include "paddle/pten/common/bfloat16.h"
namespace pten {

template <typename T, typename Context>
void ScaleSR(const Context& dev_ctx,
const SelectedRows& x,
const Scalar& scale,
float bias,
bool bias_after_scale,
SelectedRows* out) {
if (x.value().data() != out->value().data()) {
out->set_rows(x.rows());
out->set_height(x.height());
}
pten::ScaleKernel<T>(
dev_ctx, x.value(), scale, bias, bias_after_scale, out->mutable_value());
}

} // namespace pten

PT_REGISTER_KERNEL(scale_sr,
CPU,
ALL_LAYOUT,
pten::ScaleSR,
float,
double,
pten::dtype::bfloat16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_KERNEL(scale_sr,
GPU,
ALL_LAYOUT,
pten::ScaleSR,
float,
double,
pten::dtype::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
#endif
2 changes: 1 addition & 1 deletion paddle/pten/tests/api/scale_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ static void ScaleCPU(DataType kernel_dtype,
break;
}
case pten::DataType::BFLOAT16: {
pten::ScaleKernel<paddle::platform::bfloat16>(
pten::ScaleKernel<pten::dtype::bfloat16>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
Expand Down

0 comments on commit abfc2fe

Please sign in to comment.