Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: fp16/bf16 vector support #268

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ jobs:
strategy:
fail-fast: false
matrix:
macos: [11]
macos: [13]
env:
CCACHE_DIR: ${{ github.workspace }}/.ccache
CCACHE_COMPILERCHECK: content
Expand Down
8 changes: 2 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ project(milvus_sdk LANGUAGES CXX)

set(CMAKE_VERBOSE_MAKEFILE OFF)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
Expand Down Expand Up @@ -65,13 +65,9 @@ define_option_string(MILVUS_WITH_GRPC "Using gRPC from" "module"
define_option_string(MILVUS_WITH_ZLIB "Using Zlib from" "module" "package" "module")
define_option_string(MILVUS_WITH_NLOHMANN_JSON "nlohmann json from" "module" "package" "module")
define_option_string(MILVUS_WITH_GTEST "Using GTest from" "module" "package" "module")
define_option_string(MILVUS_WITH_EIGEN "Using Eigen from" "module" "package" "module")


set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED on)
set(BUILD_SCRIPTS_DIR ${PROJECT_SOURCE_DIR}/scripts)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)


# load third packages and milvus-proto
include(ThirdPartyPackages)
Expand Down
4 changes: 2 additions & 2 deletions DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ Or `make all-release` to build the release version.
And you could also create a dedicated CMake build directory, then use CMake to build it from the source by yourself

```shell
$ mkdir build
$ cd build
$ mkdir cmake_build
$ cd cmake_build
$ cmake ..
$ make
```
Expand Down
20 changes: 20 additions & 0 deletions cmake/ThirdPartyPackages.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,31 @@ FetchContent_Declare(
grpc
GIT_REPOSITORY https://github.com/grpc/grpc.git
GIT_TAG v${GRPC_VERSION}
GIT_SHALLOW TRUE
)

# nlohmann_json
FetchContent_Declare(
nlohmann_json
GIT_REPOSITORY https://github.com/nlohmann/json.git
GIT_TAG v${NLOHMANN_JSON_VERSION}
GIT_SHALLOW TRUE
)

# googletest
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG release-${GOOGLETEST_VERSION}
GIT_SHALLOW TRUE
)

FetchContent_Declare(
eigen3
GIT_REPOSITORY https://gitlab.com/libeigen/eigen.git
GIT_TAG 3.4.0
GIT_SHALLOW TRUE
)

# grpc
if ("${MILVUS_WITH_GRPC}" STREQUAL "pakcage")
Expand Down Expand Up @@ -79,3 +88,14 @@ else ()
add_subdirectory(${nlohmann_json_SOURCE_DIR} ${nlohmann_json_BINARY_DIR} EXCLUDE_FROM_ALL)
endif ()
endif ()

