Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve python exposure of determine_point_ownership #3344

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/dolfinx/geometry/BoundingBoxTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class BoundingBoxTree
/// @param[in] padding Value to pad (extend) the the bounding box of
/// each entity by.
BoundingBoxTree(const mesh::Mesh<T>& mesh, int tdim,
std::span<const std::int32_t> entities, double padding = 0)
double padding, std::span<const std::int32_t> entities)
: _tdim(tdim)
{
if (tdim < 0 or tdim > mesh.topology()->dim())
Expand Down Expand Up @@ -266,7 +266,7 @@ class BoundingBoxTree
/// build the bounding box tree for
/// @param[in] padding Value to pad (extend) the the bounding box of
/// each entity by.
BoundingBoxTree(const mesh::Mesh<T>& mesh, int tdim, T padding = 0)
BoundingBoxTree(const mesh::Mesh<T>& mesh, int tdim, T padding)
: BoundingBoxTree::BoundingBoxTree(
mesh, tdim, range(mesh.topology_mutable(), tdim), padding)
{
Expand Down
30 changes: 15 additions & 15 deletions cpp/dolfinx/geometry/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -663,41 +663,41 @@ graph::AdjacencyList<std::int32_t> compute_colliding_cells(
/// @param[in] mesh The mesh
/// @param[in] points Points to check for collision (`shape=(num_points,
/// 3)`). Storage is row-major.
/// @param[in] cells Cells to check for ownership
jhale marked this conversation as resolved.
Show resolved Hide resolved
/// @param[in] padding Amount of absolute padding of bounding boxes of the mesh.
/// Each bounding box of the mesh is padded with this amount, to increase
/// the number of candidates, avoiding rounding errors in determining the owner
/// of a point if the point is on the surface of a cell in the mesh.
/// @return Tuple `(src_owner, dest_owner, dest_points, dest_cells)`,
/// where src_owner is a list of ranks corresponding to the input
/// points. dest_owner is a list of ranks corresponding to dest_points,
/// the points that this process owns. dest_cells contains the
/// corresponding cell for each entry in dest_points.
/// @return Point ownership data.
///
/// @note `dest_owner` is sorted
/// @note Returns -1 if no colliding process is found
/// @note `src_owner` is -1 if no colliding process is found
/// @note dest_points is flattened row-major, shape `(dest_owner.size(),
/// 3)`
/// @note Only looks through cells owned by the process
/// @note A large padding value can increase the runtime of the function by
/// orders of magnitude, because for non-colliding cells
/// one has to determine the closest cell among all processes with an
/// intersecting bounding box, which is an expensive operation to perform.
template <std::floating_point T>
PointOwnershipData<T> determine_point_ownership(const mesh::Mesh<T>& mesh,
std::span<const T> points,
T padding)
T padding,
std::span<const std::int32_t> cells = {})
{
MPI_Comm comm = mesh.comm();

const int tdim = mesh.topology()->dim();

std::vector<std::int32_t> local_cells;
if (cells.empty()) {
auto cell_map = mesh.topology()->index_map(tdim);
local_cells.resize(cell_map->size_local());
std::iota(local_cells.begin(), local_cells.end(), 0);
cells = std::span<const std::int32_t>(local_cells.data(), local_cells.size());
}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jhale I made cells a default argument at the c++ level (not at the wrapper level).
Ideally this local_cells variable would be initialized inside of the empty conditional but I could not figure it out since the span that wraps it dangles if the underlying vector goes out of scope. I tried using the static keyword but didn't work.

// Create a global bounding-box tree to find candidate processes with
// cells that could collide with the points
const int tdim = mesh.topology()->dim();
auto cell_map = mesh.topology()->index_map(tdim);
const std::int32_t num_cells = cell_map->size_local();
// NOTE: Should we send the cells in as input?
std::vector<std::int32_t> cells(num_cells, 0);
std::iota(cells.begin(), cells.end(), 0);
BoundingBoxTree bb(mesh, tdim, cells, padding);
BoundingBoxTree bb(mesh, tdim, padding, cells);
BoundingBoxTree global_bbtree = bb.create_global_tree(comm);

// Compute collisions:
Expand Down
2 changes: 1 addition & 1 deletion python/demo/demo_static-condensation.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def tabulate_A(A_, w_, c_, coords_, entity_local_index, permutation=ffi.NULL):
A.assemble()

# Create bounding box for function evaluation
bb_tree = geometry.bb_tree(msh, 2)
bb_tree = geometry.bb_tree(msh, 2, 0.0)

# Check against standard table value
p = np.array([[48.0, 52.0, 0.0]], dtype=np.float64)
Expand Down
5 changes: 3 additions & 2 deletions python/dolfinx/fem/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dolfinx.fem.function import Constant, Function

import numpy as np
import numpy.typing as npt

import dolfinx
from dolfinx import cpp as _cpp
Expand Down Expand Up @@ -57,7 +58,7 @@ def locate_dofs_geometrical(
def locate_dofs_topological(
V: typing.Union[dolfinx.fem.FunctionSpace, typing.Iterable[dolfinx.fem.FunctionSpace]],
entity_dim: int,
entities: numpy.typing.NDArray[np.int32],
entities: npt.NDArray[np.int32],
remote: bool = True,
) -> np.ndarray:
"""Locate degrees-of-freedom belonging to mesh entities topologically.
Expand Down Expand Up @@ -150,7 +151,7 @@ def dof_indices(self) -> tuple[np.ndarray, int]:

def dirichletbc(
value: typing.Union[Function, Constant, np.ndarray],
dofs: numpy.typing.NDArray[np.int32],
dofs: npt.NDArray[np.int32],
V: typing.Optional[dolfinx.fem.FunctionSpace] = None,
) -> DirichletBC:
"""Create a representation of Dirichlet boundary condition which
Expand Down
2 changes: 1 addition & 1 deletion python/dolfinx/fem/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def integral_types(self):

def get_integration_domains(
integral_type: IntegralType,
subdomain: typing.Optional[typing.Union[MeshTags, list[tuple[int, np.ndarray]]]],
subdomain: typing.Optional[typing.Union[MeshTags, list[tuple[int, npt.NDArray[np.int32]]]]],
subdomain_ids: list[int],
) -> list[tuple[int, np.ndarray]]:
"""Get integration domains from subdomain data.
Expand Down
10 changes: 5 additions & 5 deletions python/dolfinx/fem/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class Expression:
def __init__(
self,
e: ufl.core.expr.Expr,
X: np.ndarray,
X: typing.Union[npt.NDArray[np.float32], npt.NDArray[np.float64]],
comm: typing.Optional[_MPI.Comm] = None,
form_compiler_options: typing.Optional[dict] = None,
jit_options: typing.Optional[dict] = None,
Expand Down Expand Up @@ -196,7 +196,7 @@ def _create_expression(dtype):
def eval(
self,
mesh: Mesh,
entities: np.ndarray,
entities: npt.NDArray[np.int32],
values: typing.Optional[np.ndarray] = None,
) -> np.ndarray:
"""Evaluate Expression on entities.
Expand Down Expand Up @@ -412,8 +412,8 @@ def interpolate_nonmatching(
def interpolate(
self,
u0: typing.Union[typing.Callable, Expression, Function],
cells0: typing.Optional[np.ndarray] = None,
cells1: typing.Optional[np.ndarray] = None,
cells0: typing.Optional[npt.NDArray[np.int32]] = None,
cells1: typing.Optional[npt.NDArray[np.int32]] = None,
) -> None:
"""Interpolate an expression.

Expand Down Expand Up @@ -584,7 +584,7 @@ def _create_dolfinx_element(
comm: _MPI.Intracomm,
cell_type: _cpp.mesh.CellType,
ufl_e: ufl.FiniteElementBase,
dtype: np.dtype,
dtype: npt.DTypeLike,
) -> typing.Union[_cpp.fem.FiniteElement_float32, _cpp.fem.FiniteElement_float64]:
"""Create a DOLFINx element from a basix.ufl element."""
if np.issubdtype(dtype, np.float32):
Expand Down
48 changes: 45 additions & 3 deletions python/dolfinx/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def create_global_tree(self, comm) -> BoundingBoxTree:
def bb_tree(
mesh: Mesh,
dim: int,
padding: float,
entities: typing.Optional[npt.NDArray[np.int32]] = None,
padding: float = 0.0,
) -> BoundingBoxTree:
"""Create a bounding box tree for use in collision detection.

Expand All @@ -128,11 +128,11 @@ def bb_tree(
dtype = mesh.geometry.x.dtype
if np.issubdtype(dtype, np.float32):
return BoundingBoxTree(
_cpp.geometry.BoundingBoxTree_float32(mesh._cpp_object, dim, entities, padding)
_cpp.geometry.BoundingBoxTree_float32(mesh._cpp_object, dim, padding, entities)
)
elif np.issubdtype(dtype, np.float64):
return BoundingBoxTree(
_cpp.geometry.BoundingBoxTree_float64(mesh._cpp_object, dim, entities, padding)
_cpp.geometry.BoundingBoxTree_float64(mesh._cpp_object, dim, padding, entities)
)
else:
raise NotImplementedError(f"Type {dtype} not supported.")
Expand Down Expand Up @@ -270,3 +270,45 @@ def compute_distance_gjk(

"""
return _cpp.geometry.compute_distance_gjk(p, q)


def determine_point_ownership(
mesh: Mesh,
points: npt.NDArray[np.floating],
padding: float,
cells: typing.Optional[npt.NDArray[np.int32]] = None,
) -> PointOwnershipData:
"""Build point ownership data for a mesh-points pair.

First, potential collisions are found by computing intersections
between the bounding boxes of the cells and the set of points.
Then, actual containment pairs are determined using the GJK algorithm.

Args:
mesh: The mesh
points: Points to check for collision (``shape=(num_points, gdim)``)
padding: Amount of absolute padding of bounding boxes of the mesh.
Each bounding box of the mesh is padded with this amount, to increase
the number of candidates, avoiding rounding errors in determining the owner
of a point if the point is on the surface of a cell in the mesh.
cells: Cells to check for ownership
If ``None`` then all cells are considered.

Returns:
Point ownership data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the fields of PointOwnershipData documented in the Python wrapped class?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def src_owner(self) -> npt.NDArray[np.int32]:
"""Ranks owning each point sent into ownership determination for current process"""
return self._cpp_object.src_owner
def dest_owner(self) -> npt.NDArray[np.int32]:
"""Ranks that sent `dest_points` to current process"""
return self._cpp_object.dest_owners
def dest_points(self) -> npt.NDArray[np.floating]:
"""Points owned by current rank"""
return self._cpp_object.dest_points
def dest_cells(self) -> npt.NDArray[np.int32]:
"""Cell indices (local to process) where each entry of `dest_points` is located"""
return self._cpp_object.dest_cells

The doc is the same as on the c++ side.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks for pointing it out to me.

The C++ documentation for the return argument describing PointOwnershipData looks over-specified given the struct is now properly documented.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed most of it.


Note:
``dest_owner`` is sorted

``src_owner`` is -1 if no colliding process is found

A large padding value will increase the run-time of the code by orders
of magnitude. General advice is to use a padding on the scale of the
cell size.
"""
if cells is None:
map = mesh.topology.index_map(mesh.topology.dim)
cells = np.arange(map.size_local, dtype=np.int32)
return PointOwnershipData(
_cpp.geometry.determine_point_ownership(mesh._cpp_object, points, cells, padding)
)
8 changes: 7 additions & 1 deletion python/dolfinx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

from __future__ import annotations

import typing

import numpy as np
import numpy.typing as npt

from dolfinx import cpp as _cpp
from dolfinx.cpp.graph import partitioner
Expand All @@ -31,7 +34,10 @@
__all__ = ["adjacencylist", "partitioner"]


def adjacencylist(data: np.ndarray, offsets=None):
def adjacencylist(
data: typing.Union[npt.NDArray[np.int32], npt.NDArray[np.int64]],
offsets: typing.Optional[npt.NDArray[np.int32]] = None,
):
"""Create an AdjacencyList for int32 or int64 datasets.

Args:
Expand Down
8 changes: 5 additions & 3 deletions python/dolfinx/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def compute_midpoints(mesh: Mesh, dim: int, entities: npt.NDArray[np.int32]):
return _cpp.mesh.compute_midpoints(mesh._cpp_object, dim, entities)


def locate_entities(mesh: Mesh, dim: int, marker: typing.Callable) -> np.ndarray:
def locate_entities(mesh: Mesh, dim: int, marker: typing.Callable) -> npt.NDArray[np.int32]:
"""Compute mesh entities satisfying a geometric marking function.

Args:
Expand All @@ -470,7 +470,9 @@ def locate_entities(mesh: Mesh, dim: int, marker: typing.Callable) -> np.ndarray
return _cpp.mesh.locate_entities(mesh._cpp_object, dim, marker)


def locate_entities_boundary(mesh: Mesh, dim: int, marker: typing.Callable) -> np.ndarray:
def locate_entities_boundary(
mesh: Mesh, dim: int, marker: typing.Callable
) -> npt.NDArray[np.int32]:
"""Compute mesh entities that are connected to an owned boundary
facet and satisfy a geometric marking function.

Expand Down Expand Up @@ -534,7 +536,7 @@ def transfer_meshtag(

def refine(
mesh: Mesh,
edges: typing.Optional[np.ndarray] = None,
edges: typing.Optional[npt.NDArray[np.int32]] = None,
redistribute: bool = True,
ghost_mode: GhostMode = GhostMode.shared_facet,
option: RefinementOption = RefinementOption.none,
Expand Down
22 changes: 13 additions & 9 deletions python/dolfinx/wrappers/geometry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,16 @@ void declare_bbtree(nb::module_& m, std::string type)
"__init__",
[](dolfinx::geometry::BoundingBoxTree<T>* bbt,
const dolfinx::mesh::Mesh<T>& mesh, int dim,
double padding,
nb::ndarray<const std::int32_t, nb::ndim<1>, nb::c_contig>
entities,
double padding)
entities)
{
new (bbt) dolfinx::geometry::BoundingBoxTree<T>(
mesh, dim,
std::span<const std::int32_t>(entities.data(), entities.size()),
padding);
padding,
std::span<const std::int32_t>(entities.data(), entities.size()));
},
nb::arg("mesh"), nb::arg("dim"), nb::arg("entities"),
nb::arg("padding") = 0.0)
nb::arg("mesh"), nb::arg("dim"), nb::arg("padding"), nb::arg("entities"))
.def_prop_ro("num_bboxes",
&dolfinx::geometry::BoundingBoxTree<T>::num_bboxes)
.def(
Expand Down Expand Up @@ -180,13 +179,18 @@ void declare_bbtree(nb::module_& m, std::string type)
nb::arg("mesh"), nb::arg("dim"), nb::arg("indices"), nb::arg("points"));
m.def("determine_point_ownership",
[](const dolfinx::mesh::Mesh<T>& mesh,
nb::ndarray<const T, nb::c_contig> points, const T padding)
nb::ndarray<const T, nb::c_contig> points,
nb::ndarray<const std::int32_t, nb::ndim<1>, nb::c_contig> cells,
const T padding)
{
const std::size_t p_s0 = points.ndim() == 1 ? 1 : points.shape(0);
std::span<const T> _p(points.data(), 3 * p_s0);
return dolfinx::geometry::determine_point_ownership<T>(mesh, _p,
padding);
});
padding,
std::span(cells.data(), cells.size()));
},
nb::arg("mesh"), nb::arg("points"), nb::arg("padding"), nb::arg("cells"),
"Compute point ownership data for mesh-points pair.");

std::string pod_pyclass_name = "PointOwnershipData_" + type;
nb::class_<dolfinx::geometry::PointOwnershipData<T>>(m,
Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/fem/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def e3(x):
u3.interpolate(e3)

x0 = (mesh.geometry.x[0] + mesh.geometry.x[1]) / 2.0
tree = bb_tree(mesh, mesh.geometry.dim)
tree = bb_tree(mesh, mesh.geometry.dim, 0.0)
cell_candidates = compute_collisions_points(tree, x0)
cell = compute_colliding_cells(mesh, cell_candidates, x0).array
assert len(cell) > 0
Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/fem/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,7 +1025,7 @@ def f_test2(x):
u1_exact.x.scatter_forward()

# Find the single cell in mesh1 which is overlapped by mesh2
tree1 = bb_tree(mesh1, mesh1.topology.dim)
tree1 = bb_tree(mesh1, mesh1.topology.dim, 0.0)
cells_overlapped1 = compute_collisions_points(
tree1, np.array([p0_mesh2, p0_mesh2, 0.0]) / 2
).array
Expand Down
Loading