Skip to content

Commit

Permalink
graph: allow access to native graph object
Browse files Browse the repository at this point in the history
The native graph objects should be accessible for advanced users.
  • Loading branch information
romintomasetti committed Sep 27, 2024
1 parent fa09fa8 commit 2a18a62
Show file tree
Hide file tree
Showing 10 changed files with 450 additions and 1 deletion.
2 changes: 2 additions & 0 deletions cmake/kokkos_arch.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,8 @@ IF(KOKKOS_ENABLE_SYCL)
)
ENDIF()
ENDIF()

CHECK_CXX_SYMBOL_EXISTS(SYCL_EXT_ONEAPI_GRAPH "sycl/sycl.hpp" KOKKOS_IMPL_HAVE_SYCL_EXT_ONEAPI_GRAPH)
ENDIF()

SET(CUDA_ARCH_ALREADY_SPECIFIED "")
Expand Down
3 changes: 3 additions & 0 deletions core/src/Cuda/Kokkos_Cuda_Graph_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ struct GraphImpl<Kokkos::Cuda> {
m_execution_space, _graph_node_kernel_ctor_tag{},
aggregate_kernel_impl_t{});
}

cudaGraph_t cuda_graph() { return m_graph; }
cudaGraphExec_t cuda_graph_exec() { return m_graph_exec; }
};

} // end namespace Impl
Expand Down
3 changes: 3 additions & 0 deletions core/src/HIP/Kokkos_HIP_Graph_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ class GraphImpl<Kokkos::HIP> {
KOKKOS_ENSURES(m_graph_exec);
}

hipGraph_t hip_graph() { return m_graph; }
hipGraphExec_t hip_graph_exec() { return m_graph_exec; }

private:
Kokkos::HIP m_execution_space;
hipGraph_t m_graph = nullptr;
Expand Down
41 changes: 40 additions & 1 deletion core/src/Kokkos_Graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ struct [[nodiscard]] Graph {
}

void submit() const { submit(get_execution_space()); }

decltype(auto) native_graph();

decltype(auto) native_graph_exec();
};

// </editor-fold> end Graph }}}1
Expand Down Expand Up @@ -168,6 +172,42 @@ create_graph(Closure&& arg_closure) {
// </editor-fold> end create_graph }}}1
//==============================================================================

template <class ExecutionSpace>
decltype(auto) Graph<ExecutionSpace>::native_graph() {
KOKKOS_EXPECTS(bool(m_impl_ptr));
#if defined(KOKKOS_ENABLE_CUDA)
if constexpr (std::is_same_v<ExecutionSpace, Kokkos::Cuda>) {
return m_impl_ptr->cuda_graph();
}
#elif defined(KOKKOS_ENABLE_HIP) && defined(KOKKOS_IMPL_HIP_NATIVE_GRAPH)
if constexpr (std::is_same_v<ExecutionSpace, Kokkos::HIP>) {
return m_impl_ptr->hip_graph();
}
#elif defined(KOKKOS_ENABLE_SYCL) && defined(SYCL_EXT_ONEAPI_GRAPH)
if constexpr (std::is_same_v<ExecutionSpace, Kokkos::SYCL>) {
return m_impl_ptr->sycl_graph();
}
#endif
}

template <class ExecutionSpace>
decltype(auto) Graph<ExecutionSpace>::native_graph_exec() {
KOKKOS_EXPECTS(bool(m_impl_ptr));
#if defined(KOKKOS_ENABLE_CUDA)
if constexpr (std::is_same_v<ExecutionSpace, Kokkos::Cuda>) {
return m_impl_ptr->cuda_graph_exec();
}
#elif defined(KOKKOS_ENABLE_HIP) && defined(KOKKOS_IMPL_HIP_NATIVE_GRAPH)
if constexpr (std::is_same_v<ExecutionSpace, Kokkos::HIP>) {
return m_impl_ptr->hip_graph_exec();
}
#elif defined(KOKKOS_ENABLE_SYCL) && defined(SYCL_EXT_ONEAPI_GRAPH)
if constexpr (std::is_same_v<ExecutionSpace, Kokkos::SYCL>) {
return m_impl_ptr->sycl_graph_exec();
}
#endif
}

} // end namespace Experimental
} // namespace Kokkos

