diff --git a/taichi/analysis/offline_cache_util.cpp b/taichi/analysis/offline_cache_util.cpp index 140a3915df67a..bcffd829f3940 100644 --- a/taichi/analysis/offline_cache_util.cpp +++ b/taichi/analysis/offline_cache_util.cpp @@ -71,9 +71,16 @@ static std::vector get_offline_cache_key_of_compile_config( static void get_offline_cache_key_of_snode_impl( SNode *snode, - BinaryOutputSerializer &serializer) { + BinaryOutputSerializer &serializer, + std::unordered_set &visited) { + if (auto iter = visited.find(snode->id); iter != visited.end()) { + serializer(snode->id); // Use snode->id as placeholder to identify a snode + return; + } + + visited.insert(snode->id); for (auto &c : snode->ch) { - get_offline_cache_key_of_snode_impl(c.get(), serializer); + get_offline_cache_key_of_snode_impl(c.get(), serializer, visited); } for (int i = 0; i < taichi_max_num_indices; ++i) { auto &extractor = snode->extractors[i]; @@ -106,21 +113,21 @@ static void get_offline_cache_key_of_snode_impl( } if (snode->grad_info && !snode->grad_info->is_primal()) { if (auto *grad_snode = snode->grad_info->grad_snode()) { - get_offline_cache_key_of_snode_impl(grad_snode, serializer); + get_offline_cache_key_of_snode_impl(grad_snode, serializer, visited); } } if (snode->exp_snode) { - get_offline_cache_key_of_snode_impl(snode->exp_snode, serializer); + get_offline_cache_key_of_snode_impl(snode->exp_snode, serializer, visited); } serializer(snode->bit_offset); serializer(snode->placing_shared_exp); serializer(snode->owns_shared_exponent); for (auto s : snode->exponent_users) { - get_offline_cache_key_of_snode_impl(s, serializer); + get_offline_cache_key_of_snode_impl(s, serializer, visited); } if (snode->currently_placing_exp_snode) { get_offline_cache_key_of_snode_impl(snode->currently_placing_exp_snode, - serializer); + serializer, visited); } if (snode->currently_placing_exp_snode_dtype) { serializer(snode->currently_placing_exp_snode_dtype->to_string()); @@ -138,7 +145,10 @@ std::string get_hashed_offline_cache_key_of_snode(SNode *snode) { BinaryOutputSerializer serializer; serializer.initialize(); - get_offline_cache_key_of_snode_impl(snode, serializer); + { + std::unordered_set visited; + get_offline_cache_key_of_snode_impl(snode, serializer, visited); + } serializer.finalize(); picosha2::hash256_one_by_one hasher;