Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix several bugs for enabling Paddle to train with CINN. #36739

Merged
merged 8 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external/cinn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ add_definitions(-w)
include(ExternalProject)
set(CINN_SOURCE_DIR ${THIRD_PARTY_PATH}/CINN)
# TODO(zhhsplendid): Modify git tag after we have release tag
set(CINN_GIT_TAG e422c01b7875301996a2baf67a14ba61b0e6192a)
set(CINN_GIT_TAG cb030430d76f42f7310d09608f9b22959ecbcb51)
set(CINN_OPTIONAL_ARGS -DPY_VERSION=${PY_VERSION} -DWITH_CUDA=${WITH_GPU} -DWITH_CUDNN=${WITH_GPU} -DPUBLISH_LIBS=ON -DWITH_TESTING=ON)
set(CINN_BUILD_COMMAND $(MAKE) cinnapi -j)
ExternalProject_Add(
Expand Down
16 changes: 9 additions & 7 deletions paddle/fluid/framework/details/build_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
ResolveOptionConfliction();

AppendPrintGraphPass("graph_viz_pass", "_original_graph");

#ifdef PADDLE_WITH_CINN
if (FLAGS_use_cinn) {
// Note: This pass is used to enable cinn.
AppendPass("build_cinn_pass");
AppendPrintGraphPass("graph_viz_pass", "_build_cinn_graph");
}
#endif

AppendPassWithCheck(strategy_.enable_sequential_execution_,
"sequential_execution_pass");
AppendPassWithCheck(strategy_.sync_batch_norm_, "sync_batch_norm_pass");
Expand All @@ -74,13 +83,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Note: This pass is used to check whether the multi_device_graph is right.
AppendPass("multi_devices_check_pass");

#ifdef PADDLE_WITH_CINN
if (FLAGS_use_cinn) {
// Note: This pass is used to enable cinn.
AppendPass("build_cinn_pass");
}
#endif

SetCollectiveContext();
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/paddle2cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cc_library(cinn_cache_key SRCS cinn_cache_key.cc DEPS boost graph graph_helper lod_tensor proto_desc)
cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector cinn_compiler)
cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector graph_pattern_detector cinn_compiler errors enforce)
cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn)
cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph graph_helper transform_desc cinn)
cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph transform_desc cinn)
cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn)

cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key)
Expand Down
Loading