Expand All @@ -179,7 +219,6 @@ create_graph(Closure&& arg_closure) {
#include <impl/Kokkos_Default_Graph_Impl.hpp>
#include <Cuda/Kokkos_Cuda_Graph_Impl.hpp>
#if defined(KOKKOS_ENABLE_HIP)
// The implementation of hipGraph in ROCm 5.2 is bugged, so we cannot use it.
#if defined(KOKKOS_IMPL_HIP_NATIVE_GRAPH)
#include <HIP/Kokkos_HIP_Graph_Impl.hpp>
#endif
Expand Down
3 changes: 3 additions & 0 deletions core/src/SYCL/Kokkos_SYCL_Graph_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class GraphImpl<Kokkos::SYCL> {
m_graph_exec = m_graph.finalize();
}

auto& sycl_graph() { return m_graph; }
auto& sycl_graph_exec() { return m_graph_exec; }

private:
Kokkos::SYCL m_execution_space;
sycl::ext::oneapi::experimental::command_graph<
Expand Down
23 changes: 23 additions & 0 deletions core/unit_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,13 @@ if(Kokkos_ENABLE_CUDA)
UnitTestMainInit.cpp
cuda/TestCuda_InterOp_StreamsMultiGPU.cpp
)

KOKKOS_ADD_EXECUTABLE_AND_TEST(
CoreUnitTest_CudaInterOpGraph
SOURCES
UnitTestMainInit.cpp
cuda/TestCuda_InterOp_Graph.cpp
)
endif()

if(Kokkos_ENABLE_HIP)
Expand Down Expand Up @@ -864,6 +871,13 @@ if(Kokkos_ENABLE_HIP)
UnitTestMain.cpp
hip/TestHIP_InterOp_Streams.cpp
)

KOKKOS_ADD_EXECUTABLE_AND_TEST(
CoreUnitTest_HIPInterOpGraph
SOURCES
UnitTestMainInit.cpp
hip/TestHIP_InterOp_Graph.cpp
)
endif()

if(Kokkos_ENABLE_SYCL)
Expand Down Expand Up @@ -948,6 +962,15 @@ if(Kokkos_ENABLE_SYCL)
UnitTestMainInit.cpp
sycl/TestSYCL_InterOp_StreamsMultiGPU.cpp
)

if(KOKKOS_IMPL_HAVE_SYCL_EXT_ONEAPI_GRAPH)
KOKKOS_ADD_EXECUTABLE_AND_TEST(
CoreUnitTest_SYCLInterOpGraph
SOURCES
UnitTestMainInit.cpp
sycl/TestSYCL_InterOp_Graph.cpp
)
endif()
endif()

