diff --git a/include/LightGBM/utils/array_args.h b/include/LightGBM/utils/array_args.h index 0183ecc22ddb..3c590bad4030 100644 --- a/include/LightGBM/utils/array_args.h +++ b/include/LightGBM/utils/array_args.h @@ -103,7 +103,9 @@ class ArrayArgs { int j = end - 1; int p = i; int q = j; - if (start >= end) { + if (start >= end - 1) { + *l = start - 1; + *r = end; return; } std::vector& ref = *arr; diff --git a/tests/cpp_tests/test_array_args.cpp b/tests/cpp_tests/test_array_args.cpp new file mode 100644 index 000000000000..2b46f93ae062 --- /dev/null +++ b/tests/cpp_tests/test_array_args.cpp @@ -0,0 +1,52 @@ +/*! + * Copyright (c) 2021 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ + +#include +#include +#include + +#include + +using LightGBM::data_size_t; +using LightGBM::score_t; +using LightGBM::ArrayArgs; + + +TEST(Partition, JustWorks) { + std::vector gradients({0.5f, 5.0f, 1.0f, 2.0f, 2.0f}); + data_size_t middle_begin, middle_end; + + ArrayArgs::Partition(&gradients, 0, gradients.size(), &middle_begin, &middle_end); + + EXPECT_EQ(gradients[middle_begin + 1], gradients[middle_end - 1]); + EXPECT_GT(gradients[0], gradients[middle_begin + 1]); + EXPECT_GT(gradients[middle_begin + 1], gradients.back()); +} + +TEST(Partition, PartitionOneElement) { + std::vector gradients({0.5f}); + data_size_t middle_begin, middle_end; + ArrayArgs::Partition(&gradients, 0, gradients.size(), &middle_begin, &middle_end); + EXPECT_EQ(gradients[middle_begin + 1], gradients[middle_end - 1]); +} + +TEST(Partition, Empty) { + std::vector gradients; + data_size_t middle_begin, middle_end; + ArrayArgs::Partition(&gradients, 0, gradients.size(), &middle_begin, &middle_end); + + EXPECT_EQ(middle_begin, -1); + EXPECT_EQ(middle_end, 0); +} + +TEST(Partition, AllEqual) { + std::vector gradients({0.5f, 0.5f, 0.5f}); + data_size_t middle_begin, middle_end; + ArrayArgs::Partition(&gradients, 0, gradients.size(), &middle_begin, &middle_end); + + EXPECT_EQ(gradients[middle_begin + 1], gradients[middle_end - 1]); + EXPECT_EQ(middle_begin, -1); + EXPECT_EQ(middle_end, 3); +}