Skip to content

Commit

Permalink
fix the modification of set_expected_place (#31177)
Browse files Browse the repository at this point in the history
* revert the modification of set_expected_place

* set device before op run

* add ut
  • Loading branch information
zhiqiu authored Feb 24, 2021
1 parent dc8dfba commit 0f1fde5
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
15 changes: 7 additions & 8 deletions paddle/fluid/imperative/tests/test_tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ TEST(test_tracer, test_trace_op) {
framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);

#ifndef PADDLE_WITH_XPU
ASSERT_THROW(tracer.TraceOp("mul", ins, outs, mul_attr_map,
platform::XPUPlace(0), true);
, platform::EnforceNotMet);
#endif

const auto& out_tensor = vout->Var().Get<framework::LoDTensor>();
for (int i = 0; i < vout->Var().Get<framework::LoDTensor>().numel(); i++) {
ASSERT_EQ(out_tensor.data<float>()[i], 20.0);
Expand Down Expand Up @@ -311,10 +318,6 @@ TEST(test_tracer, test_expected_place) {
platform::CUDAPlace gpu_place(0);
tracer.SetExpectedPlace(gpu_place);
ASSERT_EQ(platform::is_gpu_place(tracer.ExpectedPlace()), true);

// assert throw
platform::XPUPlace xpu_place(0);
ASSERT_THROW(tracer.SetExpectedPlace(xpu_place), platform::EnforceNotMet);
#endif
}
{
Expand All @@ -323,10 +326,6 @@ TEST(test_tracer, test_expected_place) {
platform::XPUPlace xpu_place(0);
tracer.SetExpectedPlace(xpu_place);
ASSERT_EQ(platform::is_xpu_place(tracer.ExpectedPlace()), true);

// assert throw
platform::CUDAPlace cuda_place(0);
ASSERT_THROW(tracer.SetExpectedPlace(cuda_place), platform::EnforceNotMet);
#endif
}
}
Expand Down
33 changes: 17 additions & 16 deletions paddle/fluid/imperative/tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,23 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
}

try {
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::SetDeviceId(BOOST_GET_CONST(platform::CUDAPlace, place).device);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU if use CUDAPlace."));
#endif
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
platform::SetXPUDeviceId(
BOOST_GET_CONST(platform::XPUPlace, place).device);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with XPU if use XPUPlace."));
#endif
}

OpBase::Run(*op, new_ins, outs, attrs, place);
} catch (platform::EnforceNotMet& exception) {
framework::AppendErrorOpHint(type, &exception);
Expand Down Expand Up @@ -199,22 +216,6 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
}

void Tracer::SetExpectedPlace(platform::Place place) {
// NOTE(wangxi): set device id before launch device kernel
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::SetDeviceId(BOOST_GET_CONST(platform::CUDAPlace, place).device);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU if use CUDAPlace."));
#endif
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
platform::SetXPUDeviceId(BOOST_GET_CONST(platform::XPUPlace, place).device);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with XPU if use XPUPlace."));
#endif
}
expected_place_ = place;
}

Expand Down

0 comments on commit 0f1fde5

Please sign in to comment.