From b09c7957d26035efbbf62dc4f900f2d3b6b4e154 Mon Sep 17 00:00:00 2001 From: Sam Reeve <6740307+streeve@users.noreply.github.com> Date: Tue, 31 Oct 2023 10:01:30 -0400 Subject: [PATCH] Enable indexing into neighbor views --- core/src/Cabana_Parallel.hpp | 142 +++++++++++++++++++++++++++-------- 1 file changed, 110 insertions(+), 32 deletions(-) diff --git a/core/src/Cabana_Parallel.hpp b/core/src/Cabana_Parallel.hpp index f46d674f8..cd6af658c 100644 --- a/core/src/Cabana_Parallel.hpp +++ b/core/src/Cabana_Parallel.hpp @@ -210,6 +210,98 @@ class TeamVectorOpTag { }; +namespace Impl +{ +template +void executeNeighborParallel( PolicyType policy, FunctorType neigh_func, + std::string str ) +{ + if ( str.empty() ) + Kokkos::parallel_for( policy, neigh_func ); + else + Kokkos::parallel_for( str, policy, neigh_func ); +} + +template +void neighborParallelFirstSerialDirect( const FunctorType& functor, + ListType list, PolicyType policy, + std::string str ) +{ + auto neigh_func = KOKKOS_LAMBDA( const IndexType i ) + { + for ( IndexType n = 0; n < NeighborTraits::numNeighbor( list, i ); ++n ) + Impl::functorTagDispatch( + functor, i, + static_cast( + NeighborTraits::getNeighbor( list, i, n ) ) ); + }; + + executeNeighborParallel( policy, neigh_func, str ); +} +template +void neighborParallelFirstSerialIndirect( const FunctorType& functor, + const ListType& list, + PolicyType policy, std::string str ) +{ + auto neigh_func = KOKKOS_LAMBDA( const IndexType i ) + { + for ( IndexType n = 0; n < NeighborTraits::numNeighbor( list, i ); ++n ) + Impl::functorTagDispatch( functor, i, n ); + }; + + executeNeighborParallel( policy, neigh_func, str ); +} + +template +void neighborParallelFirstTeamDirect( const FunctorType& functor, ListType list, + PolicyType policy, + const IndexType range_begin, + std::string str ) +{ + auto neigh_func = + KOKKOS_LAMBDA( const typename PolicyType::member_type& team ) + { + IndexType i = team.league_rank() + range_begin; + Kokkos::parallel_for( + Kokkos::TeamThreadRange( team, + NeighborTraits::numNeighbor( list, i ) ), + [&]( const IndexType n ) + { + Impl::functorTagDispatch( + functor, i, + static_cast( + NeighborTraits::getNeighbor( list, i, n ) ) ); + } ); + }; + + executeNeighborParallel( policy, neigh_func, str ); +} +template +void neighborParallelFirstTeamIndirect( const FunctorType& functor, + ListType list, PolicyType policy, + const IndexType range_begin, + std::string str ) +{ + auto neigh_func = + KOKKOS_LAMBDA( const typename PolicyType::member_type& team ) + { + IndexType i = team.league_rank() + range_begin; + Kokkos::parallel_for( + Kokkos::TeamThreadRange( team, + NeighborTraits::numNeighbor( list, i ) ), + [&]( const IndexType n ) + { Impl::functorTagDispatch( functor, i, n ); } ); + }; + + executeNeighborParallel( policy, neigh_func, str ); +} + +} // namespace Impl + //---------------------------------------------------------------------------// /*! \brief Execute functor in parallel according to the execution policy over @@ -251,7 +343,8 @@ template inline void neighbor_parallel_for( const Kokkos::RangePolicy& exec_policy, const FunctorType& functor, const NeighborListType& list, - const FirstNeighborsTag, const SerialOpTag, const std::string& str = "" ) + const FirstNeighborsTag, const SerialOpTag, const std::string& str = "", + const bool direct_index = true ) { Kokkos::Profiling::pushRegion( "Cabana::neighbor_parallel_for" ); @@ -274,19 +367,14 @@ inline void neighbor_parallel_for( static_assert( is_accessible_from{}, "" ); - auto neigh_func = KOKKOS_LAMBDA( const index_type i ) - { - for ( index_type n = 0; - n < neighbor_list_traits::numNeighbor( list, i ); ++n ) - Impl::functorTagDispatch( - functor, i, - static_cast( - neighbor_list_traits::getNeighbor( list, i, n ) ) ); - }; - if ( str.empty() ) - Kokkos::parallel_for( linear_exec_policy, neigh_func ); + if ( direct_index ) + Impl::neighborParallelFirstSerialDirect( + functor, list, linear_exec_policy, str ); else - Kokkos::parallel_for( str, linear_exec_policy, neigh_func ); + Impl::neighborParallelFirstSerialIndirect( + functor, list, linear_exec_policy, str ); Kokkos::Profiling::popRegion(); } @@ -386,7 +474,8 @@ template inline void neighbor_parallel_for( const Kokkos::RangePolicy& exec_policy, const FunctorType& functor, const NeighborListType& list, - const FirstNeighborsTag, const TeamOpTag, const std::string& str = "" ) + const FirstNeighborsTag, const TeamOpTag, const std::string& str = "", + const bool direct_index = true ) { Kokkos::Profiling::pushRegion( "Cabana::neighbor_parallel_for" ); @@ -410,25 +499,14 @@ inline void neighbor_parallel_for( const auto range_begin = exec_policy.begin(); - auto neigh_func = - KOKKOS_LAMBDA( const typename kokkos_policy::member_type& team ) - { - index_type i = team.league_rank() + range_begin; - Kokkos::parallel_for( - Kokkos::TeamThreadRange( - team, neighbor_list_traits::numNeighbor( list, i ) ), - [&]( const index_type n ) - { - Impl::functorTagDispatch( - functor, i, - static_cast( - neighbor_list_traits::getNeighbor( list, i, n ) ) ); - } ); - }; - if ( str.empty() ) - Kokkos::parallel_for( team_policy, neigh_func ); + if ( direct_index ) + Impl::neighborParallelFirstTeamDirect( + functor, list, team_policy, range_begin, str ); else - Kokkos::parallel_for( str, team_policy, neigh_func ); + Impl::neighborParallelFirstTeamIndirect( + functor, list, team_policy, range_begin, str ); Kokkos::Profiling::popRegion(); }