Skip to content

Commit

Permalink
phi move ReshapeToMatrix & GetValue (#50139)
Browse files Browse the repository at this point in the history
  • Loading branch information
engineer1109 authored Feb 6, 2023
1 parent 1274e73 commit d09962a
Show file tree
Hide file tree
Showing 18 changed files with 131 additions and 143 deletions.
12 changes: 6 additions & 6 deletions paddle/fluid/framework/tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/core/tensor_utils.h"

#include <gtest/gtest.h>

Expand All @@ -25,7 +25,7 @@ namespace platform = paddle::platform;
TEST(DenseTensor, Dims) {
phi::DenseTensor tt;
tt.Resize({2, 3, 4});
framework::DDim dims = tt.dims();
phi::DDim dims = tt.dims();
ASSERT_EQ(arity(dims), 3);
for (int i = 0; i < 3; ++i) {
EXPECT_EQ(i + 2, dims[i]);
Expand Down Expand Up @@ -225,7 +225,7 @@ TEST(DenseTensor, Slice) {
src_tensor.mutable_data<int>(phi::make_ddim({5, 3, 4}),
platform::CPUPlace());
phi::DenseTensor slice_tensor = src_tensor.Slice(1, 3);
framework::DDim slice_dims = slice_tensor.dims();
phi::DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 3);
EXPECT_EQ(slice_dims[0], 2);
EXPECT_EQ(slice_dims[1], 3);
Expand All @@ -251,7 +251,7 @@ TEST(DenseTensor, Slice) {
src_tensor.mutable_data<double>(phi::make_ddim({6, 9}),
platform::CUDAPlace(0));
phi::DenseTensor slice_tensor = src_tensor.Slice(2, 6);
framework::DDim slice_dims = slice_tensor.dims();
phi::DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 2);
EXPECT_EQ(slice_dims[0], 4);
EXPECT_EQ(slice_dims[1], 9);
Expand All @@ -278,7 +278,7 @@ TEST(DenseTensor, Slice) {
src_tensor.mutable_data<double>(phi::make_ddim({6, 9}),
platform::NPUPlace(0));
phi::DenseTensor slice_tensor = src_tensor.Slice(2, 6);
framework::DDim slice_dims = slice_tensor.dims();
phi::DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 2);
EXPECT_EQ(slice_dims[0], 4);
EXPECT_EQ(slice_dims[1], 9);
Expand Down Expand Up @@ -306,7 +306,7 @@ TEST(DenseTensor, ReshapeToMatrix) {
for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
src_ptr[i] = i;
}
phi::DenseTensor res = framework::ReshapeToMatrix(src, 2);
phi::DenseTensor res = phi::ReshapeToMatrix(src, 2);
ASSERT_EQ(res.dims()[0], 2 * 3);
ASSERT_EQ(res.dims()[1], 4 * 9);
}
Expand Down
21 changes: 0 additions & 21 deletions paddle/fluid/framework/tensor_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -560,27 +560,6 @@ inline void TensorToVector(const phi::DenseTensor& src,

std::ostream& operator<<(std::ostream& os, const LoD& lod);

inline phi::DenseTensor ReshapeToMatrix(const phi::DenseTensor& src,
int num_col_dims) {
int rank = src.dims().size();
PADDLE_ENFORCE_GE(
rank,
2,
platform::errors::InvalidArgument(
"'ReshapeToMatrix()' is only used for flatten high rank "
"tensors to matrixs. The dimensions of phi::DenseTensor must be "
"greater or equal than 2. "
"But received dimensions of phi::DenseTensor is %d",
rank));
if (rank == 2) {
return src;
}
phi::DenseTensor res;
res.ShareDataWith(src);
res.Resize(phi::flatten_to_2d(src.dims(), num_col_dims));
return res;
}

template <typename T>
inline T GetValue(const phi::DenseTensor* x) {
T value = static_cast<T>(0);
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/operators/bpr_loss_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace paddle {
Expand Down Expand Up @@ -44,9 +45,9 @@ class BprLossOpKernel : public framework::OpKernel<T> {
y->mutable_data<T>(ctx.GetPlace());
int rank = x->dims().size();

phi::DenseTensor x_2d = framework::ReshapeToMatrix(*x, rank - 1);
phi::DenseTensor labels_2d = framework::ReshapeToMatrix(*label, rank - 1);
phi::DenseTensor y_2d = framework::ReshapeToMatrix(*y, rank - 1);
phi::DenseTensor x_2d = phi::ReshapeToMatrix(*x, rank - 1);
phi::DenseTensor labels_2d = phi::ReshapeToMatrix(*label, rank - 1);
phi::DenseTensor y_2d = phi::ReshapeToMatrix(*y, rank - 1);

const phi::DenseTensor* logits = &x_2d;
const phi::DenseTensor* labels = &labels_2d;
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/operators/cross_entropy_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/cross_entropy.h"
#include "paddle/phi/kernels/funcs/math.h"
#include "paddle/phi/kernels/funcs/math_function.h"
Expand All @@ -34,7 +35,7 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> {

int rank = x->dims().size();
auto label_dims = labels->dims();
phi::DenseTensor x_2d = framework::ReshapeToMatrix(*x, rank - 1);
phi::DenseTensor x_2d = phi::ReshapeToMatrix(*x, rank - 1);
phi::DenseTensor labels_2d, y_2d;
if (label_dims.size() < rank) {
labels_2d.ShareDataWith(*labels);
Expand All @@ -44,8 +45,8 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> {
y_2d.Resize({phi::product(y->dims()), 1});

} else {
labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
y_2d = framework::ReshapeToMatrix(*y, rank - 1);
labels_2d = phi::ReshapeToMatrix(*labels, rank - 1);
y_2d = phi::ReshapeToMatrix(*y, rank - 1);
}

int axis_dim = x->dims()[rank - 1];
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/operators/fused/multihead_matmul_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"

namespace paddle {
Expand Down Expand Up @@ -343,10 +344,10 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {

// (B*S, hidden)
const phi::DenseTensor input_matrix =
framework::ReshapeToMatrix(*input, 2 /*x_num_col_dims */);
phi::ReshapeToMatrix(*input, 2 /*x_num_col_dims */);
// (hidden, 3 * all_head_size)
const phi::DenseTensor w_matrix =
framework::ReshapeToMatrix(*w, 1 /*y_num_col_dims*/);
phi::ReshapeToMatrix(*w, 1 /*y_num_col_dims*/);

phi::DenseTensor temp_out_tensor;
auto temp_out_dims =
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/operators/isfinite_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,17 +129,17 @@ inline void TensorIsfinite(const phi::DenseTensor& tensor,
inline bool TensorContainsNAN(const phi::DenseTensor& tensor) {
phi::DenseTensor out;
TensorContainsNAN(tensor, &out);
return GetValue<bool>(&out);
return paddle::framework::GetValue<bool>(&out);
}
inline bool TensorContainsInf(const phi::DenseTensor& tensor) {
phi::DenseTensor out;
TensorContainsInf(tensor, &out);
return GetValue<bool>(&out);
return paddle::framework::GetValue<bool>(&out);
}
inline bool TensorIsfinite(const phi::DenseTensor& tensor) {
phi::DenseTensor out;
TensorIsfinite(tensor, &out);
return GetValue<bool>(&out);
return paddle::framework::GetValue<bool>(&out);
}
} // namespace framework
namespace operators {
Expand Down
56 changes: 56 additions & 0 deletions paddle/phi/core/tensor_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -867,4 +867,60 @@ template void TensorToVector(const phi::DenseTensor& src,
template void TensorToVector(const phi::DenseTensor& src,
std::vector<phi::dtype::complex<double>>* dst);

phi::DenseTensor ReshapeToMatrix(const phi::DenseTensor& src,
int num_col_dims) {
int rank = src.dims().size();
PADDLE_ENFORCE_GE(
rank,
2,
phi::errors::InvalidArgument(
"'ReshapeToMatrix()' is only used for flatten high rank "
"tensors to matrixs. The dimensions of phi::DenseTensor must be "
"greater or equal than 2. "
"But received dimensions of phi::DenseTensor is %d",
rank));
if (rank == 2) {
return src;
}
phi::DenseTensor res;
res.ShareDataWith(src);
res.Resize(phi::flatten_to_2d(src.dims(), num_col_dims));
return res;
}

template <typename T>
T GetValue(const phi::DenseTensor* x) {
T value = static_cast<T>(0);
if (!paddle::platform::is_cpu_place(x->place())) {
phi::DenseTensor cpu_x{};
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
phi::DeviceContext* dev_ctx = pool.Get(x->place());
phi::Copy(*dev_ctx, *x, phi::CPUPlace(), true, &cpu_x);
value = cpu_x.data<T>()[0];
} else {
value = x->data<T>()[0];
}
return value;
}

template bool GetValue(const phi::DenseTensor* x);

template int16_t GetValue(const phi::DenseTensor* x);

template int GetValue(const phi::DenseTensor* x);

template int64_t GetValue(const phi::DenseTensor* x);

template float GetValue(const phi::DenseTensor* x);

template double GetValue(const phi::DenseTensor* x);

template phi::dtype::bfloat16 GetValue(const phi::DenseTensor* x);

template phi::dtype::float16 GetValue(const phi::DenseTensor* x);

template phi::dtype::complex<float> GetValue(const phi::DenseTensor* x);

template phi::dtype::complex<double> GetValue(const phi::DenseTensor* x);

} // namespace phi
18 changes: 18 additions & 0 deletions paddle/phi/core/tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,22 @@ void TensorToVector(const phi::DenseTensor& src,
const phi::DeviceContext& ctx,
std::vector<T>* dst);

phi::DenseTensor ReshapeToMatrix(const phi::DenseTensor& src, int num_col_dims);

template <typename T>
T GetValue(const phi::DenseTensor* x);

template <typename T, typename Context>
inline T GetValue(const Context& dev_ctx, const DenseTensor& x) {
T value = static_cast<T>(0);
if (x.place() != CPUPlace()) {
DenseTensor cpu_x;
Copy(dev_ctx, x, CPUPlace(), true, &cpu_x);
value = cpu_x.data<T>()[0];
} else {
value = x.data<T>()[0];
}
return value;
}

} // namespace phi
3 changes: 2 additions & 1 deletion paddle/phi/kernels/fusion/onednn/fused_matmul_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"

using dnnl::engine;
using dnnl::inner_product_forward;
using dnnl::memory;
using dnnl::prop_kind;
using dnnl::stream;
using paddle::framework::ReshapeToMatrix;
using phi::ReshapeToMatrix;

namespace phi {

Expand Down
13 changes: 0 additions & 13 deletions paddle/phi/kernels/gpu/arange_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,6 @@

namespace phi {

template <typename T, typename Context>
inline T GetValue(const Context& dev_ctx, const DenseTensor& x) {
T value = static_cast<T>(0);
if (x.place() != CPUPlace()) {
DenseTensor cpu_x;
Copy(dev_ctx, x, CPUPlace(), true, &cpu_x);
value = cpu_x.data<T>()[0];
} else {
value = x.data<T>()[0];
}
return value;
}

template <typename T>
__global__ void Range(T start, T step, int64_t size, T* out) {
CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; }
Expand Down
13 changes: 0 additions & 13 deletions paddle/phi/kernels/gpu/linspace_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,6 @@ __global__ void LinspaceSpecialKernel(T start, T* out) {
out[0] = static_cast<T>(start);
}

template <typename T, typename Context>
T GetValue(const Context& ctx, const DenseTensor& x) {
T value = static_cast<T>(0);
if (x.place() != CPUPlace()) {
DenseTensor cpu_x;
Copy(ctx, x, CPUPlace(), true, &cpu_x);
value = cpu_x.data<T>()[0];
} else {
value = x.data<T>()[0];
}
return value;
}

template <typename T, typename Context>
T GetValueOfExpectedType(const Context& ctx, const DenseTensor& x) {
switch (x.dtype()) {
Expand Down
Loading

0 comments on commit d09962a

Please sign in to comment.