Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into fix375
Browse files Browse the repository at this point in the history
  • Loading branch information
JinHai-CN committed Nov 15, 2024
2 parents c0c9c50 + f9463c8 commit 6c068d7
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 13 deletions.
5 changes: 4 additions & 1 deletion docs/references/http_api_reference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -1954,7 +1954,10 @@ curl --request GET \
- `"topn"`: `int`, *Required*
An integer indicating the number of nearest neighbours (vector search) or most relevant rows (full-text search) to return.
- `"params"`: `object`, *Optional*
Additional matching or reranking parameters.
Additional matching or reranking parameters.
- For all matching methods, you can provide a index hint when searching, so that the query will run using specified index.
- For `"match_method" : "dense"`, `"match_method" : "sparse"`, `"match_method" : "tensor"`, use `"index_name" : "idx1"` to specify index, note that if the index specified is not present, an error would occur.
- For `"match_method" : "text"`, since the query can run on multiple columns, you should specify multiple indexes using `"index_names" : "idx1,idx2..."`, there will be no error when some of specified indexes are not present.
- If you set `"match_method"` to `"dense"`:
- `"ef"`: `str`, Recommended value: one to ten times the value of `topn`.
- For example, if you set `topn` to `10`, you can set `"ef"` to `"50"`.
Expand Down
3 changes: 3 additions & 0 deletions docs/references/pysdk_api_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -1834,6 +1834,7 @@ A dictionary representing additional KNN or ANN search parameters.
- `"threshold"`: `str`, *Optional* A threshold value for the search.
- For example, if you use the `"cosine"` distance metric and set `"threshold"` to `"0.5"`, the search will return only those rows with a cosine similarity greater than `0.5`.
- `"nprobe"`: `str`, *Optional* The number of cells to search for the IVF index. The default value is `"1"`.
- `"index_name"` : `str`, *Optional* The name of index on which you would like the database to perform query on.

#### Returns

Expand Down Expand Up @@ -1939,6 +1940,7 @@ A dictionary representing additional parameters for the sparse vector search. Fo
`"0.0"` ~ `"1.0"` (default: `"1.0"`) - A "Termination Conditions" parameter. The smaller the value, the more aggressive the pruning.
- `"beta"`: `str`
`"0.0"` ~ `"1.0"` (default: `"1.0"`) - A "Query Term Pruning" parameter. The smaller the value, the more aggressive the pruning.
- `"index_name"` : `str`, *Optional* The name of index on which you would like the database to perform query on.

#### Returns

Expand Down Expand Up @@ -2061,6 +2063,7 @@ An optional dictionary specifying the following search options:
For example, reinterprets `"A01-233:BC"` as `'"A01" OR "-233" OR "BC"'`.
- `{"operator": "and"}`: Interpolates the `AND` operator between words in `matching_text` to create a new search text.
For example, reinterprets `"A01-233:BC"` as `'"A01" AND "-233" AND "BC"'`.
- **`"index_name"`** : `str`, *Optional* The names of indexes on which you would like the database to perform query on.

#### Returns

