diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc index dc1957691b1a2..c9a82c9e011bb 100644 --- a/paddle/framework/eigen_test.cc +++ b/paddle/framework/eigen_test.cc @@ -46,6 +46,29 @@ TEST(Eigen, Tensor) { } } +TEST(Eigen, TensorSum) { + Tensor t_in; + float* in = + t_in.mutable_data(make_ddim({2, 3, 4}), platform::CPUPlace()); + for (int i = 0; i < 2 * 3 * 4; i++) { + in[i] = static_cast(i); + } + + Tensor t_out; + float* out = t_out.mutable_data({2}, platform::CPUPlace()); + + EigenTensor::Type et_in = EigenTensor::From(t_in); + EigenTensor::Type et_out = EigenTensor::From(t_out); + + Eigen::DSizes along_class(1, 2); + + Eigen::DefaultDevice dd; + et_out.device(dd) = et_in.sum(along_class); + + ASSERT_NEAR(66, out[0], 1e-6f); + ASSERT_NEAR(210, out[1], 1e-6f); +} + TEST(Eigen, ScalarFrom) { Tensor t; int* p = t.mutable_data(make_ddim({1}), platform::CPUPlace());