-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
138 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
diff --git a/src/backend/graph_compiler/core/src/compiler/config/context.cpp b/src/backend/graph_compiler/core/src/compiler/config/context.cpp | ||
index 6e03ae30df..522687e5ab 100644 | ||
--- a/src/backend/graph_compiler/core/src/compiler/config/context.cpp | ||
+++ b/src/backend/graph_compiler/core/src/compiler/config/context.cpp | ||
@@ -208,6 +208,8 @@ context_ptr get_default_context() { | ||
parse_value(env_names[SC_INDEX2VAR], flags.index2var_); | ||
parse_value(env_names[SC_PRINT_IR], flags.print_ir_); | ||
parse_value(env_names[SC_MIXED_FUSION], flags.mixed_fusion_); | ||
+ parse_value( | ||
+ env_names[SC_COARSE_GRAIN_FUSION], flags.coarse_grain_fusion_); | ||
parse_value(env_names[SC_COST_MODEL], flags.use_cost_model_); | ||
parse_value(env_names[SC_SSA_PASSES], flags.ssa_passes_); | ||
parse_value( | ||
diff --git a/src/backend/graph_compiler/core/src/compiler/config/context.hpp b/src/backend/graph_compiler/core/src/compiler/config/context.hpp | ||
index b105e094be..736c420a9d 100644 | ||
--- a/src/backend/graph_compiler/core/src/compiler/config/context.hpp | ||
+++ b/src/backend/graph_compiler/core/src/compiler/config/context.hpp | ||
@@ -54,6 +54,7 @@ struct scflags_t { | ||
std::string graph_dump_results_; | ||
bool value_check_ = false; | ||
bool mixed_fusion_ = true; | ||
+ bool coarse_grain_fusion_ = true; | ||
bool use_cost_model_ = true; | ||
bool debug_info_ = false; | ||
bool xbyak_jit_save_obj_ = false; | ||
diff --git a/src/backend/graph_compiler/core/src/compiler/ir/graph/mixed_partition.cpp b/src/backend/graph_compiler/core/src/compiler/ir/graph/mixed_partition.cpp | ||
index 9f4986c89b..6f27f4100f 100644 | ||
--- a/src/backend/graph_compiler/core/src/compiler/ir/graph/mixed_partition.cpp | ||
+++ b/src/backend/graph_compiler/core/src/compiler/ir/graph/mixed_partition.cpp | ||
@@ -2959,8 +2959,10 @@ void do_mixed_partition(const context_ptr &ctx, sc_graph_t &graph) { | ||
if (do_partition(ctx, graph, dep, op_2_partition)) break; | ||
} | ||
|
||
- std::vector<crossover_alg> algs | ||
- = {horizontal_crossover, parallel_crossover, vertical_crossover}; | ||
+ std::vector<crossover_alg> algs = ctx->flags_.coarse_grain_fusion_ | ||
+ ? std::vector<crossover_alg> {horizontal_crossover, | ||
+ parallel_crossover, vertical_crossover} | ||
+ : std::vector<crossover_alg> {}; | ||
crossover_partition(op_2_partition, algs); | ||
|
||
std::vector<sc_op_ptr> fused_ops; | ||
diff --git a/src/backend/graph_compiler/core/src/runtime/env_vars.cpp b/src/backend/graph_compiler/core/src/runtime/env_vars.cpp | ||
index b1c2f755fd..ca72ac4ef7 100644 | ||
--- a/src/backend/graph_compiler/core/src/runtime/env_vars.cpp | ||
+++ b/src/backend/graph_compiler/core/src/runtime/env_vars.cpp | ||
@@ -36,9 +36,9 @@ const char *env_names[] = {"SC_CPU_JIT", "SC_TRACE", "SC_DUMP_GRAPH", | ||
"SC_VERBOSE", "SC_RUN_THREADS", "SC_TRACE_INIT_CAP", | ||
"SC_EXECUTION_VERBOSE", "SC_LOGGING_FILTER", "SC_HOME", "SC_SSA_PASSES", | ||
"SC_PRINT_PASS_TIME", "SC_PRINT_PASS_RESULT", "SC_JIT_PROFILE", | ||
- "SC_MIXED_FUSION", "SC_COST_MODEL", "SC_DEBUG_INFO", "SC_PREFETCH", | ||
- "SC_XBYAK_JIT_SAVE_OBJ", "SC_XBYAK_JIT_ASM_LISTING", | ||
- "SC_XBYAK_JIT_LOG_STACK_FRAME_MODEL", | ||
+ "SC_MIXED_FUSION", "SC_COARSE_GRAIN_FUSION", "SC_COST_MODEL", | ||
+ "SC_DEBUG_INFO", "SC_PREFETCH", "SC_XBYAK_JIT_SAVE_OBJ", | ||
+ "SC_XBYAK_JIT_ASM_LISTING", "SC_XBYAK_JIT_LOG_STACK_FRAME_MODEL", | ||
"SC_XBYAK_JIT_PAUSE_AFTER_CODEGEN", "SC_MANAGED_THREAD_POOL", | ||
"SC_TENSOR_INPLACE"}; | ||
|
||
diff --git a/src/backend/graph_compiler/core/src/runtime/env_vars.hpp b/src/backend/graph_compiler/core/src/runtime/env_vars.hpp | ||
index 02ecce42fa..a3a1ba6cac 100644 | ||
--- a/src/backend/graph_compiler/core/src/runtime/env_vars.hpp | ||
+++ b/src/backend/graph_compiler/core/src/runtime/env_vars.hpp | ||
@@ -50,6 +50,7 @@ enum key { | ||
SC_PRINT_PASS_RESULT, | ||
SC_JIT_PROFILE, | ||
SC_MIXED_FUSION, | ||
+ SC_COARSE_GRAIN_FUSION, | ||
SC_COST_MODEL, | ||
SC_DEBUG_INFO, | ||
SC_PREFETCH, | ||
diff --git a/src/backend/graph_compiler/core/src/runtime/trace.cpp b/src/backend/graph_compiler/core/src/runtime/trace.cpp | ||
index e101f2617f..bd0e2ae137 100644 | ||
--- a/src/backend/graph_compiler/core/src/runtime/trace.cpp | ||
+++ b/src/backend/graph_compiler/core/src/runtime/trace.cpp | ||
@@ -134,7 +134,7 @@ void write_traces(const std::list<thread_local_buffer_t *> &tls_buffers) { | ||
} | ||
} // namespace runtime | ||
|
||
-SC_INTERNAL_API void generate_trace_file() { | ||
+SC_API void generate_trace_file() { | ||
sc::release_runtime_memory(nullptr); | ||
} | ||
|
||
diff --git a/src/backend/graph_compiler/core/src/runtime/trace.hpp b/src/backend/graph_compiler/core/src/runtime/trace.hpp | ||
index 2569070e6c..93e14d9d95 100644 | ||
--- a/src/backend/graph_compiler/core/src/runtime/trace.hpp | ||
+++ b/src/backend/graph_compiler/core/src/runtime/trace.hpp | ||
@@ -21,6 +21,8 @@ | ||
#include <string> | ||
#include <vector> | ||
|
||
+#include "util/def.hpp" | ||
+ | ||
namespace sc { | ||
namespace runtime { | ||
|
||
@@ -38,6 +40,7 @@ struct trace_manager_t { | ||
void write_traces(const std::list<thread_local_buffer_t *> &tls_buffers); | ||
|
||
} // namespace runtime | ||
+SC_API void generate_trace_file(); | ||
int register_traced_func(const std::string &name); | ||
} // namespace sc | ||
|
||
diff --git a/tests/benchdnn/CMakeLists.txt b/tests/benchdnn/CMakeLists.txt | ||
index 9b86ffd8b6..d736ee226b 100644 | ||
--- a/tests/benchdnn/CMakeLists.txt | ||
+++ b/tests/benchdnn/CMakeLists.txt | ||
@@ -34,6 +34,7 @@ include_directories_with_host_compiler( | ||
${PROJECT_SOURCE_DIR}/include | ||
${PROJECT_SOURCE_DIR}/src/utils/ | ||
${PROJECT_SOURCE_DIR}/tests/ | ||
+ ${PROJECT_SOURCE_DIR}/src/backend/graph_compiler/core/src | ||
) | ||
|
||
if(BENCHDNN_USE_RDPMC) | ||
diff --git a/tests/benchdnn/graph/bench_graph.cpp b/tests/benchdnn/graph/bench_graph.cpp | ||
index e5007cd5e6..6672676416 100644 | ||
--- a/tests/benchdnn/graph/bench_graph.cpp | ||
+++ b/tests/benchdnn/graph/bench_graph.cpp | ||
@@ -21,6 +21,8 @@ | ||
#include "parser.hpp" | ||
#include "utils/parser.hpp" | ||
|
||
+#include "runtime/trace.hpp" | ||
+ | ||
namespace graph { | ||
|
||
void check_correctness(const settings_t &s) { | ||
@@ -45,6 +47,7 @@ void check_correctness(const settings_t &s) { | ||
pr.report(&res, pstr); | ||
} | ||
} | ||
+ sc::generate_trace_file(); | ||
} | ||
|
||
int bench(int argc, char **argv) { |