Skip to content

Commit

Permalink
tests(graph): ensure that the driver type in kernel node is correct
Browse files Browse the repository at this point in the history
  • Loading branch information
romintomasetti committed Sep 16, 2024
1 parent 0ca72be commit 1a31059
Showing 1 changed file with 45 additions and 1 deletion.
46 changes: 45 additions & 1 deletion core/unit_test/cuda/TestCuda_InterOp_Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
//@HEADER

#include <filesystem>
#include <fstream>
#include <regex>

#include <TestCuda_Category.hpp>
#include <TestGraph_helpers.hpp>
Expand Down Expand Up @@ -71,7 +73,7 @@ TEST(TEST_CATEGORY, graph_instantiate_and_debug_dot_print) {

auto root = Kokkos::Impl::GraphAccess::create_root_ref(graph);

root.then_parallel_for(1, SetViewToValueFunctor<Kokkos::Cuda, int>{data, 1});
auto node_ref = root.then_parallel_for(1, SetViewToValueFunctor<Kokkos::Cuda, int>{data, 1});

auto graph_ptr_impl =
Kokkos::Impl::GraphAccess::get_graph_weak_ptr(root).lock();
Expand All @@ -93,6 +95,48 @@ TEST(TEST_CATEGORY, graph_instantiate_and_debug_dot_print) {

ASSERT_TRUE(std::filesystem::exists(dot));

using root_node_t = Kokkos::Experimental::GraphNodeRef<
Kokkos::Cuda, Kokkos::Experimental::TypeErasedTag, Kokkos::Experimental::TypeErasedTag>;

using graph_node_kernel_t = Kokkos::Impl::GraphNodeKernelImpl<
Kokkos::Cuda,
Kokkos::RangePolicy<Kokkos::Cuda, Kokkos::Impl::IsGraphKernelTag>,
SetViewToValueFunctor<Kokkos::Cuda, int>,
Kokkos::ParallelForTag>;

using graph_node_ref_t = Kokkos::Experimental::GraphNodeRef<
Kokkos::Cuda,
graph_node_kernel_t,
root_node_t>;

static_assert(std::is_same_v<Kokkos::Impl::remove_cvref_t<decltype(node_ref)>, graph_node_ref_t>);

using graph_node_impl_t = Kokkos::Impl::GraphNodeImpl<
Kokkos::Cuda,
graph_node_kernel_t,
root_node_t>;

auto node_ptr = Kokkos::Impl::GraphAccess::get_node_ptr(node_ref);

auto& kernel = static_cast<graph_node_impl_t*>(node_ptr.get())->get_kernel();

static_assert(std::is_same_v<decltype(kernel), graph_node_kernel_t&>);

// We cannot extract the mangled name of 'fct_t', but reading the 'dot' output and demangling it leads to 'fct_t'.
// Moreover, the mangled name contains some unpredicable digits/letters.
// using driver_t = Kokkos::Impl::ParallelFor<
// SetViewToValueFunctor<Kokkos::Cuda, int>,
// Kokkos::RangePolicy<Kokkos::Cuda, Kokkos::Impl::IsGraphKernelTag>,
// Kokkos::Cuda>;
// using fct_t = Kokkos::Impl::cuda_parallel_launch_local_memory<driver_t>;

const std::regex expected("_ZN6Kokkos4Impl33cuda_parallel_launch_local_memoryINS0_11ParallelForIN65_GLOBAL__N__56f1aff9_26_TestCuda_InterOp_Graph_cpp_[a-z0-9]*_[0-9]*SetViewToValueFunctorINS_4CudaEiEENS_11RangePolicyIJS5_NS0_16IsGraphKernelTagEEEES5_EEEEvT_");

std::stringstream buffer;
buffer << std::ifstream(dot).rdbuf();

ASSERT_TRUE(std::regex_search(buffer.str(), expected)) << "Could not find the expected signature in " << dot;

unsigned long long flags =
Kokkos::Experimental::finite_max_v<unsigned long long>;
KOKKOS_IMPL_CUDA_SAFE_CALL(
Expand Down

0 comments on commit 1a31059

Please sign in to comment.