diff --git a/python/taichi/misc/util.py b/python/taichi/misc/util.py index 460ec1ba68c8e..3f183ef9b7604 100644 --- a/python/taichi/misc/util.py +++ b/python/taichi/misc/util.py @@ -269,6 +269,11 @@ def dot_to_pdf(dot, filepath): fh.write(pdf_contents) +def get_kernel_stats(): + from taichi.core import ti_core + return ti_core.get_kernel_stats() + + __all__ = [ 'vec', 'veci', @@ -278,6 +283,7 @@ def dot_to_pdf(dot, filepath): 'dump_dot', 'dot_to_pdf', 'obsolete', + 'get_kernel_stats', 'get_traceback', 'set_gdb_trigger', 'print_profile_info', diff --git a/taichi/python/export_misc.cpp b/taichi/python/export_misc.cpp index a802a4074e3ce..0d6d4ae27a547 100644 --- a/taichi/python/export_misc.cpp +++ b/taichi/python/export_misc.cpp @@ -3,18 +3,19 @@ The use of this software is governed by the LICENSE file. *******************************************************************************/ +#include "taichi/backends/metal/api.h" +#include "taichi/backends/opengl/opengl_api.h" #include "taichi/common/core.h" #include "taichi/common/task.h" #include "taichi/math/math.h" #include "taichi/python/exception.h" -#include "taichi/python/print_buffer.h" #include "taichi/python/export.h" +#include "taichi/python/print_buffer.h" #include "taichi/system/benchmark.h" -#include "taichi/system/profiler.h" -#include "taichi/system/memory_usage_monitor.h" #include "taichi/system/dynamic_loader.h" -#include "taichi/backends/metal/api.h" -#include "taichi/backends/opengl/opengl_api.h" +#include "taichi/system/memory_usage_monitor.h" +#include "taichi/system/profiler.h" +#include "taichi/util/statistics.h" #if defined(TI_WITH_CUDA) #include "taichi/backends/cuda/cuda_driver.h" #endif @@ -173,6 +174,13 @@ void export_misc(py::module &m) { #else m.def("with_cc", []() { return false; }); #endif + + py::class_(m, "Statistics") + .def(py::init<>()) + .def("clear", &Statistics::clear) + .def("get_counters", &Statistics::get_counters); + m.def("get_kernel_stats", []() -> Statistics & { return stat; }, + py::return_value_policy::reference); } TI_NAMESPACE_END diff --git a/taichi/util/statistics.cpp b/taichi/util/statistics.cpp index 2190afabfa4fd..d50bbe40e21fa 100644 --- a/taichi/util/statistics.cpp +++ b/taichi/util/statistics.cpp @@ -5,19 +5,19 @@ TI_NAMESPACE_BEGIN Statistics stat; void Statistics::add(std::string key, Statistics::value_type value) { - counters[key] += value; + counters_[key] += value; } void Statistics::print(std::string *output) { std::vector keys; - for (auto const &item : counters) + for (auto const &item : counters_) keys.push_back(item.first); std::sort(keys.begin(), keys.end()); std::stringstream ss; for (auto const &k : keys) - ss << fmt::format("{:20}: {:.2f}\n", k, counters[k]); + ss << fmt::format("{:20}: {:.2f}\n", k, counters_[k]); if (output) { *output = ss.str(); @@ -27,7 +27,7 @@ void Statistics::print(std::string *output) { } void Statistics::clear() { - counters.clear(); + counters_.clear(); } TI_NAMESPACE_END diff --git a/taichi/util/statistics.h b/taichi/util/statistics.h index 117328c460ae8..9196117ad45d5 100644 --- a/taichi/util/statistics.h +++ b/taichi/util/statistics.h @@ -5,12 +5,10 @@ TI_NAMESPACE_BEGIN class Statistics { + public: using value_type = float64; + using counters_map = std::unordered_map; - private: - std::unordered_map counters; - - public: Statistics() = default; void add(std::string key, value_type value = value_type(1)); @@ -18,6 +16,13 @@ class Statistics { void print(std::string *output = nullptr); void clear(); + + inline const counters_map &get_counters() { + return counters_; + } + + private: + counters_map counters_; }; extern Statistics stat; diff --git a/tests/python/test_sfg.py b/tests/python/test_sfg.py index d3858282f4131..6334e348cd2e4 100644 --- a/tests/python/test_sfg.py +++ b/tests/python/test_sfg.py @@ -9,8 +9,8 @@ def test_remove_clear_list_from_fused_serial(): z = ti.field(ti.i32, shape=()) n = 32 - ti.root.pointer(ti.i, n).place(x) - ti.root.pointer(ti.i, n).place(y) + ti.root.pointer(ti.i, n).dense(ti.i, 1).place(x) + ti.root.pointer(ti.i, n).dense(ti.i, 1).place(y) @ti.kernel def init_xy(): @@ -23,6 +23,9 @@ def init_xy(): init_xy() ti.sync() + stats = ti.get_kernel_stats() + stats.clear() + @ti.kernel def inc(f: ti.template()): for i in f: @@ -40,6 +43,12 @@ def serial_z(): inc(y) ti.sync() + counters = stats.get_counters() + # each of x and y has two listgens: root -> pointer -> dense + assert int(counters['launched_tasks_list_gen']) == 4 + # clear list tasks have been fused into serial_z + assert int(counters['launched_tasks_serial']) == 1 + xs = x.to_numpy() ys = y.to_numpy() for i in range(n):