SET(DEFAULT_DEVICE_SOURCES
Expand Down
10 changes: 10 additions & 0 deletions core/unit_test/TestGraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,16 @@ TEST_F(TEST_CATEGORY_FIXTURE(graph), force_global_launch) {
#endif
}

// Ensure that an empty graph on the default host execution space
// can be submitted.
TEST_F(TEST_CATEGORY_FIXTURE(graph), empty_graph_default_host_exec) {
auto graph =
Kokkos::Experimental::create_graph(Kokkos::DefaultHostExecutionSpace{});
graph.instantiate();
graph.submit();
graph.get_execution_space().fence();
}

template <typename ViewType, size_t TargetIndex, size_t NumIndices = 0>
struct FetchValuesAndContribute {
static_assert(std::is_same_v<typename ViewType::value_type,
Expand Down
143 changes: 143 additions & 0 deletions core/unit_test/cuda/TestCuda_InterOp_Graph.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#include <filesystem>
#include <fstream>
#include <regex>

#include <TestCuda_Category.hpp>
#include <Kokkos_Core.hpp>
#include <Kokkos_Graph.hpp>

#include <gtest/gtest.h>

namespace {

template <typename ViewType>
struct Increment {
ViewType data;

KOKKOS_FUNCTION
void operator()(const int) const { ++data(); }
};

class TEST_CATEGORY_FIXTURE(GraphInterOp) : public ::testing::Test {
public:
using execution_space = Kokkos::Cuda;
using view_t =
Kokkos::View<int, execution_space, Kokkos::MemoryTraits<Kokkos::Atomic>>;
using graph_t = Kokkos::Experimental::Graph<execution_space>;

void SetUp() override {
data = view_t(Kokkos::view_alloc(exec, "witness"));

graph = Kokkos::Experimental::create_graph(exec);

auto root = Kokkos::Impl::GraphAccess::create_root_ref(*graph);

root.then_parallel_for(1, Increment<view_t>{data});
}
protected:
const execution_space exec{};
view_t data;
std::optional<graph_t> graph;
};

// This test checks the promises of Kokkos::Graph against its
// underlying Cuda native objects.
TEST_F(TEST_CATEGORY_FIXTURE(GraphInterOp), promises_on_native_objects) {
// Before instantiation, the Cuda graph is valid, but the Cuda executable
// graph is still null.
cudaGraph_t cuda_graph = graph->native_graph();

ASSERT_NE(cuda_graph, nullptr);
ASSERT_EQ(graph->native_graph_exec(), nullptr);

// After instantiation, both native objects are valid.
graph->instantiate();

cudaGraphExec_t cuda_graph_exec = graph->native_graph_exec();

ASSERT_EQ(graph->native_graph(), cuda_graph);
ASSERT_NE(cuda_graph_exec, nullptr);

// Submission should not affect the underlying objects.
graph->submit();

ASSERT_EQ(graph->native_graph(), cuda_graph);
ASSERT_EQ(graph->native_graph_exec(), cuda_graph_exec);
}

// Count the number of nodes. This is useful to ensure no spurious
// (possibly empty) node is added.
TEST_F(TEST_CATEGORY_FIXTURE(GraphInterOp), count_nodes) {
graph->instantiate();

size_t num_nodes = 0;

KOKKOS_IMPL_CUDA_SAFE_CALL(
cudaGraphGetNodes(graph->native_graph(), nullptr, &num_nodes));

ASSERT_EQ(num_nodes, 2u);
}

// Use native Cuda graph to generate a DOT representation.
TEST_F(TEST_CATEGORY_FIXTURE(GraphInterOp), debug_dot_print) {
#if CUDA_VERSION < 11600
GTEST_SKIP() << "Export a graph to DOT requires Cuda 11.6.";
#else
graph->instantiate();

const auto dot = std::filesystem::temp_directory_path() / "cuda_graph.dot";

// Convert path to string then to const char * to make it work on Windows.
KOKKOS_IMPL_CUDA_SAFE_CALL(
cudaGraphDebugDotPrint(graph->native_graph(), dot.string().c_str(),
cudaGraphDebugDotFlagsVerbose));

ASSERT_TRUE(std::filesystem::exists(dot));
ASSERT_GT(std::filesystem::file_size(dot), 0u);

// We could write a check against the full kernel's function signature, but
// it would make the test rely too much on internal implementation details.
// Therefore, we just look for the functor and policy. Note that the
// signature is mangled in the 'dot' output.
const std::string expected("[A-Za-z0-9_]+Increment[A-Za-z0-9_]+RangePolicy");

std::stringstream buffer;
buffer << std::ifstream(dot).rdbuf();

ASSERT_TRUE(std::regex_search(buffer.str(), std::regex(expected)))
<< "Could not find expected signature regex " << std::quoted(expected)
<< " in " << dot;
#endif
}

// Ensure that the graph has been instantiated with the default flag.
TEST_F(TEST_CATEGORY_FIXTURE(GraphInterOp), instantiation_flags) {
#if CUDA_VERSION < 12000
GTEST_SKIP() << "Graph instantiation flag inspection requires Cuda 12.";
#else
unsigned long long flags =
Kokkos::Experimental::finite_max_v<unsigned long long>;
KOKKOS_IMPL_CUDA_SAFE_CALL(
cudaGraphExecGetFlags(graph->native_graph_exec(), &flags));

ASSERT_EQ(flags, 0u);
#endif
}

} // namespace
Loading

0 comments on commit 2a18a62

Please sign in to comment.