Skip to content

Commit

Permalink
add range check
Browse files Browse the repository at this point in the history
Signed-off-by: xianliang.li <xianliang.li@zilliz.com>
  • Loading branch information
foxspy committed Oct 31, 2024
1 parent 8e0150b commit d33e629
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 19 deletions.
54 changes: 37 additions & 17 deletions include/knowhere/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,32 @@ typedef nlohmann::json Json;
#define CFG_MATERIALIZED_VIEW_SEARCH_INFO_TYPE std::optional<knowhere::MaterializedViewSearchInfo>
#endif

template <typename T>
struct Range {
T left;
T right;
bool include_left;
bool include_right;

Range(T left, T right, bool includeLeft, bool includeRight)
: left(left), right(right), include_left(includeLeft), include_right(includeRight) {
}

bool
within(T val) {
bool left_range_check = left < val || (include_left && left <= val);
bool right_range_check = val < right || (include_right && val <= right);
return left_range_check && right_range_check;
}

std::string
to_string() {
std::string left_mark = include_left ? "[" : "(";
std::string right_mark = include_right ? "]" : ")";
return left_mark + std::to_string(left) + ", " + std::to_string(right) + right_mark;
}
};

template <typename T>
struct Entry {};

Expand Down Expand Up @@ -114,7 +140,7 @@ struct Entry<CFG_FLOAT> {
CFG_FLOAT* val;
std::optional<CFG_FLOAT::value_type> default_val;
uint32_t type;
std::optional<std::pair<CFG_FLOAT::value_type, CFG_FLOAT::value_type>> range;
std::optional<Range<CFG_FLOAT::value_type>> range;
std::optional<std::string> desc;
bool allow_empty_without_default = false;
};
Expand All @@ -139,7 +165,7 @@ struct Entry<CFG_INT> {
CFG_INT* val;
std::optional<CFG_INT::value_type> default_val;
uint32_t type;
std::optional<std::pair<CFG_INT::value_type, CFG_INT::value_type>> range;
std::optional<Range<CFG_INT::value_type>> range;
std::optional<std::string> desc;
bool allow_empty_without_default = false;
};
Expand All @@ -164,7 +190,7 @@ struct Entry<CFG_INT64> {
CFG_INT64* val;
std::optional<CFG_INT64::value_type> default_val;
uint32_t type;
std::optional<std::pair<CFG_INT64::value_type, CFG_INT64::value_type>> range;
std::optional<Range<CFG_INT64::value_type>> range;
std::optional<std::string> desc;
bool allow_empty_without_default = false;
};
Expand Down Expand Up @@ -228,8 +254,8 @@ class EntryAccess {
}

EntryAccess&
set_range(typename T::value_type a, typename T::value_type b) {
entry->range = std::make_pair(a, b);
set_range(typename T::value_type a, typename T::value_type b, bool include_left = true, bool include_right = true) {
entry->range = Range<typename T::value_type>(a, b, include_left, include_right);
return *this;
}

Expand Down Expand Up @@ -360,13 +386,11 @@ class Config {
}
CFG_INT::value_type v = json[it.first];
auto range_val = ptr->range.value();
if (range_val.first <= v && v <= range_val.second) {
if (range_val.within(v)) {
*ptr->val = v;
} else {
std::string msg = "Out of range in json: param '" + it.first + "' (" +
to_string(json[it.first]) + ") should be in range [" +
std::to_string(range_val.first) + ", " + std::to_string(range_val.second) +
"]";
to_string(json[it.first]) + ") should be in range " + range_val.to_string();
show_err_msg(msg);
return Status::out_of_range_in_json;
}
Expand Down Expand Up @@ -408,13 +432,11 @@ class Config {
}
CFG_INT64::value_type v = json[it.first];
auto range_val = ptr->range.value();
if (range_val.first <= v && v <= range_val.second) {
if (range_val.within(v)) {
*ptr->val = v;
} else {
std::string msg = "Out of range in json: param '" + it.first + "' (" +
to_string(json[it.first]) + ") should be in range [" +
std::to_string(range_val.first) + ", " + std::to_string(range_val.second) +
"]";
to_string(json[it.first]) + ") should be in range " + range_val.to_string();
show_err_msg(msg);
return Status::out_of_range_in_json;
}
Expand Down Expand Up @@ -456,13 +478,11 @@ class Config {
}
CFG_FLOAT::value_type v = json[it.first];
auto range_val = ptr->range.value();
if (range_val.first <= v && v <= range_val.second) {
if (range_val.within(v)) {
*ptr->val = v;
} else {
std::string msg = "Out of range in json: param '" + it.first + "' (" +
to_string(json[it.first]) + ") should be in range [" +
std::to_string(range_val.first) + ", " + std::to_string(range_val.second) +
"]";
to_string(json[it.first]) + ") should be in range " + range_val.to_string();
show_err_msg(msg);
return Status::out_of_range_in_json;
}
Expand Down
4 changes: 2 additions & 2 deletions src/index/sparse/sparse_inverted_index_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ class SparseInvertedIndexConfig : public BaseConfig {
KNOWHERE_CONFIG_DECLARE_FIELD(drop_ratio_build)
.description("drop ratio for build")
.set_default(0.0f)
.set_range(0.0f, 1.0f)
.set_range(0.0f, 1.0f, true, false)
.for_train();
KNOWHERE_CONFIG_DECLARE_FIELD(drop_ratio_search)
.description("drop ratio for search")
.set_default(0.0f)
.set_range(0.0f, 1.0f)
.set_range(0.0f, 1.0f, true, false)
.for_search()
.for_range_search()
.for_iterator();
Expand Down
26 changes: 26 additions & 0 deletions tests/ut/test_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "knowhere/version.h"
#ifdef KNOWHERE_WITH_DISKANN
#include "index/diskann/diskann_config.h"
#include "index/sparse/sparse_inverted_index_config.h"
#endif
#ifdef KNOWHERE_WITH_RAFT
#include "index/gpu_raft/gpu_raft_cagra_config.h"
Expand Down Expand Up @@ -110,6 +111,31 @@ TEST_CASE("Test config json parse", "[config]") {
CHECK(test_config.dim.value() == 10000000000L);
}

SECTION("check range data values") {
auto sparse_valid = GENERATE(as<std::string>{},
R"({
"drop_ratio_build": 0.0
})");
knowhere::BaseConfig test_config;
knowhere::Json test_json = knowhere::Json::parse(sparse_valid);
s = knowhere::Config::FormatAndCheck(test_config, test_json);
CHECK(s == knowhere::Status::success);
s = knowhere::Config::Load(test_config, test_json, knowhere::TRAIN);
CHECK(s == knowhere::Status::success);

auto sparse_invalid = GENERATE(as<std::string>{},
R"({
"drop_ratio_build": 1.0
})");

knowhere::SparseInvertedIndexConfig test_invalid_config;
knowhere::Json test_invalid_json = knowhere::Json::parse(sparse_invalid);
s = knowhere::Config::FormatAndCheck(test_invalid_config, test_invalid_json);
CHECK(s == knowhere::Status::success);
s = knowhere::Config::Load(test_invalid_config, test_invalid_json, knowhere::TRAIN);
CHECK(s == knowhere::Status::out_of_range_in_json);
}

SECTION("check invalid json values") {
auto invalid_json_str = GENERATE(as<std::string>{},
R"({
Expand Down

0 comments on commit d33e629

Please sign in to comment.