diff --git a/Project.toml b/Project.toml index 8c56fad..8a41d67 100644 --- a/Project.toml +++ b/Project.toml @@ -13,9 +13,16 @@ Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +[weakdeps] +CellListMap = "69e1c6dd-3888-40e6-b3c8-31ac5f578864" + +[extensions] +PointNeighborsCellListMapExt = "CellListMap" + [compat] Adapt = "4" Atomix = "0.1" +CellListMap = "0.9" GPUArraysCore = "0.1" KernelAbstractions = "0.9" LinearAlgebra = "1" diff --git a/ext/PointNeighborsCellListMapExt.jl b/ext/PointNeighborsCellListMapExt.jl new file mode 100644 index 0000000..5ccb7bc --- /dev/null +++ b/ext/PointNeighborsCellListMapExt.jl @@ -0,0 +1,84 @@ +module PointNeighborsCellListMapExt + +using PointNeighbors +using CellListMap: CellListMap + +mutable struct CellListMapNeighborhoodSearch{CL, B} + cell_list :: CL + box :: B + + function PointNeighbors.CellListMapNeighborhoodSearch(NDIMS, search_radius) + # Create a cell list with only one point and resize it later + x = zeros(NDIMS, 1) + box = CellListMap.Box(CellListMap.limits(x, x), search_radius) + cell_list = CellListMap.CellList(x, x, box) + + return new{typeof(cell_list), typeof(box)}(cell_list, box) + end +end + +function PointNeighbors.search_radius(neighborhood_search::CellListMapNeighborhoodSearch) + return neighborhood_search.box.cutoff +end + +function Base.ndims(neighborhood_search::CellListMapNeighborhoodSearch) + return length(neighborhood_search.box.cell_size) +end + +function PointNeighbors.initialize!(neighborhood_search::CellListMapNeighborhoodSearch, + x::AbstractMatrix, y::AbstractMatrix) + PointNeighbors.update!(neighborhood_search, x, y) +end + +function PointNeighbors.update!(neighborhood_search::CellListMapNeighborhoodSearch, + x::AbstractMatrix, y::AbstractMatrix; + points_moving = (true, true)) + (; cell_list) = neighborhood_search + + # Resize box + box = CellListMap.Box(CellListMap.limits(x, y), neighborhood_search.box.cutoff) + neighborhood_search.box = box + + # Resize and update cell list + CellListMap.UpdateCellList!(x, y, box, cell_list) + + # Recalculate number of batches for multithreading + CellListMap.set_number_of_batches!(cell_list) + + return neighborhood_search +end + +# The type annotation is to make Julia specialize on the type of the function. +# Otherwise, unspecialized code will cause a lot of allocations +# and heavily impact performance. +# See https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing +function PointNeighbors.foreach_point_neighbor(f::T, system_coords, neighbor_coords, + neighborhood_search::CellListMapNeighborhoodSearch; + points = axes(system_coords, 2), + parallel = true) where {T} + (; cell_list, box) = neighborhood_search + + # `0` is the returned output, which we don't use + CellListMap.map_pairwise!(0, box, cell_list, + parallel = parallel) do x, y, i, j, d2, output + # Skip all indices not in `points` + i in points || return output + + pos_diff = x - y + distance = sqrt(d2) + + @inline f(i, j, pos_diff, distance) + + return output + end + + return nothing +end + +function PointNeighbors.copy_neighborhood_search(nhs::CellListMapNeighborhoodSearch, + search_radius, n_points; + eachpoint = 1:n_points) + return PointNeighbors.CellListMapNeighborhoodSearch(ndims(nhs), search_radius) +end + +end diff --git a/src/PointNeighbors.jl b/src/PointNeighbors.jl index b7a8203..8314d15 100644 --- a/src/PointNeighbors.jl +++ b/src/PointNeighbors.jl @@ -20,7 +20,8 @@ include("nhs_precomputed.jl") include("gpu.jl") export foreach_point_neighbor, foreach_neighbor -export TrivialNeighborhoodSearch, GridNeighborhoodSearch, PrecomputedNeighborhoodSearch +export TrivialNeighborhoodSearch, GridNeighborhoodSearch, PrecomputedNeighborhoodSearch, + CellListMapNeighborhoodSearch export DictionaryCellList, FullGridCellList export ParallelUpdate, SemiParallelUpdate, SerialUpdate export initialize!, update!, initialize_grid!, update_grid! diff --git a/src/neighborhood_search.jl b/src/neighborhood_search.jl index dacbc10..3299b98 100644 --- a/src/neighborhood_search.jl +++ b/src/neighborhood_search.jl @@ -236,3 +236,5 @@ end @inline function periodic_coords(coords, periodic_box::Nothing) return coords end + +CellListMapNeighborhoodSearch() = error("CellListMap.jl has to be imported to use this") diff --git a/test/Project.toml b/test/Project.toml index f57fdf2..3e2bd50 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,10 +1,12 @@ [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CellListMap = "69e1c6dd-3888-40e6-b3c8-31ac5f578864" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] BenchmarkTools = "1" +CellListMap = "0.9" Plots = "1" Test = "1" diff --git a/test/neighborhood_search.jl b/test/neighborhood_search.jl index a49ad2b..c54d62c 100644 --- a/test/neighborhood_search.jl +++ b/test/neighborhood_search.jl @@ -178,6 +178,7 @@ search_radius, backend = Vector{Vector{Int}})), PrecomputedNeighborhoodSearch{NDIMS}(; search_radius, n_points), + CellListMapNeighborhoodSearch(NDIMS, search_radius), ] names = [ @@ -187,6 +188,7 @@ "`GridNeighborhoodSearch` with `FullGridCellList` with `DynamicVectorOfVectors` and `SemiParallelUpdate`", "`GridNeighborhoodSearch` with `FullGridCellList` with `Vector{Vector}`", "`PrecomputedNeighborhoodSearch`", + "`CellListMapNeighborhoodSearch`", ] # Also test copied templates @@ -202,6 +204,7 @@ max_corner, backend = Vector{Vector{Int32}})), PrecomputedNeighborhoodSearch{NDIMS}(), + CellListMapNeighborhoodSearch(NDIMS, 1.0), ] copied_nhs = copy_neighborhood_search.(template_nhs, search_radius, n_points) append!(neighborhood_searches, copied_nhs) diff --git a/test/test_util.jl b/test/test_util.jl index 07a3ac9..a5c1911 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -3,6 +3,9 @@ using Test: @test, @testset, @test_throws using PointNeighbors +# Load `PointNeighborsCellListMapExt` +import CellListMap + """ @trixi_testset "name of the testset" #= code to test #=