Skip to content

Commit

Permalink
Refactor t::geometry::TriangleMesh::SelectFacesByMask
Browse files Browse the repository at this point in the history
* Add error handling on empty mesh
* Use DISPATCH_INT_DTYPE_PREFIX_TO_TEMPLATE instead of a conditional
  branch
* Use UpdateTriangleIndicesByVertexMask helper to update triangle
  indices
* Use CopyAttributesByMask helper to copy the mesh attributes
* Add tests
  • Loading branch information
nsaiapova committed Nov 8, 2023
1 parent 0a78111 commit defc9df
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 40 deletions.
64 changes: 26 additions & 38 deletions cpp/open3d/t/geometry/TriangleMesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,17 @@ static void CopyAttributesByMasks(TriangleMesh &dst,
}

TriangleMesh TriangleMesh::SelectFacesByMask(const core::Tensor &mask) const {
if (!HasVertexPositions()) {
utility::LogWarning(
"[SelectFacesByMask] mesh has no vertex positions.");
return {};
}
if (!HasTriangleIndices()) {
utility::LogWarning(
"[SelectFacesByMask] mesh has no triangle indices.");
return {};
}

core::AssertTensorShape(mask, {GetTriangleIndices().GetLength()});
core::AssertTensorDtype(mask, core::Bool);
GetTriangleAttr().AssertSizeSynchronized();
Expand All @@ -1050,55 +1061,32 @@ TriangleMesh TriangleMesh::SelectFacesByMask(const core::Tensor &mask) const {
// select triangles
core::Tensor tris = GetTriangleIndices().IndexGet({mask});
core::Tensor tris_cpu = tris.To(core::Device()).Contiguous();
const int64_t num_tris = tris_cpu.GetLength();

// create mask for vertices that are part of the selected faces
const int64_t num_verts = GetVertexPositions().GetLength();
core::Tensor vertex_mask = core::Tensor::Zeros({num_verts}, core::Int32);
std::vector<int64_t> prefix_sum(num_verts + 1, 0);
{
int32_t *vertex_mask_ptr = vertex_mask.GetDataPtr<int32_t>();
if (tris_cpu.GetDtype() == core::Int32) {
int32_t *vert_idx_ptr = tris_cpu.GetDataPtr<int32_t>();
for (int64_t i = 0; i < tris_cpu.GetLength() * 3; ++i) {
vertex_mask_ptr[vert_idx_ptr[i]] = 1;
}
} else {
int64_t *vert_idx_ptr = tris_cpu.GetDataPtr<int64_t>();
for (int64_t i = 0; i < tris_cpu.GetLength() * 3; ++i) {
vertex_mask_ptr[vert_idx_ptr[i]] = 1;
}
}
utility::InclusivePrefixSum(
vertex_mask_ptr, vertex_mask_ptr + num_verts, &prefix_sum[1]);
}

// update triangle indices
if (tris_cpu.GetDtype() == core::Int32) {
int32_t *vert_idx_ptr = tris_cpu.GetDataPtr<int32_t>();
for (int64_t i = 0; i < num_tris * 3; ++i) {
int64_t new_idx = prefix_sum[vert_idx_ptr[i]];
vert_idx_ptr[i] = int32_t(new_idx);
}
} else {
int64_t *vert_idx_ptr = tris_cpu.GetDataPtr<int64_t>();
// empty tensor to further construct the vertex mask
core::Tensor vertex_mask;

DISPATCH_INT_DTYPE_PREFIX_TO_TEMPLATE(tris_cpu.GetDtype(), tris, [&]() {
vertex_mask = core::Tensor::Zeros(
{num_verts}, core::Dtype::FromType<scalar_tris_t>());
const int64_t num_tris = tris_cpu.GetLength();
scalar_tris_t *vertex_mask_ptr =
vertex_mask.GetDataPtr<scalar_tris_t>();
scalar_tris_t *vert_idx_ptr = tris_cpu.GetDataPtr<scalar_tris_t>();
// mask for the vertices, which are used in the triangles
for (int64_t i = 0; i < num_tris * 3; ++i) {
int64_t new_idx = prefix_sum[vert_idx_ptr[i]];
vert_idx_ptr[i] = new_idx;
vertex_mask_ptr[vert_idx_ptr[i]] = 1;
}
}
UpdateTriangleIndicesByVertexMask<scalar_tris_t>(tris_cpu, vertex_mask);
});

tris = tris_cpu.To(GetDevice());
vertex_mask = vertex_mask.To(GetDevice(), core::Bool);
core::Tensor verts = GetVertexPositions().IndexGet({vertex_mask});
TriangleMesh result(verts, tris);

// copy attributes
for (auto item : GetVertexAttr()) {
if (!result.HasVertexAttr(item.first)) {
result.SetVertexAttr(item.first,
item.second.IndexGet({vertex_mask}));
}
CopyAttributesByMasks(result, *this, vertex_mask, mask);

return result;
}
Expand Down
3 changes: 2 additions & 1 deletion cpp/open3d/t/geometry/TriangleMesh.h
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,8 @@ class TriangleMesh : public Geometry, public DrawableGeometry {
/// Returns a new mesh with the faces selected by a boolean mask.
/// \param mask A boolean mask with the shape (N) with N as the number of
/// faces in the mesh.
/// \return A new mesh with the selected faces.
/// \return A new mesh with the selected faces. If the original mesh is
/// empty, return an empty mesh.
TriangleMesh SelectFacesByMask(const core::Tensor &mask) const;

/// Returns a new mesh with the vertices selected by a vector of indices.
Expand Down
2 changes: 1 addition & 1 deletion cpp/pybind/t/geometry/trianglemesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ the partition id for each face.
number of faces in the mesh.
Returns:
A new mesh with the selected faces.
A new mesh with the selected faces. If the original mesh is empty, return an empty mesh.
Example:
Expand Down
107 changes: 107 additions & 0 deletions cpp/tests/t/geometry/TriangleMesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,113 @@ TEST_P(TriangleMeshPermuteDevices, CreateMobius) {
triangle_indices_custom));
}

TEST_P(TriangleMeshPermuteDevices, SelectFacesByMask) {
// check that an exception is thrown if the mesh is empty
t::geometry::TriangleMesh mesh_empty;
core::Tensor mask_empty =
core::Tensor::Zeros({12}, core::Bool, mesh_empty.GetDevice());
core::Tensor mask_full =
core::Tensor::Ones({12}, core::Bool, mesh_empty.GetDevice());

// check completely empty mesh
EXPECT_TRUE(mesh_empty.SelectFacesByMask(mask_empty).IsEmpty());
EXPECT_TRUE(mesh_empty.SelectFacesByMask(mask_full).IsEmpty());

// check mesh w/o triangles
core::Tensor cpu_vertices =
core::Tensor::Ones({2, 3}, core::Float32, mesh_empty.GetDevice());
mesh_empty.SetVertexPositions(cpu_vertices);
EXPECT_TRUE(mesh_empty.SelectFacesByMask(mask_empty).IsEmpty());
EXPECT_TRUE(mesh_empty.SelectFacesByMask(mask_full).IsEmpty());

// create box with normals, colors and labels defined.
t::geometry::TriangleMesh box = t::geometry::TriangleMesh::CreateBox();
core::Tensor vertex_colors = core::Tensor::Init<float>({{0.0, 0.0, 0.0},
{1.0, 1.0, 1.0},
{2.0, 2.0, 2.0},
{3.0, 3.0, 3.0},
{4.0, 4.0, 4.0},
{5.0, 5.0, 5.0},
{6.0, 6.0, 6.0},
{7.0, 7.0, 7.0}});
;
core::Tensor vertex_labels = core::Tensor::Init<float>({{0.0, 0.0, 0.0},
{1.0, 1.0, 1.0},
{2.0, 2.0, 2.0},
{3.0, 3.0, 3.0},
{4.0, 4.0, 4.0},
{5.0, 5.0, 5.0},
{6.0, 6.0, 6.0},
{7.0, 7.0, 7.0}}) *
10;
;
core::Tensor triangle_labels =
core::Tensor::Init<float>({{0.0, 0.0, 0.0},
{1.0, 1.0, 1.0},
{2.0, 2.0, 2.0},
{3.0, 3.0, 3.0},
{4.0, 4.0, 4.0},
{5.0, 5.0, 5.0},
{6.0, 6.0, 6.0},
{7.0, 7.0, 7.0},
{8.0, 8.0, 8.0},
{9.0, 9.0, 9.0},
{10.0, 10.0, 10.0},
{11.0, 11.0, 11.0}}) *
100;
box.SetVertexColors(vertex_colors);
box.SetVertexAttr("labels", vertex_labels);
box.ComputeTriangleNormals();
box.SetTriangleAttr("labels", triangle_labels);

// empty index list
EXPECT_TRUE(box.SelectFacesByMask(mask_empty).IsEmpty());

// set the expected value
core::Tensor expected_verts = core::Tensor::Init<float>({{0.0, 0.0, 1.0},
{1.0, 0.0, 1.0},
{0.0, 1.0, 1.0},
{1.0, 1.0, 1.0}});
core::Tensor expected_vert_colors =
core::Tensor::Init<float>({{2.0, 2.0, 2.0},
{3.0, 3.0, 3.0},
{6.0, 6.0, 6.0},
{7.0, 7.0, 7.0}});
core::Tensor expected_vert_labels =
core::Tensor::Init<float>({{20.0, 20.0, 20.0},
{30.0, 30.0, 30.0},
{60.0, 60.0, 60.0},
{70.0, 70.0, 70.0}});
core::Tensor expected_tris =
core::Tensor::Init<int64_t>({{0, 1, 3}, {0, 3, 2}});
core::Tensor tris_mask =
core::Tensor::Init<bool>({0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0});
core::Tensor expected_tri_normals =
box.GetTriangleNormals().IndexGet({tris_mask});
core::Tensor expected_tri_labels = core::Tensor::Init<float>(
{{800.0, 800.0, 800.0}, {900.0, 900.0, 900.0}});

// check basic case
t::geometry::TriangleMesh selected = box.SelectFacesByMask(tris_mask);

EXPECT_TRUE(selected.GetVertexPositions().AllClose(expected_verts));
EXPECT_TRUE(selected.GetVertexColors().AllClose(expected_vert_colors));
EXPECT_TRUE(
selected.GetVertexAttr("labels").AllClose(expected_vert_labels));
EXPECT_TRUE(selected.GetTriangleIndices().AllClose(expected_tris));
EXPECT_TRUE(selected.GetTriangleNormals().AllClose(expected_tri_normals));
EXPECT_TRUE(
selected.GetTriangleAttr("labels").AllClose(expected_tri_labels));

// Check that initial mesh is unchanged.
t::geometry::TriangleMesh box_untouched =
t::geometry::TriangleMesh::CreateBox();
EXPECT_TRUE(box.GetVertexPositions().AllClose(
box_untouched.GetVertexPositions()));
EXPECT_TRUE(box.GetTriangleIndices().AllClose(
box_untouched.GetTriangleIndices()));
}

TEST_P(TriangleMeshPermuteDevices, SelectByIndex) {
// check that an exception is thrown if the mesh is empty
t::geometry::TriangleMesh mesh_empty;
Expand Down
90 changes: 90 additions & 0 deletions python/test/t/geometry/test_trianglemesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,96 @@ def test_pickle(device):
mesh.triangle.indices.cpu().numpy())


@pytest.mark.parametrize("device", list_devices())
def test_select_faces_by_mask_32(device):
sphere_custom = o3d.t.geometry.TriangleMesh.create_sphere(
1, 3, o3c.float64, o3c.int32, device)

expected_verts = o3c.Tensor(
[[0.0, 0.0, 1.0], [0.866025, 0, 0.5], [0.433013, 0.75, 0.5],
[-0.866025, 0.0, 0.5], [-0.433013, -0.75, 0.5], [0.433013, -0.75, 0.5]
], o3c.float64, device)

expected_tris = o3c.Tensor([[0, 1, 2], [0, 3, 4], [0, 4, 5], [0, 5, 1]],
o3c.int32, device)

# check indices shape mismatch
mask_2d = o3c.Tensor([[False, False], [False, False], [False, False]],
o3c.bool, device)
with pytest.raises(RuntimeError):
selected = sphere_custom.select_faces_by_mask(mask_2d)

# check indices type mismatch
mask_float = o3c.Tensor([
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
], o3c.float32, device)
with pytest.raises(RuntimeError):
selected = sphere_custom.select_faces_by_mask(mask_float)

# check the basic case
mask = o3c.Tensor([
True, False, False, False, False, False, True, False, True, False, True,
False, False, False, False, False, False, False, False, False, False,
False, False, False
], o3c.bool, device)
selected = sphere_custom.select_faces_by_mask(mask)
assert selected.vertex.positions.allclose(expected_verts)
assert selected.triangle.indices.allclose(expected_tris)

# check that the original mesh is unmodified
untouched_sphere = o3d.t.geometry.TriangleMesh.create_sphere(
1, 3, o3c.float64, o3c.int32, device)
assert sphere_custom.vertex.positions.allclose(
untouched_sphere.vertex.positions)
assert sphere_custom.triangle.indices.allclose(
untouched_sphere.triangle.indices)


@pytest.mark.parametrize("device", list_devices())
def test_select_faces_by_mask_64(device):
sphere_custom = o3d.t.geometry.TriangleMesh.create_sphere(
1, 3, o3c.float64, o3c.int64, device)

# check indices shape mismatch
mask_2d = o3c.Tensor([[False, False], [False, False], [False, False]],
o3c.bool, device)
with pytest.raises(RuntimeError):
selected = sphere_custom.select_faces_by_mask(mask_2d)

# check indices type mismatch
mask_float = o3c.Tensor([
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
], o3c.float32, device)
with pytest.raises(RuntimeError):
selected = sphere_custom.select_faces_by_mask(mask_float)

expected_verts = o3c.Tensor(
[[0.0, 0.0, 1.0], [0.866025, 0, 0.5], [0.433013, 0.75, 0.5],
[-0.866025, 0.0, 0.5], [-0.433013, -0.75, 0.5], [0.433013, -0.75, 0.5]
], o3c.float64, device)

expected_tris = o3c.Tensor([[0, 1, 2], [0, 3, 4], [0, 4, 5], [0, 5, 1]],
o3c.int64, device)
# check the basic case
mask = o3c.Tensor([
True, False, False, False, False, False, True, False, True, False, True,
False, False, False, False, False, False, False, False, False, False,
False, False, False
], o3c.bool, device)

selected = sphere_custom.select_faces_by_mask(mask)
assert selected.vertex.positions.allclose(expected_verts)
assert selected.triangle.indices.allclose(expected_tris)

# check that the original mesh is unmodified
untouched_sphere = o3d.t.geometry.TriangleMesh.create_sphere(
1, 3, o3c.float64, o3c.int64, device)
assert sphere_custom.vertex.positions.allclose(
untouched_sphere.vertex.positions)
assert sphere_custom.triangle.indices.allclose(
untouched_sphere.triangle.indices)


@pytest.mark.parametrize("device", list_devices())
def test_select_by_index_32(device):
sphere_custom = o3d.t.geometry.TriangleMesh.create_sphere(
Expand Down

0 comments on commit defc9df

Please sign in to comment.