Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] add paddle.framework.core.is_run_with_cinn() api for running check #54355

Merged
10 changes: 10 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,15 @@ bool IsCompiledWithCINN() {
#endif
}

bool IsRunWithCINN() {
#ifndef PADDLE_WITH_CINN
return false;
#else
return framework::paddle2cinn::CinnCompiler::GetInstance()
->real_compiled_num() > 0;
#endif
}

bool IsCompiledWithHETERPS() {
#ifndef PADDLE_WITH_HETERPS
return false;
Expand Down Expand Up @@ -1909,6 +1918,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("is_compiled_with_mpi", IsCompiledWithMPI);
m.def("is_compiled_with_mpi_aware", IsCompiledWithMPIAWARE);
m.def("is_compiled_with_cinn", IsCompiledWithCINN);
m.def("is_run_with_cinn", IsRunWithCINN);
m.def("_is_compiled_with_heterps", IsCompiledWithHETERPS);
m.def("supports_bfloat16", SupportsBfloat16);
m.def("supports_bfloat16_fast_performance", SupportsBfloat16FastPerformance);
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/fluid/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ def to_list(s):
from .libpaddle import _cleanup_mmap_fds
from .libpaddle import _remove_tensor_list_mmap_fds
from .libpaddle import _set_max_memory_map_allocation_pool_size

# CINN
from .libpaddle import is_run_with_cinn

except Exception as e:
if has_paddle_dy_lib:
sys.stderr.write(
Expand Down
12 changes: 12 additions & 0 deletions test/dygraph_to_static/test_cinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ def train(self, use_cinn):
sgd.clear_grad()

res.append(out.numpy())

if use_cinn and paddle.device.is_compiled_with_cinn():
self.assertTrue(
paddle.framework.core.is_run_with_cinn(),
msg="The test was not running with CINN! Please check.",
)
else:
self.assertFalse(
paddle.framework.core.is_run_with_cinn(),
msg="The test should not running with CINN when the whl package was not compiled with CINN! Please check.",
)

return res

def test_cinn(self):
Expand Down