Expand Down
143 changes: 142 additions & 1 deletion python/test_pysdk/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1833,4 +1833,145 @@ def test_binary_embedding_hamming_distance(self, check_data, knn_distance_type,
pd.testing.assert_frame_equal(res, pd.DataFrame(
{'c1' : (0, 1, 2), 'c2' : (['0100000000000000'], ['0000000000000001'], ['0000000000000011']), 'DISTANCE' : (1.0, 1.0, 2.0)}
).astype({'c1': dtype('int32'), 'DISTANCE' : dtype('float32')}))
db_obj.drop_table("test_binary_knn_hamming_distance" + suffix, ConflictType.Error)
db_obj.drop_table("test_binary_knn_hamming_distance" + suffix, ConflictType.Error)

@pytest.mark.parametrize("check_data", [{"file_name": "sparse_knn.csv",
"data_dir": common_values.TEST_TMP_DIR}], indirect=True)
def test_match_sparse_index_hint(self, check_data, suffix):
db_obj = self.infinity_obj.get_database("default_db")
db_obj.drop_table("test_sparse_knn_with_index"+suffix, ConflictType.Ignore)
table_obj = db_obj.create_table("test_sparse_knn_with_index"+suffix,
{"c1": {"type": "int"}, "c2": {"type": "sparse,100,float,int8"}},
ConflictType.Error)
if not check_data:
copy_data("sparse_knn.csv")
test_csv_dir = common_values.TEST_TMP_DIR + "sparse_knn.csv"
table_obj.import_data(test_csv_dir, import_options={"delimiter": ","})
table_obj.create_index("idx1",
index.IndexInfo("c2",
index.IndexType.BMP,
{"block_size": "8", "compress_type": "compress"}), ConflictType.Error)
table_obj.create_index("idx2",
index.IndexInfo("c2",
index.IndexType.BMP,
{"block_size": "8", "compress_type": "raww"}), ConflictType.Error)

table_obj.optimize("idx1", {"topk": "3"})

res = (table_obj
.output(["*", "_row_id", "_similarity"])
.match_sparse("c2", SparseVector(**{"indices": [0, 20, 80], "values": [1.0, 2.0, 3.0]}), "ip", 3,
{"alpha": "1.0", "beta": "1.0"})
.to_pl())
print(res)
pd.testing.assert_frame_equal(table_obj.output(["c1", "_similarity"])
.match_sparse("c2",
SparseVector(**{"indices": [0, 20, 80], "values": [1.0, 2.0, 3.0]}),
"ip", 3, {"alpha": "1.0", "beta": "1.0", "index_name":"idx1"})
.to_df(), pd.DataFrame({'c1': [4, 2, 1], 'SIMILARITY': [16.0, 12.0, 6.0]}).astype(
{'c1': dtype('int32'), 'SIMILARITY': dtype('float32')}))
pd.testing.assert_frame_equal(table_obj.output(["c1", "_similarity"])
.match_sparse("c2",
SparseVector(**{"indices": [0, 20, 80], "values": [1.0, 2.0, 3.0]}),
"ip", 3, {"alpha": "1.0", "beta": "1.0", "threshold": "10", "index_name" : "idx2"})
.to_df(), pd.DataFrame({'c1': [4, 2], 'SIMILARITY': [16.0, 12.0]}).astype(
{'c1': dtype('int32'), 'SIMILARITY': dtype('float32')}))

# non-existent index
with pytest.raises(InfinityException) as e:
res = table_obj.output(["c1", "_similarity"]).match_sparse("c2",
SparseVector(**{"indices": [0, 20, 80], "values": [1.0, 2.0, 3.0]}),
"ip", 3,
{"alpha": "1.0", "beta": "1.0", "threshold": "10", "index_name" : "idx8"}
).to_pl()
assert e.value.args[0] == ErrorCode.INDEX_NOT_EXIST

res = table_obj.drop_index("idx2", ConflictType.Error)
assert res.error_code == ErrorCode.OK

res = table_obj.drop_index("idx1", ConflictType.Error)
assert res.error_code == ErrorCode.OK

res = db_obj.drop_table("test_sparse_knn_with_index"+suffix, ConflictType.Error)
assert res.error_code == ErrorCode.OK

@pytest.mark.parametrize("match_param_1", ["body^5"])
@pytest.mark.parametrize("check_data", [{"file_name": "enwiki_embedding_99_commas.csv",
"data_dir": common_values.TEST_TMP_DIR}], indirect=True)
def test_match_text_index_hints(self, check_data, match_param_1, suffix):
db_obj = self.infinity_obj.get_database("default_db")
db_obj.drop_table(
"test_with_fulltext_match_with_valid_columns"+suffix, ConflictType.Ignore)
table_obj = db_obj.create_table("test_with_fulltext_match_with_valid_columns"+suffix,
{"doctitle": {"type": "varchar"},
"docdate": {"type": "varchar"},
"body": {"type": "varchar"},
"num": {"type": "int"},
"vec": {"type": "vector, 4, float"}})
table_obj.create_index("my_index",
index.IndexInfo("body",
index.IndexType.FullText,
{"ANALYZER": "standard"}),
ConflictType.Error)

if not check_data:
generate_commas_enwiki(
"enwiki_99.csv", "enwiki_embedding_99_commas.csv", 1)
copy_data("enwiki_embedding_99_commas.csv")

test_csv_dir = common_values.TEST_TMP_DIR + "enwiki_embedding_99_commas.csv"
table_obj.import_data(test_csv_dir, import_options={"delimiter": ","})
res = (table_obj
.output(["*"])
.match_dense("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 1)
.match_text(match_param_1, "black", 1, {"index_names" : "my_index"})
.fusion(method='rrf', topn=10)
.to_pl())
print(res)
res_filter_1 = (table_obj
.output(["*"])
.match_dense("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 1)
.match_text(match_param_1, "black", 1, {"index_names" : "index1"})
.fusion(method='rrf', topn=10)
.filter("num!=98 AND num != 12")
.to_pl())
print(res_filter_1)
pl_assert_frame_not_equal(res, res_filter_1)
res_filter_2 = (table_obj
.output(["*"])
.match_dense("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 1, {"filter": "num!=98 AND num != 12"})
.match_text(match_param_1, "black", 1, {"filter": "num!=98 AND num != 12"})
.fusion(method='rrf', topn=10)
.to_pl())
print(res_filter_2)
pl_assert_frame_equal(res_filter_1, res_filter_2)

# filter_fulltext = "num!=98 AND num != 12 AND filter_fulltext('body', 'harmful chemical')"
# filter_fulltext = """num!=98 AND num != 12 AND filter_fulltext('body', '(harmful OR chemical)')"""
# filter_fulltext = """num!=98 AND num != 12 AND filter_fulltext('body', '("harmful" OR "chemical")')"""
# filter_fulltext = """(num!=98 AND num != 12) AND filter_fulltext('body', '(("harmful" OR "chemical"))')"""
filter_fulltext = """(num!=98 AND num != 12) AND filter_fulltext('body^3,body,body^2', '(("harmful" OR "chemical"))', 'indexes=my_index')"""
_ = (table_obj
.output(["*"])
.match_dense("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 1, {"filter": filter_fulltext})
.match_text(match_param_1, "black", 1, {"filter": "num!=98 AND num != 12"})
.fusion(method='rrf', topn=10)
.to_pl())

with pytest.raises(InfinityException) as e_info:
res_filter_3 = (table_obj
.output(["*"])
.match_dense("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 1, {"filter": "num!=98 AND num != 12"})
.match_text(match_param_1, "black", 1, {"filter": "num!=98 AND num != 12", "index_names" : 'my_index'})
.fusion(method='rrf', topn=10)
.filter("num!=98 AND num != 12")
.to_pl())
print(e_info)

res = table_obj.drop_index("my_index", ConflictType.Error)
assert res.error_code == ErrorCode.OK

res = db_obj.drop_table(
"test_with_fulltext_match_with_valid_columns"+suffix, ConflictType.Error)
assert res.error_code == ErrorCode.OK

21 changes: 19 additions & 2 deletions src/embedded_infinity/wrap_infinity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import table_def;
import third_party;
import logger;
import query_options;
import search_options;
import defer_op;

namespace infinity {
Expand Down Expand Up @@ -357,6 +358,10 @@ ParsedExpr *WrapMatchExpr::GetParsedExpr(Status &status) {
match_expr->fields_ = fields;
match_expr->matching_text_ = matching_text;
match_expr->options_text_ = options_text;
SearchOptions options(match_expr->options_text_);
if (options.options_.find("index_names") != options.options_.end()) {
match_expr->index_names_ = options.options_["index_names"];
}
if (filter_expr) {
match_expr->filter_expr_.reset(filter_expr->GetParsedExpr(status));
}
Expand Down Expand Up @@ -389,6 +394,10 @@ ParsedExpr *WrapMatchTensorExpr::GetParsedExpr(Status &status) {
match_tensor_expr->column_expr_.reset(column_expr.GetParsedExpr(status));
match_tensor_expr->options_text_ = options_text;
match_tensor_expr->embedding_data_type_ = embedding_data_type;
SearchOptions options(match_tensor_expr->options_text_);
if (options.options_.find("index_name") != options.options_.end()) {
match_tensor_expr->index_name_ = options.options_["index_name"];
}

if (status.code_ != ErrorCode::kOk) {
return nullptr;
Expand Down Expand Up @@ -433,10 +442,18 @@ ParsedExpr *WrapMatchSparseExpr::GetParsedExpr(Status &status) {

auto *opt_params_ptr = new Vector<InitParameter *>();
for (auto &param : opt_params) {
auto *init_parameter = new InitParameter();
if (param.param_name_ == "index_name") {
match_sparse_expr->index_name_ = param.param_value_;
continue;
}
if (param.param_name_ == "ignore_index" && param.param_value_ == "true") {
match_sparse_expr->ignore_index_ = true;
continue;
}
auto init_parameter = MakeUnique<InitParameter>();
init_parameter->param_name_ = param.param_name_;
init_parameter->param_value_ = param.param_value_;
opt_params_ptr->emplace_back(init_parameter);
opt_params_ptr->emplace_back(init_parameter.release());
}
match_sparse_expr->SetOptParams(topn, opt_params_ptr);

Expand Down
5 changes: 3 additions & 2 deletions src/expression/match_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,9 @@ void ParseMultiIndexHints(const String &index_hints, Vector<String> &index_names
if (comma_idx == String::npos) {
auto index_name = index_hints.substr(begin_idx);
index_names.emplace_back(index_name);
LOG_TRACE(fmt::format("new index hint : {}", index_name));
break;
} else {
auto index_name = index_hints.substr(begin_idx, comma_idx - begin_idx);
LOG_TRACE(fmt::format("new index hint : {}", index_name));
begin_idx = comma_idx + 1;
}
}
Expand Down Expand Up @@ -70,6 +68,9 @@ u64 MatchExpression::Hash() const {
h ^= std::hash<String>()(fields_);
h ^= std::hash<String>()(matching_text_);
h ^= std::hash<String>()(options_text_);
for (SizeT i = 0; i < index_names_.size(); i++) {
h ^= std::hash<String>()(index_names_[i]);
}
return h;
}

Expand Down
8 changes: 5 additions & 3 deletions src/expression/match_sparse_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ u64 MatchSparseExpression::Hash() const {
h ^= std::hash<SparseMetricType>()(metric_type_);
h ^= std::hash<SizeT>()(query_n_);
h ^= std::hash<SizeT>()(topn_);
h ^= std::hash<String>()(index_name_);
if (optional_filter_) {
h ^= optional_filter_->Hash();
}
Expand All @@ -126,7 +127,7 @@ bool MatchSparseExpression::Eq(const BaseExpression &other_base) const {
}
}
bool eq = column_expr_->Eq(*other.column_expr_) && query_sparse_expr_->Eq(*other.query_sparse_expr_) && metric_type_ == other.metric_type_ &&
query_n_ == other.query_n_ && topn_ == other.topn_;
query_n_ == other.query_n_ && topn_ == other.topn_ && index_name_ == other.index_name_;
if (!eq) {
return false;
}
Expand Down Expand Up @@ -160,13 +161,14 @@ String MatchSparseExpression::ToString() const {
}
String opt_str = ss.str();

return fmt::format("MATCH SPARSE ({}, [{}], {}, {}{}) WITH ({})",
return fmt::format("MATCH SPARSE ({}, [{}], {}, {}{}) WITH ({}) USING INDEX ({})",
column_expr_->Name(),
sparse_str,
MatchSparseExpr::MetricTypeToString(metric_type_),
topn_,
optional_filter_ ? fmt::format(", WHERE {}", optional_filter_->ToString()) : "",
opt_str);
opt_str,
index_name_);
}

} // namespace infinity
4 changes: 3 additions & 1 deletion src/expression/match_tensor_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ u64 MatchTensorExpression::Hash() const {
h ^= std::hash<u32>()(dimension_);
h ^= std::hash<u32>()(tensor_basic_embedding_dimension_);
h ^= std::hash<String>()(options_text_);
h ^= std::hash<String>()(index_name_);
return h;
}

Expand All @@ -92,7 +93,8 @@ bool MatchTensorExpression::Eq(const BaseExpression &other_base) const {
}
bool eq = search_method_ == other.search_method_ && column_expr_->Eq(*other.column_expr_) && embedding_data_type_ == other.embedding_data_type_ &&
dimension_ == other.dimension_ && query_embedding_.Eq(other.query_embedding_, embedding_data_type_, dimension_) &&
tensor_basic_embedding_dimension_ == other.tensor_basic_embedding_dimension_ && options_text_ == other.options_text_;
tensor_basic_embedding_dimension_ == other.tensor_basic_embedding_dimension_ && options_text_ == other.options_text_ &&
index_name_ == other.index_name_;
return eq;
}

Expand Down
Loading

0 comments on commit 6c068d7

Please sign in to comment.