Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support to range filter with json expr #23739

Merged
merged 1 commit into from
Apr 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 41 additions & 33 deletions internal/core/src/common/Json.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

namespace milvus {
using document = simdjson::ondemand::document;
template <typename T>
using value_result = simdjson::simdjson_result<T>;
class Json {
public:
Json() = default;
Expand All @@ -41,9 +43,6 @@ class Json {
data_ = own_data_.value();
}

explicit Json(simdjson::padded_string_view data) : data_(data) {
}

Json(const char* data, size_t len, size_t cap) : data_(data, len) {
AssertInfo(len + simdjson::SIMDJSON_PADDING <= cap,
fmt::format("create json without enough memory size for "
Expand All @@ -58,28 +57,40 @@ class Json {
Json(const char* data, size_t len) : data_(data, len) {
}

Json(Json&& json) = default;
Json(const Json& json) {
if (json.own_data_.has_value()) {
own_data_ = simdjson::padded_string(
json.own_data_.value().data(), json.own_data_.value().length());
data_ = own_data_.value();
} else {
data_ = json.data_;
}
};
Json(Json&& json) noexcept {
if (json.own_data_.has_value()) {
own_data_ = std::move(json.own_data_);
data_ = own_data_.value();
} else {
data_ = json.data_;
}
}

Json&
operator=(const Json& json) {
if (json.own_data_.has_value()) {
own_data_ = simdjson::padded_string(
json.own_data_.value().data(), json.own_data_.value().length());
data_ = own_data_.value();
} else {
data_ = json.data_;
}

data_ = json.data_;
return *this;
}

operator std::string_view() const {
return data_;
}

void
parse(simdjson::padded_string_view data) {
data_ = data;
}

document
doc() const {
thread_local simdjson::ondemand::parser parser;
Expand All @@ -91,36 +102,33 @@ class Json {
parser.iterate(data_, data_.size() + simdjson::SIMDJSON_PADDING)
.get(doc);
AssertInfo(err == simdjson::SUCCESS,
fmt::format("failed to parse the json: {}", err));
fmt::format("failed to parse the json {}: {}",
data_,
simdjson::error_message(err)));
return doc;
}

simdjson::ondemand::value
operator[](const std::string_view field) const {
simdjson::ondemand::value result;
auto err = doc().get_value()[field].get(result);
AssertInfo(
err == simdjson::SUCCESS,
fmt::format("failed to access the field {}: {}", field, err));
return result;
bool
exist(std::vector<std::string> nested_path) const {
std::for_each(
nested_path.begin(), nested_path.end(), [](std::string& key) {
boost::replace_all(key, "~", "~0");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leave the comment here, make sure code reviewer understand it's the requirement of simdJson

boost::replace_all(key, "/", "~1");
});
auto pointer = "/" + boost::algorithm::join(nested_path, "/");
return doc().at_pointer(pointer).error() == simdjson::SUCCESS;
}

simdjson::ondemand::value
operator[](std::vector<std::string> nested_path) const {
template <typename T>
value_result<T>
at(std::vector<std::string> nested_path) const {
std::for_each(
nested_path.begin(), nested_path.end(), [](std::string& key) {
boost::replace_all(key, "~", "~0");
boost::replace_all(key, "/", "~1");
});
auto pointer = boost::algorithm::join(nested_path, "/");
simdjson::ondemand::value result;
auto err = doc().at_pointer(pointer).get(result);
AssertInfo(
err == simdjson::SUCCESS,
fmt::format("failed to access the field with json pointer {}: {}",
pointer,
err));
return result;
auto pointer = "/" + boost::algorithm::join(nested_path, "/");
return doc().at_pointer(pointer).get<T>();
}

std::string_view
Expand All @@ -130,7 +138,7 @@ class Json {

private:
std::optional<simdjson::padded_string>
own_data_; // this could be empty, then the Json will be just s view on bytes
simdjson::padded_string_view data_;
own_data_{}; // this could be empty, then the Json will be just s view on bytes
simdjson::padded_string_view data_{};
};
} // namespace milvus
51 changes: 37 additions & 14 deletions internal/core/src/query/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <vector>

#include "common/Schema.h"
#include "common/Types.h"
#include "pb/plan.pb.h"

namespace milvus::query {
Expand All @@ -33,6 +34,27 @@ using optype = proto::plan::OpType;

class ExprVisitor;

struct ColumnInfo {
FieldId field_id;
DataType data_type;
std::vector<std::string> nested_path;

ColumnInfo(const proto::plan::ColumnInfo& column_info)
: field_id(column_info.field_id()),
data_type(static_cast<DataType>(column_info.data_type())),
nested_path(column_info.nested_path().begin(),
column_info.nested_path().end()) {
}

ColumnInfo(FieldId field_id,
DataType data_type,
std::vector<std::string> nested_path = {})
: field_id(field_id),
data_type(data_type),
nested_path(std::move(nested_path)) {
}
};

// Base of all Exprs
struct Expr {
public:
Expand Down Expand Up @@ -132,21 +154,22 @@ static const std::map<ArithOpType, std::string> mapping_arith_op_ = {
};

struct BinaryArithOpEvalRangeExpr : Expr {
const FieldId field_id_;
const DataType data_type_;
const ColumnInfo column_;
const proto::plan::GenericValue::ValCase val_case_;
const OpType op_type_;
const ArithOpType arith_op_;

protected:
// prevent accidental instantiation
BinaryArithOpEvalRangeExpr() = delete;

BinaryArithOpEvalRangeExpr(const FieldId field_id,
const DataType data_type,
const OpType op_type,
const ArithOpType arith_op)
: field_id_(field_id),
data_type_(data_type),
BinaryArithOpEvalRangeExpr(
ColumnInfo column,
const proto::plan::GenericValue::ValCase val_case,
const OpType op_type,
const ArithOpType arith_op)
: column_(std::move(column)),
val_case_(val_case),
op_type_(op_type),
arith_op_(arith_op) {
}
Expand Down Expand Up @@ -189,21 +212,21 @@ struct UnaryRangeExpr : Expr {
};

struct BinaryRangeExpr : Expr {
const FieldId field_id_;
const DataType data_type_;
const ColumnInfo column_;
const proto::plan::GenericValue::ValCase val_case_;
const bool lower_inclusive_;
const bool upper_inclusive_;

protected:
// prevent accidental instantiation
BinaryRangeExpr() = delete;

BinaryRangeExpr(const FieldId field_id,
const DataType data_type,
BinaryRangeExpr(ColumnInfo column,
const proto::plan::GenericValue::ValCase val_case,
const bool lower_inclusive,
const bool upper_inclusive)
: field_id_(field_id),
data_type_(data_type),
: column_(std::move(column)),
val_case_(val_case),
lower_inclusive_(lower_inclusive),
upper_inclusive_(upper_inclusive) {
}
Expand Down
28 changes: 17 additions & 11 deletions internal/core/src/query/ExprImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
#pragma once

#include <tuple>
#include <utility>
#include <vector>
#include <boost/container/vector.hpp>

#include "Expr.h"
#include "pb/plan.pb.h"

namespace milvus::query {

Expand All @@ -40,13 +42,15 @@ struct BinaryArithOpEvalRangeExprImpl : BinaryArithOpEvalRangeExpr {
const T right_operand_;
const T value_;

BinaryArithOpEvalRangeExprImpl(const FieldId field_id,
const DataType data_type,
const ArithOpType arith_op,
const T right_operand,
const OpType op_type,
const T value)
: BinaryArithOpEvalRangeExpr(field_id, data_type, op_type, arith_op),
BinaryArithOpEvalRangeExprImpl(
ColumnInfo column,
const proto::plan::GenericValue::ValCase val_case,
const ArithOpType arith_op,
const T right_operand,
const OpType op_type,
const T value)
: BinaryArithOpEvalRangeExpr(
std::forward<ColumnInfo>(column), val_case, op_type, arith_op),
right_operand_(right_operand),
value_(value) {
}
Expand All @@ -69,14 +73,16 @@ struct BinaryRangeExprImpl : BinaryRangeExpr {
const T lower_value_;
const T upper_value_;

BinaryRangeExprImpl(const FieldId field_id,
const DataType data_type,
BinaryRangeExprImpl(ColumnInfo column,
const proto::plan::GenericValue::ValCase val_case,
const bool lower_inclusive,
const bool upper_inclusive,
const T lower_value,
const T upper_value)
: BinaryRangeExpr(
field_id, data_type, lower_inclusive, upper_inclusive),
: BinaryRangeExpr(std::forward<ColumnInfo>(column),
val_case,
lower_inclusive,
upper_inclusive),
lower_value_(lower_value),
upper_value_(upper_value) {
}
Expand Down
14 changes: 9 additions & 5 deletions internal/core/src/query/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include "Plan.h"
#include "generated/ExtractInfoPlanNodeVisitor.h"
#include "generated/VerifyPlanNodeVisitor.h"
#include "pb/plan.pb.h"
#include "query/Expr.h"

namespace milvus::query {

Expand Down Expand Up @@ -92,7 +94,7 @@ Parser::ParseRangeNode(const Json& out_body) {
Assert(out_body.size() == 1);
auto out_iter = out_body.begin();
auto field_name = FieldName(out_iter.key());
auto body = out_iter.value();
auto& body = out_iter.value();
auto data_type = schema[field_name].get_data_type();
Assert(!datatype_is_vector(data_type));

Expand Down Expand Up @@ -302,8 +304,9 @@ Parser::ParseRangeNodeImpl(const FieldName& field_name, const Json& body) {
}

return std::make_unique<BinaryArithOpEvalRangeExprImpl<T>>(
schema.get_field_id(field_name),
schema[field_name].get_data_type(),
ColumnInfo(schema.get_field_id(field_name),
schema[field_name].get_data_type()),
proto::plan::GenericValue::ValCase::VAL_NOT_SET,
arith_op_mapping_.at(arith_op_name),
right_operand,
mapping_.at(op_name),
Expand Down Expand Up @@ -366,8 +369,9 @@ Parser::ParseRangeNodeImpl(const FieldName& field_name, const Json& body) {
AssertInfo(has_lower_value && has_upper_value,
"illegal binary-range node");
return std::make_unique<BinaryRangeExprImpl<T>>(
schema.get_field_id(field_name),
schema[field_name].get_data_type(),
ColumnInfo(schema.get_field_id(field_name),
schema[field_name].get_data_type()),
proto::plan::GenericValue::ValCase::VAL_NOT_SET,
lower_inclusive,
upper_inclusive,
lower_value,
Expand Down
Loading