Skip to content

Commit

Permalink
[SpatialPartitioning] Add test for KdTree datastructure
Browse files Browse the repository at this point in the history
Aim at detecting problems with
 - the structure KdTreeDefaultTraits (copies, references, ...)
 - duplicated samples
  • Loading branch information
nmellado committed Sep 18, 2024
1 parent 5cb11da commit ff12255
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 20 deletions.
52 changes: 32 additions & 20 deletions tests/common/kdtree_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ class MyPoint {
};

template<typename Scalar, typename VectorContainer>
bool check_range_neighbors(const VectorContainer& points, const std::vector<int>& sampling, int index, Scalar r, const std::vector<int>& neighbors)
bool check_range_neighbors(const VectorContainer& points, const std::vector<int>& sampling, int index, Scalar r,
const std::vector<int>& neighbors,
bool allow_duplicates = false)
{
if (has_duplicate(neighbors))
if (!allow_duplicates && has_duplicate(neighbors))
{
return false;
}
Expand Down Expand Up @@ -103,9 +105,11 @@ bool check_range_neighbors(const VectorContainer& points, const std::vector<int>
}

template<typename Scalar, typename VectorType, typename VectorContainer>
bool check_range_neighbors(const VectorContainer& points, const std::vector<int>& sampling, const VectorType& point, Scalar r, const std::vector<int>& neighbors)
bool check_range_neighbors(const VectorContainer& points, const std::vector<int>& sampling,
const VectorType& point, Scalar r, const std::vector<int>& neighbors,
bool allow_duplicates = false)
{
if (has_duplicate(neighbors))
if (!allow_duplicates && has_duplicate(neighbors))
{
return false;
}
Expand Down Expand Up @@ -147,14 +151,15 @@ bool check_range_neighbors(const VectorContainer& points, const std::vector<int>
}

template<typename Scalar, typename VectorContainer>
bool check_k_nearest_neighbors(const VectorContainer& points, int index, int k, const std::vector<int>& neighbors)
bool check_k_nearest_neighbors(const VectorContainer& points, int index, int k,
const std::vector<int>& neighbors, bool allow_duplicates = false)
{
if (int(points.size()) > k && int(neighbors.size()) != k)
{
return false;
}

if (has_duplicate(neighbors))
if (!allow_duplicates && has_duplicate(neighbors))
{
return false;
}
Expand Down Expand Up @@ -190,14 +195,15 @@ bool check_k_nearest_neighbors(const VectorContainer& points, int index, int k,
}

template<typename Scalar, typename VectorContainer>
bool check_k_nearest_neighbors(const VectorContainer& points, const std::vector<int>& sampling, int index, int k, const std::vector<int>& neighbors)
bool check_k_nearest_neighbors(const VectorContainer& points, const std::vector<int>& sampling, int index, int k,
const std::vector<int>& neighbors, bool allow_duplicates = false)
{
if (int(points.size()) > k && int(neighbors.size()) != k)
{
return false;
}

if (has_duplicate(neighbors))
if (!allow_duplicates && has_duplicate(neighbors))
{
return false;
}
Expand Down Expand Up @@ -242,14 +248,16 @@ bool check_k_nearest_neighbors(const VectorContainer& points, const std::vector<
}

template<typename Scalar, typename VectorType, typename VectorContainer>
bool check_k_nearest_neighbors(const VectorContainer& points, const std::vector<int>& sampling, const VectorType& point, int k, const std::vector<int>& neighbors)
bool check_k_nearest_neighbors(const VectorContainer& points, const std::vector<int>& sampling,
const VectorType& point, int k, const std::vector<int>& neighbors,
bool allow_duplicates = false)
{
if (int(sampling.size()) >= k && int(neighbors.size()) != k)
{
return false;
}

if (has_duplicate(neighbors))
if (!allow_duplicates && has_duplicate(neighbors))
{
return false;
}
Expand Down Expand Up @@ -287,14 +295,15 @@ bool check_k_nearest_neighbors(const VectorContainer& points, const std::vector<


template<typename Scalar, typename VectorType, typename VectorContainer>
bool check_k_nearest_neighbors(const VectorContainer& points, const VectorType& point, int k, const std::vector<int>& neighbors)
bool check_k_nearest_neighbors(const VectorContainer& points, const VectorType& point, int k,
const std::vector<int>& neighbors, bool allow_duplicates = false)
{
if (int(points.size()) >= k && int(neighbors.size()) != k)
{
return false;
}

if (has_duplicate(neighbors))
if (!allow_duplicates && has_duplicate(neighbors))
{
return false;
}
Expand Down Expand Up @@ -323,24 +332,27 @@ bool check_k_nearest_neighbors(const VectorContainer& points, const VectorType&


template<typename Scalar, typename VectorContainer>
bool check_nearest_neighbor(const VectorContainer& points, int index, int nearest)
bool check_nearest_neighbor(const VectorContainer& points, int index, int nearest, bool allow_duplicates = false)
{
return check_k_nearest_neighbors<Scalar, VectorContainer>(points, index, 1, { nearest });
return check_k_nearest_neighbors<Scalar, VectorContainer>(points, index, 1, { nearest }, allow_duplicates);
}
template<typename Scalar, typename VectorType, typename VectorContainer>
bool check_nearest_neighbor(const VectorContainer& points, const VectorType& point, int nearest)
bool check_nearest_neighbor(const VectorContainer& points, const VectorType& point, int nearest,
bool allow_duplicates = false)
{
return check_k_nearest_neighbors<Scalar, VectorType, VectorContainer>(points, point, 1, { nearest });
return check_k_nearest_neighbors<Scalar, VectorType, VectorContainer>(points, point, 1, { nearest }, allow_duplicates);
}

template<typename Scalar, typename VectorType, typename VectorContainer>
bool check_nearest_neighbor(const VectorContainer& points, const std::vector<int>& sampling, const VectorType& point, int nearest)
bool check_nearest_neighbor(const VectorContainer& points, const std::vector<int>& sampling,
const VectorType& point, int nearest, bool allow_duplicates = false)
{
return check_k_nearest_neighbors<Scalar, VectorContainer>(points, sampling, point, 1, { nearest });
return check_k_nearest_neighbors<Scalar, VectorContainer>(points, sampling, point, 1, { nearest }, allow_duplicates);
}

template<typename Scalar, typename VectorType, typename VectorContainer>
bool check_nearest_neighbor(const VectorContainer& points, const std::vector<int>& sampling, int index, int nearest)
bool check_nearest_neighbor(const VectorContainer& points, const std::vector<int>& sampling, int index, int nearest,
bool allow_duplicates = false)
{
return check_k_nearest_neighbors<Scalar, VectorType, VectorContainer>(points, sampling, index, 1, { nearest });
return check_k_nearest_neighbors<Scalar, VectorType, VectorContainer>(points, sampling, index, 1, { nearest }, allow_duplicates);
}
1 change: 1 addition & 0 deletions tests/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ add_multi_test(weight_kernel.cpp)
add_multi_test(queries_range.cpp)
add_multi_test(queries_nearest.cpp)
add_multi_test(queries_knearest.cpp)
add_multi_test(kdtree.cpp)
123 changes: 123 additions & 0 deletions tests/src/kdtree.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
This Source Code Form is subject to the terms of the Mozilla Public
License, v. 2.0. If a copy of the MPL was not distributed with this
file, You can obtain one at http://mozilla.org/MPL/2.0/.
\file Test general properties of the KdTree
*/

#include "../common/testing.h"
#include "../common/testUtils.h"
#include "../common/has_duplicate.h"
#include "../common/kdtree_utils.h"

#include <Ponca/src/SpatialPartitioning/KdTree/kdTree.h>
#include <Ponca/src/SpatialPartitioning/KdTree/kdTreeTraits.h>

using namespace Ponca;


template<typename DataPoint>
void testKdtreeWithDuplicate()
{
using Scalar = typename DataPoint::Scalar;
using VectorContainer = typename KdTreeSparse<DataPoint>::PointContainer;
using VectorType = typename DataPoint::VectorType;

// Number of point samples in each KdTree leaf
#ifdef PONCA_DEBUG
const int cellSize = 6;
const int nbCells = 2;
const int N = nbCells*cellSize;
#else
const int cellSize = 64;
const int nbCells = 100;
const int N = nbCells*cellSize;
#endif

const Scalar r = 0.001;

auto test_tree = [r] (const auto& points, const auto&indices, const int cellSize) -> void
{
KdTreeSparse<DataPoint> tree(points, indices, cellSize);

#ifndef PONCA_DEBUG
#pragma omp parallel for default(none) shared(tree, points, indices, g_test_stack, r)
#endif
for (int i = 0; i < points.size(); ++i)
{
VectorType point = points[i].pos();//VectorType::Random(); // values between [-1:1]
std::vector<int> results;

for (int j : tree.range_neighbors(point, r)) {
results.push_back(j);
}

bool res = check_range_neighbors<Scalar, VectorType, VectorContainer>(points, indices, point, r, results, true);
VERIFY(res);
}
};


// Generate N random points
typename KdTreeDense<DataPoint>::IndexContainer ids(N);
std::iota(ids.begin(), ids.end(), 0);

auto points = VectorContainer(N);
std::generate(points.begin(), points.end(), []() {return DataPoint(VectorType::Random()); });

// Test on 100% random points
{
test_tree(points, ids, cellSize);
}

}

template<typename NodeType>
void testKdTreeNode() {
std::vector<NodeType> buffer;

buffer.resize(10);

// simple predicate that only check if a node is a leaf or not
auto nodePredicate = [](const NodeType& n1, const NodeType& n2) -> bool {
return n1.is_leaf() == n2.is_leaf();
};

auto checkProperties = [nodePredicate](std::vector<NodeType>& buf, bool targetLeafState) -> void{
// Check that references works well:
for (auto& b : buf ) b.set_is_leaf(targetLeafState);
for (const auto& b : buf ) VERIFY( b.is_leaf() == targetLeafState );

// Check that copies are working well
std::vector<NodeType> other;
other.reserve(buf.size());
other = buf;
VERIFY(std::equal(buf.cbegin(), buf.cend(), other.cbegin(), other.cend(), nodePredicate));

// Check that reallocation works well
other.resize(buf.size()*2);
VERIFY(std::equal(buf.cbegin(), buf.cend(), other.cbegin(), other.cbegin()+buf.size(), nodePredicate));
};

checkProperties(buffer, true);
checkProperties(buffer, false);
}

int main(int argc, char** argv)
{
if (!init_testing(argc, argv))
{
return EXIT_FAILURE;
}

using PointType = TestPoint<float, 3>;
using KdTreeTraits = KdTreeDefaultTraits<PointType>;

cout << "Test KdTreeDefaultNode" << endl;
testKdTreeNode<typename KdTreeTraits::NodeType>();

cout << "Test KdTreeRange with large number of duplicated points" << endl;
testKdtreeWithDuplicate<PointType>();

}

0 comments on commit ff12255

Please sign in to comment.