# eigen3
if ("${MILVUS_WITH_EIGEN}" STREQUAL "package")
find_package(Eigen3 REQUIRED NO_MODULE)
else ()
if (NOT eigen3_POPULATED)
FetchContent_Populate(eigen3)
set(BUILD_TESTING OFF CACHE INTERNAL "")
add_subdirectory(${eigen3_SOURCE_DIR} ${eigen3_BINARY_DIR} EXCLUDE_FROM_ALL)
endif ()
endif ()
2 changes: 1 addition & 1 deletion examples/simple/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ main(int argc, char* argv[]) {
std::uniform_int_distribution<int64_t> int64_gen(0, row_count - 1);
int64_t q_number = int64_gen(ran);
std::vector<float> q_vector = insert_vectors[q_number];
arguments.AddTargetVector(field_face_name, std::move(q_vector));
arguments.AddTargetVector<milvus::FloatVecFieldData>(field_face_name, std::move(q_vector));
std::cout << "Searching the No." << q_number << " entity..." << std::endl;

milvus::SearchResults search_results{};
Expand Down
4 changes: 2 additions & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ add_library(milvus_sdk ${impl_files} ${impl_types_files})
# add proto gens
add_milvus_protos(milvus_sdk)
set_target_properties(milvus_sdk PROPERTIES OUTPUT_NAME milvus_sdk)
target_link_libraries(milvus_sdk gRPC::grpc++ nlohmann_json::nlohmann_json)
target_link_libraries(milvus_sdk gRPC::grpc++ nlohmann_json::nlohmann_json Eigen3::Eigen)
target_include_directories(milvus_sdk PUBLIC include)


Expand All @@ -37,4 +37,4 @@ install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/
install(TARGETS milvus_sdk
EXPORT milvus_sdk
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
155 changes: 117 additions & 38 deletions src/impl/MilvusClientImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,53 @@ MilvusClientImpl::Delete(const std::string& collection_name, const std::string&
post);
}

template <DataType Dt>
struct SearchTrait;

template <>
struct SearchTrait<DataType::BINARY_VECTOR> {
static constexpr proto::common::PlaceholderType type = proto::common::PlaceholderType::BinaryVector;
using FieldDataType = BinaryVecFieldData;
static constexpr size_t size_of_element = sizeof(uint8_t);
};

template <>
struct SearchTrait<DataType::FLOAT_VECTOR> {
static constexpr proto::common::PlaceholderType type = proto::common::PlaceholderType::FloatVector;
using FieldDataType = FloatVecFieldData;
static constexpr size_t size_of_element = sizeof(float);
};

template <>
struct SearchTrait<DataType::FLOAT16_VECTOR> {
static constexpr proto::common::PlaceholderType type = proto::common::PlaceholderType::Float16Vector;
using FieldDataType = Float16VecFieldData;
// Float16VecFieldData stores float16 vectors as strings
static constexpr size_t size_of_element = sizeof(uint8_t);
static_assert(sizeof(Eigen::half) == 2, "Eigen::half size is not 2");
};

template <>
struct SearchTrait<DataType::BFLOAT16_VECTOR> {
static constexpr proto::common::PlaceholderType type = proto::common::PlaceholderType::BFloat16Vector;
using FieldDataType = BFloat16VecFieldData;
// BFloat16VecFieldData stores bfloat16 vectors as strings
static constexpr size_t size_of_element = sizeof(uint8_t);
static_assert(sizeof(Eigen::bfloat16) == 2, "Eigen::bfloat16 size is not 2");
};

template <DataType Dt>
static void
SetPlaceHolderValue(proto::common::PlaceholderValue& placeholder_value, FieldDataPtr target) {
placeholder_value.set_type(SearchTrait<Dt>::type);
auto& vec = dynamic_cast<typename SearchTrait<Dt>::FieldDataType&>(*target);
for (const auto& v : vec.Data()) {
std::string placeholder_data(reinterpret_cast<const char*>(v.data()),
v.size() * SearchTrait<Dt>::size_of_element);
placeholder_value.add_values(std::move(placeholder_data));
}
}

Status
MilvusClientImpl::Search(const SearchArguments& arguments, SearchResults& results, int timeout) {
std::string anns_field;
Expand Down Expand Up @@ -706,23 +753,21 @@ MilvusClientImpl::Search(const SearchArguments& arguments, SearchResults& result
auto& placeholder_value = *placeholder_group.add_placeholders();
placeholder_value.set_tag("$0");
auto target = arguments.TargetVectors();
if (target->Type() == DataType::BINARY_VECTOR) {
// bins
placeholder_value.set_type(proto::common::PlaceholderType::BinaryVector);
auto& bins_vec = dynamic_cast<BinaryVecFieldData&>(*target);
for (const auto& bins : bins_vec.Data()) {
std::string placeholder_data(reinterpret_cast<const char*>(bins.data()), bins.size());
placeholder_value.add_values(std::move(placeholder_data));
}
} else {
// floats
placeholder_value.set_type(proto::common::PlaceholderType::FloatVector);
auto& floats_vec = dynamic_cast<FloatVecFieldData&>(*target);
for (const auto& floats : floats_vec.Data()) {
std::string placeholder_data(reinterpret_cast<const char*>(floats.data()),
floats.size() * sizeof(float));
placeholder_value.add_values(std::move(placeholder_data));
}
switch (target->Type()) {
case DataType::BINARY_VECTOR:
SetPlaceHolderValue<DataType::BINARY_VECTOR>(placeholder_value, target);
break;
case DataType::FLOAT_VECTOR:
SetPlaceHolderValue<DataType::FLOAT_VECTOR>(placeholder_value, target);
break;
case DataType::FLOAT16_VECTOR:
SetPlaceHolderValue<DataType::FLOAT16_VECTOR>(placeholder_value, target);
break;
case DataType::BFLOAT16_VECTOR:
SetPlaceHolderValue<DataType::BFLOAT16_VECTOR>(placeholder_value, target);
break;
default:
assert(false);
}
rpc_request.set_placeholder_group(std::move(placeholder_group.SerializeAsString()));

Expand Down Expand Up @@ -837,30 +882,64 @@ MilvusClientImpl::CalcDistance(const CalcDistanceArguments& arguments, DistanceA
if (IsVectorType(arg_vectors->Type())) {
auto data_array = rpc_vectors->mutable_data_array();

if (arg_vectors->Type() == DataType::FLOAT_VECTOR) {
FloatVecFieldDataPtr data_ptr = std::static_pointer_cast<FloatVecFieldData>(arg_vectors);
auto float_vectors = data_array->mutable_float_vector();
auto mutable_data = float_vectors->mutable_data();
auto& vectors = data_ptr->Data();
for (auto& vector : vectors) {
mutable_data->Add(vector.begin(), vector.end());
switch (arg_vectors->Type()) {
case DataType::FLOAT_VECTOR: {
FloatVecFieldDataPtr data_ptr = std::static_pointer_cast<FloatVecFieldData>(arg_vectors);
auto mutable_data = data_array->mutable_float_vector()->mutable_data();
auto& vectors = data_ptr->Data();
for (auto& vector : vectors) {
mutable_data->Add(vector.begin(), vector.end());
}

// suppose vectors is not empty, already checked by Validate()
data_array->set_dim(static_cast<int>(vectors[0].size()));
break;
}

// suppose vectors is not empty, already checked by Validate()
data_array->set_dim(static_cast<int>(vectors[0].size()));
} else {
auto data_ptr = std::static_pointer_cast<BinaryVecFieldData>(arg_vectors);
auto& str = *data_array->mutable_binary_vector();
auto& vectors = data_ptr->Data();
// user specify dimension(only for binary vectors), if not, get it from vectors
auto dimensions = arguments.Dimension() > 0 ? arguments.Dimension() : (vectors.front().size() * 8);
str.reserve(dimensions * vectors.size() / 8);
for (auto& vector : vectors) {
str.append(vector);
case DataType::BINARY_VECTOR: {
auto data_ptr = std::static_pointer_cast<BinaryVecFieldData>(arg_vectors);
auto& str = *data_array->mutable_binary_vector();
auto& vectors = data_ptr->Data();
// user specify dimension(only for binary vectors), if not, get it from vectors
auto dimensions =
arguments.Dimension() > 0 ? arguments.Dimension() : (vectors.front().size() * 8);
str.reserve(dimensions * vectors.size() / 8);
for (auto& vector : vectors) {
str.append(vector);
}
data_array->set_dim(static_cast<int>(dimensions));
break;
}
data_array->set_dim(static_cast<int>(dimensions));
case DataType::FLOAT16_VECTOR: {
Float16VecFieldDataPtr data_ptr = std::static_pointer_cast<Float16VecFieldData>(arg_vectors);
auto& mutable_data = *data_array->mutable_float16_vector();
auto& vectors = data_ptr->Data();
auto dimensions = arguments.Dimension() > 0 ? arguments.Dimension() : vectors.front().size();
mutable_data.reserve(dimensions * vectors.size() * 2);

for (auto& vector : vectors) {
mutable_data.append(vector);
}

data_array->set_dim(static_cast<int>(dimensions));
break;
}
case DataType::BFLOAT16_VECTOR: {
BFloat16VecFieldDataPtr data_ptr = std::static_pointer_cast<BFloat16VecFieldData>(arg_vectors);
auto& mutable_data = *data_array->mutable_bfloat16_vector();
auto& vectors = data_ptr->Data();
auto dimensions = arguments.Dimension() > 0 ? arguments.Dimension() : vectors.front().size();
mutable_data.reserve(dimensions * vectors.size() * 2);

for (auto& vector : vectors) {
mutable_data.append(vector);
}

data_array->set_dim(static_cast<int>(dimensions));
break;
}
default:
assert(false);
}

} else if (arg_vectors->Type() == DataType::INT64) {
auto id_array = rpc_vectors->mutable_id_array();
id_array->set_collection_name(is_left ? arguments.LeftCollection() : arguments.RightCollection());
Expand Down
Loading
Loading