From e1e863bf0944f19603b3553ac2bef0090b575e8d Mon Sep 17 00:00:00 2001 From: Yinzuo Jiang Date: Fri, 13 Dec 2024 22:58:51 +0800 Subject: [PATCH] feat!: fp16/bf16 vector support issue: milvus-io/milvus#37448 Signed-off-by: Yinzuo Jiang Signed-off-by: Yinzuo Jiang --- .github/workflows/main.yaml | 2 +- CMakeLists.txt | 8 +- DEVELOPMENT.md | 4 +- cmake/ThirdPartyPackages.cmake | 20 ++ examples/simple/main.cpp | 2 +- src/CMakeLists.txt | 4 +- src/impl/MilvusClientImpl.cpp | 155 +++++++++++---- src/impl/TypeUtils.cpp | 180 +++++++++++++++++- src/impl/TypeUtils.h | 25 ++- src/impl/types/CalcDistanceArguments.cpp | 20 ++ src/impl/types/CollectionSchema.cpp | 4 +- src/impl/types/FieldData.cpp | 171 ++++++++++++----- src/impl/types/FloatUtils.cpp | 176 +++++++++++++++++ src/impl/types/SearchArguments.cpp | 128 +++++++------ .../milvus/types/CalcDistanceArguments.h | 24 +++ src/include/milvus/types/DataType.h | 4 + src/include/milvus/types/FieldData.h | 119 +++++++++--- src/include/milvus/types/FloatUtils.h | 65 +++++++ src/include/milvus/types/SearchArguments.h | 69 ++++--- test/it/TestInsert.cpp | 37 ++-- test/it/TestSearch.cpp | 54 ++++-- test/st/PythonMilvusServer.cpp | 10 +- test/st/PythonMilvusServer.h | 8 +- test/st/TestCollection.cpp | 124 +++++++----- test/st/TestSearch.cpp | 126 ++++++------ test/st/TestSearchWithBinaryVectors.cpp | 62 +++++- test/ut/TestCalcDistanceArguments.cpp | 30 ++- test/ut/TestCollectionSchema.cpp | 15 +- test/ut/TestFieldData.cpp | 38 +++- test/ut/TestFieldSchema.cpp | 12 +- test/ut/TestFloatUtils.cpp | 48 +++++ test/ut/TestSearchArguments.cpp | 94 +++++++-- test/ut/TestTypeUtils.cpp | 135 ++++++++++--- 33 files changed, 1559 insertions(+), 414 deletions(-) create mode 100644 src/impl/types/FloatUtils.cpp create mode 100644 src/include/milvus/types/FloatUtils.h create mode 100644 test/ut/TestFloatUtils.cpp diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 721a2c2..379fe0d 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -132,7 +132,7 @@ jobs: strategy: fail-fast: false matrix: - macos: [11] + macos: [13] env: CCACHE_DIR: ${{ github.workspace }}/.ccache CCACHE_COMPILERCHECK: content diff --git a/CMakeLists.txt b/CMakeLists.txt index 2be4880..9b8d34b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,7 +19,7 @@ project(milvus_sdk LANGUAGES CXX) set(CMAKE_VERBOSE_MAKEFILE OFF) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") @@ -65,13 +65,9 @@ define_option_string(MILVUS_WITH_GRPC "Using gRPC from" "module" define_option_string(MILVUS_WITH_ZLIB "Using Zlib from" "module" "package" "module") define_option_string(MILVUS_WITH_NLOHMANN_JSON "nlohmann json from" "module" "package" "module") define_option_string(MILVUS_WITH_GTEST "Using GTest from" "module" "package" "module") +define_option_string(MILVUS_WITH_EIGEN "Using Eigen from" "module" "package" "module") - -set(CMAKE_CXX_STANDARD 14) -set(CMAKE_CXX_STANDARD_REQUIRED on) set(BUILD_SCRIPTS_DIR ${PROJECT_SOURCE_DIR}/scripts) -set(CMAKE_EXPORT_COMPILE_COMMANDS ON) - # load third packages and milvus-proto include(ThirdPartyPackages) diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 4989341..3e0fa52 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -42,8 +42,8 @@ Or `make all-release` to build the release version. And you could also create a dedicated CMake build directory, then use CMake to build it from the source by yourself ```shell -$ mkdir build -$ cd build +$ mkdir cmake_build +$ cd cmake_build $ cmake .. $ make ``` diff --git a/cmake/ThirdPartyPackages.cmake b/cmake/ThirdPartyPackages.cmake index ab9c619..02f9c5e 100644 --- a/cmake/ThirdPartyPackages.cmake +++ b/cmake/ThirdPartyPackages.cmake @@ -27,6 +27,7 @@ FetchContent_Declare( grpc GIT_REPOSITORY https://github.com/grpc/grpc.git GIT_TAG v${GRPC_VERSION} + GIT_SHALLOW TRUE ) # nlohmann_json @@ -34,6 +35,7 @@ FetchContent_Declare( nlohmann_json GIT_REPOSITORY https://github.com/nlohmann/json.git GIT_TAG v${NLOHMANN_JSON_VERSION} + GIT_SHALLOW TRUE ) # googletest @@ -41,8 +43,15 @@ FetchContent_Declare( googletest GIT_REPOSITORY https://github.com/google/googletest.git GIT_TAG release-${GOOGLETEST_VERSION} + GIT_SHALLOW TRUE ) +FetchContent_Declare( + eigen3 + GIT_REPOSITORY https://gitlab.com/libeigen/eigen.git + GIT_TAG 3.4.0 + GIT_SHALLOW TRUE +) # grpc if ("${MILVUS_WITH_GRPC}" STREQUAL "pakcage") @@ -79,3 +88,14 @@ else () add_subdirectory(${nlohmann_json_SOURCE_DIR} ${nlohmann_json_BINARY_DIR} EXCLUDE_FROM_ALL) endif () endif () + +# eigen3 +if ("${MILVUS_WITH_EIGEN}" STREQUAL "package") + find_package(Eigen3 REQUIRED NO_MODULE) +else () + if (NOT eigen3_POPULATED) + FetchContent_Populate(eigen3) + set(BUILD_TESTING OFF CACHE INTERNAL "") + add_subdirectory(${eigen3_SOURCE_DIR} ${eigen3_BINARY_DIR} EXCLUDE_FROM_ALL) + endif () +endif () diff --git a/examples/simple/main.cpp b/examples/simple/main.cpp index 3b4c840..32796e3 100644 --- a/examples/simple/main.cpp +++ b/examples/simple/main.cpp @@ -122,7 +122,7 @@ main(int argc, char* argv[]) { std::uniform_int_distribution int64_gen(0, row_count - 1); int64_t q_number = int64_gen(ran); std::vector q_vector = insert_vectors[q_number]; - arguments.AddTargetVector(field_face_name, std::move(q_vector)); + arguments.AddTargetVector(field_face_name, std::move(q_vector)); std::cout << "Searching the No." << q_number << " entity..." << std::endl; milvus::SearchResults search_results{}; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f50a7d7..0c02e7c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -21,7 +21,7 @@ add_library(milvus_sdk ${impl_files} ${impl_types_files}) # add proto gens add_milvus_protos(milvus_sdk) set_target_properties(milvus_sdk PROPERTIES OUTPUT_NAME milvus_sdk) -target_link_libraries(milvus_sdk gRPC::grpc++ nlohmann_json::nlohmann_json) +target_link_libraries(milvus_sdk gRPC::grpc++ nlohmann_json::nlohmann_json Eigen3::Eigen) target_include_directories(milvus_sdk PUBLIC include) @@ -37,4 +37,4 @@ install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/ install(TARGETS milvus_sdk EXPORT milvus_sdk LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) \ No newline at end of file + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/src/impl/MilvusClientImpl.cpp b/src/impl/MilvusClientImpl.cpp index db66182..0704afc 100644 --- a/src/impl/MilvusClientImpl.cpp +++ b/src/impl/MilvusClientImpl.cpp @@ -666,6 +666,53 @@ MilvusClientImpl::Delete(const std::string& collection_name, const std::string& post); } +template +struct SearchTrait; + +template <> +struct SearchTrait { + static constexpr proto::common::PlaceholderType type = proto::common::PlaceholderType::BinaryVector; + using FieldDataType = BinaryVecFieldData; + static constexpr size_t size_of_element = sizeof(uint8_t); +}; + +template <> +struct SearchTrait { + static constexpr proto::common::PlaceholderType type = proto::common::PlaceholderType::FloatVector; + using FieldDataType = FloatVecFieldData; + static constexpr size_t size_of_element = sizeof(float); +}; + +template <> +struct SearchTrait { + static constexpr proto::common::PlaceholderType type = proto::common::PlaceholderType::Float16Vector; + using FieldDataType = Float16VecFieldData; + // Float16VecFieldData stores float16 vectors as strings + static constexpr size_t size_of_element = sizeof(uint8_t); + static_assert(sizeof(Eigen::half) == 2, "Eigen::half size is not 2"); +}; + +template <> +struct SearchTrait { + static constexpr proto::common::PlaceholderType type = proto::common::PlaceholderType::BFloat16Vector; + using FieldDataType = BFloat16VecFieldData; + // BFloat16VecFieldData stores bfloat16 vectors as strings + static constexpr size_t size_of_element = sizeof(uint8_t); + static_assert(sizeof(Eigen::bfloat16) == 2, "Eigen::bfloat16 size is not 2"); +}; + +template +static void +SetPlaceHolderValue(proto::common::PlaceholderValue& placeholder_value, FieldDataPtr target) { + placeholder_value.set_type(SearchTrait
::type); + auto& vec = dynamic_cast::FieldDataType&>(*target); + for (const auto& v : vec.Data()) { + std::string placeholder_data(reinterpret_cast(v.data()), + v.size() * SearchTrait
::size_of_element); + placeholder_value.add_values(std::move(placeholder_data)); + } +} + Status MilvusClientImpl::Search(const SearchArguments& arguments, SearchResults& results, int timeout) { std::string anns_field; @@ -706,23 +753,21 @@ MilvusClientImpl::Search(const SearchArguments& arguments, SearchResults& result auto& placeholder_value = *placeholder_group.add_placeholders(); placeholder_value.set_tag("$0"); auto target = arguments.TargetVectors(); - if (target->Type() == DataType::BINARY_VECTOR) { - // bins - placeholder_value.set_type(proto::common::PlaceholderType::BinaryVector); - auto& bins_vec = dynamic_cast(*target); - for (const auto& bins : bins_vec.Data()) { - std::string placeholder_data(reinterpret_cast(bins.data()), bins.size()); - placeholder_value.add_values(std::move(placeholder_data)); - } - } else { - // floats - placeholder_value.set_type(proto::common::PlaceholderType::FloatVector); - auto& floats_vec = dynamic_cast(*target); - for (const auto& floats : floats_vec.Data()) { - std::string placeholder_data(reinterpret_cast(floats.data()), - floats.size() * sizeof(float)); - placeholder_value.add_values(std::move(placeholder_data)); - } + switch (target->Type()) { + case DataType::BINARY_VECTOR: + SetPlaceHolderValue(placeholder_value, target); + break; + case DataType::FLOAT_VECTOR: + SetPlaceHolderValue(placeholder_value, target); + break; + case DataType::FLOAT16_VECTOR: + SetPlaceHolderValue(placeholder_value, target); + break; + case DataType::BFLOAT16_VECTOR: + SetPlaceHolderValue(placeholder_value, target); + break; + default: + assert(false); } rpc_request.set_placeholder_group(std::move(placeholder_group.SerializeAsString())); @@ -837,30 +882,64 @@ MilvusClientImpl::CalcDistance(const CalcDistanceArguments& arguments, DistanceA if (IsVectorType(arg_vectors->Type())) { auto data_array = rpc_vectors->mutable_data_array(); - if (arg_vectors->Type() == DataType::FLOAT_VECTOR) { - FloatVecFieldDataPtr data_ptr = std::static_pointer_cast(arg_vectors); - auto float_vectors = data_array->mutable_float_vector(); - auto mutable_data = float_vectors->mutable_data(); - auto& vectors = data_ptr->Data(); - for (auto& vector : vectors) { - mutable_data->Add(vector.begin(), vector.end()); + switch (arg_vectors->Type()) { + case DataType::FLOAT_VECTOR: { + FloatVecFieldDataPtr data_ptr = std::static_pointer_cast(arg_vectors); + auto mutable_data = data_array->mutable_float_vector()->mutable_data(); + auto& vectors = data_ptr->Data(); + for (auto& vector : vectors) { + mutable_data->Add(vector.begin(), vector.end()); + } + + // suppose vectors is not empty, already checked by Validate() + data_array->set_dim(static_cast(vectors[0].size())); + break; } - - // suppose vectors is not empty, already checked by Validate() - data_array->set_dim(static_cast(vectors[0].size())); - } else { - auto data_ptr = std::static_pointer_cast(arg_vectors); - auto& str = *data_array->mutable_binary_vector(); - auto& vectors = data_ptr->Data(); - // user specify dimension(only for binary vectors), if not, get it from vectors - auto dimensions = arguments.Dimension() > 0 ? arguments.Dimension() : (vectors.front().size() * 8); - str.reserve(dimensions * vectors.size() / 8); - for (auto& vector : vectors) { - str.append(vector); + case DataType::BINARY_VECTOR: { + auto data_ptr = std::static_pointer_cast(arg_vectors); + auto& str = *data_array->mutable_binary_vector(); + auto& vectors = data_ptr->Data(); + // user specify dimension(only for binary vectors), if not, get it from vectors + auto dimensions = + arguments.Dimension() > 0 ? arguments.Dimension() : (vectors.front().size() * 8); + str.reserve(dimensions * vectors.size() / 8); + for (auto& vector : vectors) { + str.append(vector); + } + data_array->set_dim(static_cast(dimensions)); + break; } - data_array->set_dim(static_cast(dimensions)); + case DataType::FLOAT16_VECTOR: { + Float16VecFieldDataPtr data_ptr = std::static_pointer_cast(arg_vectors); + auto& mutable_data = *data_array->mutable_float16_vector(); + auto& vectors = data_ptr->Data(); + auto dimensions = arguments.Dimension() > 0 ? arguments.Dimension() : vectors.front().size(); + mutable_data.reserve(dimensions * vectors.size() * 2); + + for (auto& vector : vectors) { + mutable_data.append(vector); + } + + data_array->set_dim(static_cast(dimensions)); + break; + } + case DataType::BFLOAT16_VECTOR: { + BFloat16VecFieldDataPtr data_ptr = std::static_pointer_cast(arg_vectors); + auto& mutable_data = *data_array->mutable_bfloat16_vector(); + auto& vectors = data_ptr->Data(); + auto dimensions = arguments.Dimension() > 0 ? arguments.Dimension() : vectors.front().size(); + mutable_data.reserve(dimensions * vectors.size() * 2); + + for (auto& vector : vectors) { + mutable_data.append(vector); + } + + data_array->set_dim(static_cast(dimensions)); + break; + } + default: + assert(false); } - } else if (arg_vectors->Type() == DataType::INT64) { auto id_array = rpc_vectors->mutable_id_array(); id_array->set_collection_name(is_left ? arguments.LeftCollection() : arguments.RightCollection()); diff --git a/src/impl/TypeUtils.cpp b/src/impl/TypeUtils.cpp index ae213d0..9dba1be 100644 --- a/src/impl/TypeUtils.cpp +++ b/src/impl/TypeUtils.cpp @@ -16,6 +16,20 @@ #include "TypeUtils.h" +#include + +#include +#include +#include +#include +#include +#include + +#include "Eigen/src/Core/arch/Default/BFloat16.h" +#include "Eigen/src/Core/arch/Default/Half.h" +#include "milvus/types/DataType.h" +#include "schema.pb.h" + namespace milvus { bool @@ -180,7 +194,7 @@ operator==(const proto::schema::FieldData& lhs, const BinaryVecFieldData& rhs) { } const auto& vectors = lhs.vectors(); - if (vectors.has_float_vector()) { + if (!vectors.has_binary_vector()) { return false; } @@ -199,6 +213,85 @@ operator==(const proto::schema::FieldData& lhs, const BinaryVecFieldData& rhs) { return it == vectors_data.end(); } +template +struct Fp16OrBf16VecFieldDataTraits; + +template <> +struct Fp16OrBf16VecFieldDataTraits { + static bool + has_vector(const proto::schema::FieldData& field_data) { + return field_data.vectors().has_bfloat16_vector(); + } + + static const std::string& + get_vector(const proto::schema::FieldData& field_data) { + return field_data.vectors().bfloat16_vector(); + } + + static std::string& + mutable_vector(proto::schema::VectorField* field) { + return *(field->mutable_bfloat16_vector()); + } +}; + +template <> +struct Fp16OrBf16VecFieldDataTraits { + static bool + has_vector(const proto::schema::FieldData& field_data) { + return field_data.vectors().has_float16_vector(); + } + + static const std::string& + get_vector(const proto::schema::FieldData& field_data) { + return field_data.vectors().float16_vector(); + } + + static std::string& + mutable_vector(proto::schema::VectorField* field) { + return *(field->mutable_float16_vector()); + } +}; + +template +static bool +compareFp16OrBf16VecFieldData(const proto::schema::FieldData& lhs, const T& rhs) { + if (lhs.field_name() != rhs.Name()) { + return false; + } + if (!lhs.has_vectors()) { + return false; + } + + if (!Fp16OrBf16VecFieldDataTraits::has_vector(lhs)) { + return false; + } + + const std::string& vectors_data = Fp16OrBf16VecFieldDataTraits::get_vector(lhs); + const auto& field_data_vecs = rhs.Data(); + if (field_data_vecs.empty() && !vectors_data.empty()) { + return false; + } + const char* vector_data_begin = vectors_data.data(); + for (const auto& field_data_vec : field_data_vecs) { + const std::string_view vec_data(vector_data_begin, field_data_vec.size()); + if (vec_data != field_data_vec) { + return false; + } + vector_data_begin += field_data_vec.size(); + } + return true; +} + +bool +operator==(const proto::schema::FieldData& lhs, const BFloat16VecFieldData& rhs) { + return compareFp16OrBf16VecFieldData(lhs, rhs); +} + +bool +operator==(const proto::schema::FieldData& lhs, const Float16VecFieldData& rhs) { + return compareFp16OrBf16VecFieldData(lhs, rhs); +} + bool operator==(const proto::schema::FieldData& lhs, const FloatVecFieldData& rhs) { if (lhs.field_name() != rhs.Name()) { @@ -255,6 +348,10 @@ operator==(const proto::schema::FieldData& lhs, const Field& rhs) { return lhs == dynamic_cast(rhs); case DataType::FLOAT_VECTOR: return lhs == dynamic_cast(rhs); + case DataType::FLOAT16_VECTOR: + return lhs == dynamic_cast(rhs); + case DataType::BFLOAT16_VECTOR: + return lhs == dynamic_cast(rhs); default: return false; } @@ -296,6 +393,10 @@ DataTypeCast(DataType type) { return proto::schema::DataType::BinaryVector; case DataType::FLOAT_VECTOR: return proto::schema::DataType::FloatVector; + case DataType::FLOAT16_VECTOR: + return proto::schema::DataType::Float16Vector; + case DataType::BFLOAT16_VECTOR: + return proto::schema::DataType::BFloat16Vector; default: return proto::schema::DataType::None; } @@ -324,6 +425,10 @@ DataTypeCast(proto::schema::DataType type) { return DataType::BINARY_VECTOR; case proto::schema::DataType::FloatVector: return DataType::FLOAT_VECTOR; + case proto::schema::DataType::Float16Vector: + return DataType::FLOAT16_VECTOR; + case proto::schema::DataType::BFloat16Vector: + return DataType::BFLOAT16_VECTOR; default: return DataType::UNKNOWN; } @@ -424,6 +529,48 @@ CreateProtoFieldData(const FloatVecFieldData& field) { return ret; } +template +proto::schema::VectorField* +CreateFp16OrBf16VecProtoFieldData(const T& field) { + auto ret = new proto::schema::VectorField{}; + auto& data = field.Data(); + auto dim = data.front().size() / 2; + auto& vectors_data = Fp16OrBf16VecFieldDataTraits::mutable_vector(ret); + vectors_data.reserve(data.size() * dim * 2); + for (const auto& item : data) { + if constexpr (BYTE_ORDER == LITTLE_ENDIAN) { + std::copy(item.begin(), item.end(), std::back_inserter(vectors_data)); + } else { + // Big endian + for (auto fp : item) { + union { + uint16_t u16; + char bytes[2]; + } value; + if constexpr (std::is_same_v) { + value.u16 = Eigen::bfloat16_impl::raw_bfloat16_as_uint16(fp); + } else { + value.u16 = Eigen::half_impl::raw_half_as_uint16(fp); + } + vectors_data.push_back(value.bytes[1]); + vectors_data.push_back(value.bytes[0]); + } + } + } + ret->set_dim(static_cast(dim)); + return ret; +} + +proto::schema::VectorField* +CreateProtoFieldData(const Float16VecFieldData& field) { + return CreateFp16OrBf16VecProtoFieldData(field); +} + +proto::schema::VectorField* +CreateProtoFieldData(const BFloat16VecFieldData& field) { + return CreateFp16OrBf16VecProtoFieldData(field); +} + proto::schema::ScalarField* CreateProtoFieldData(const BoolFieldData& field) { auto ret = new proto::schema::ScalarField{}; @@ -512,6 +659,12 @@ CreateProtoFieldData(const Field& field) { case DataType::FLOAT_VECTOR: field_data.set_allocated_vectors(CreateProtoFieldData(dynamic_cast(field))); break; + case DataType::FLOAT16_VECTOR: + field_data.set_allocated_vectors(CreateProtoFieldData(dynamic_cast(field))); + break; + case DataType::BFLOAT16_VECTOR: + field_data.set_allocated_vectors(CreateProtoFieldData(dynamic_cast(field))); + break; case DataType::BOOL: field_data.set_allocated_scalars(CreateProtoFieldData(dynamic_cast(field))); break; @@ -559,6 +712,16 @@ CreateMilvusFieldData(const milvus::proto::schema::FieldData& field_data, size_t name, BuildFieldDataVectors>( field_data.vectors().dim(), field_data.vectors().float_vector().data(), offset, count)); + case proto::schema::DataType::Float16Vector: + return std::make_shared( + name, BuildFieldDataVectors(field_data.vectors().dim() * 2, + field_data.vectors().float16_vector(), offset, count)); + + case proto::schema::DataType::BFloat16Vector: + return std::make_shared( + name, BuildFieldDataVectors(field_data.vectors().dim() * 2, + field_data.vectors().bfloat16_vector(), offset, count)); + case proto::schema::DataType::Bool: return std::make_shared( name, BuildFieldDataScalars(field_data.scalars().bool_data().data(), offset, count)); @@ -611,6 +774,16 @@ CreateMilvusFieldData(const milvus::proto::schema::FieldData& field_data) { name, BuildFieldDataVectors>(field_data.vectors().dim(), field_data.vectors().float_vector().data())); + case proto::schema::DataType::Float16Vector: + return std::make_shared( + name, BuildFieldDataVectors(field_data.vectors().dim() * 2, + field_data.vectors().float16_vector())); + + case proto::schema::DataType::BFloat16Vector: + return std::make_shared( + name, BuildFieldDataVectors(field_data.vectors().dim() * 2, + field_data.vectors().bfloat16_vector())); + case proto::schema::DataType::Bool: return std::make_shared( name, BuildFieldDataScalars(field_data.scalars().bool_data().data())); @@ -800,11 +973,6 @@ IndexStateCast(proto::common::IndexState state) { } } -bool -IsVectorType(DataType type) { - return (DataType::BINARY_VECTOR == type || DataType::FLOAT_VECTOR == type); -} - std::string Base64Encode(const std::string& val) { const char* base64_chars = { diff --git a/src/impl/TypeUtils.h b/src/impl/TypeUtils.h index 8073f65..a7d4537 100644 --- a/src/impl/TypeUtils.h +++ b/src/impl/TypeUtils.h @@ -24,6 +24,7 @@ #include "milvus/types/IndexType.h" #include "milvus/types/MetricType.h" #include "milvus/types/SegmentInfo.h" + namespace milvus { bool @@ -56,6 +57,12 @@ operator==(const proto::schema::FieldData& lhs, const BinaryVecFieldData& rhs); bool operator==(const proto::schema::FieldData& lhs, const FloatVecFieldData& rhs); +bool +operator==(const proto::schema::FieldData& lhs, const Float16VecFieldData& rhs); + +bool +operator==(const proto::schema::FieldData& lhs, const BFloat16VecFieldData& rhs); + bool operator==(const proto::schema::FieldData& lhs, const proto::schema::FieldData& rhs); @@ -92,6 +99,12 @@ CreateProtoFieldData(const BinaryVecFieldData& field); proto::schema::VectorField* CreateProtoFieldData(const FloatVecFieldData& field); +proto::schema::VectorField* +CreateProtoFieldData(const Float16VecFieldData& field); + +proto::schema::VectorField* +CreateProtoFieldData(const BFloat16VecFieldData& field); + proto::schema::ScalarField* CreateProtoFieldData(const BoolFieldData& field); @@ -196,8 +209,16 @@ SegmentStateCast(SegmentState state); IndexStateCode IndexStateCast(proto::common::IndexState state); -bool -IsVectorType(DataType type); +constexpr bool +IsDenseVectorType(DataType type) { + return type == DataType::BINARY_VECTOR || type == DataType::FLOAT_VECTOR || type == DataType::FLOAT16_VECTOR || + type == DataType::BFLOAT16_VECTOR; +} + +constexpr bool +IsVectorType(DataType type) { + return IsDenseVectorType(type) || type == DataType::SPARSE_FLOAT_VECTOR; +} std::string Base64Encode(const std::string& val); diff --git a/src/impl/types/CalcDistanceArguments.cpp b/src/impl/types/CalcDistanceArguments.cpp index 202aa50..022457e 100644 --- a/src/impl/types/CalcDistanceArguments.cpp +++ b/src/impl/types/CalcDistanceArguments.cpp @@ -65,6 +65,16 @@ CalcDistanceArguments::SetLeftVectors(BinaryVecFieldDataPtr vectors) { return validate_vectors_with_oper(vectors, [this, &vectors]() { this->vectors_left_ = std::move(vectors); }); } +Status +CalcDistanceArguments::SetLeftVectors(Float16VecFieldDataPtr vectors) { + return validate_vectors_with_oper(vectors, [this, &vectors]() { this->vectors_left_ = std::move(vectors); }); +} + +Status +CalcDistanceArguments::SetLeftVectors(BFloat16VecFieldDataPtr vectors) { + return validate_vectors_with_oper(vectors, [this, &vectors]() { this->vectors_left_ = std::move(vectors); }); +} + Status CalcDistanceArguments::SetLeftVectors(Int64FieldDataPtr ids, std::string collection_name, const std::vector& partition_names) { @@ -100,6 +110,16 @@ CalcDistanceArguments::SetRightVectors(BinaryVecFieldDataPtr vectors) { return validate_vectors_with_oper(vectors, [this, &vectors]() { this->vectors_right_ = std::move(vectors); }); } +Status +CalcDistanceArguments::SetRightVectors(Float16VecFieldDataPtr vectors) { + return validate_vectors_with_oper(vectors, [this, &vectors]() { this->vectors_right_ = std::move(vectors); }); +} + +Status +CalcDistanceArguments::SetRightVectors(BFloat16VecFieldDataPtr vectors) { + return validate_vectors_with_oper(vectors, [this, &vectors]() { this->vectors_right_ = std::move(vectors); }); +} + Status CalcDistanceArguments::SetRightVectors(Int64FieldDataPtr ids, std::string collection_name, const std::vector& partition_names) { diff --git a/src/impl/types/CollectionSchema.cpp b/src/impl/types/CollectionSchema.cpp index 531cd04..6a1f281 100644 --- a/src/impl/types/CollectionSchema.cpp +++ b/src/impl/types/CollectionSchema.cpp @@ -18,6 +18,8 @@ #include +#include "../TypeUtils.h" + namespace milvus { CollectionSchema::CollectionSchema() = default; @@ -90,7 +92,7 @@ CollectionSchema::AnnsFieldNames() const { std::unordered_set ret; for (const auto& field : fields_) { auto data_type = field.FieldDataType(); - if (data_type == DataType::BINARY_VECTOR || data_type == DataType::FLOAT_VECTOR) { + if (IsVectorType(data_type)) { ret.emplace(field.Name()); } } diff --git a/src/impl/types/FieldData.cpp b/src/impl/types/FieldData.cpp index 5e0e9c6..9336b10 100644 --- a/src/impl/types/FieldData.cpp +++ b/src/impl/types/FieldData.cpp @@ -16,33 +16,27 @@ #include "milvus/types/FieldData.h" -namespace milvus { - -namespace { +#include +#include +#include +#include -template -struct DataTypeTraits { - static const bool is_vector = false; -}; +#include "../TypeUtils.h" +#include "milvus/types/DataType.h" +#include "milvus/types/FloatUtils.h" -template <> -struct DataTypeTraits { - static const bool is_vector = true; -}; +namespace milvus { -template <> -struct DataTypeTraits { - static const bool is_vector = true; -}; +namespace { -template ::is_vector, bool> = true> +template = true> StatusCode AddElement(const T& element, std::vector& array) { array.push_back(element); return StatusCode::OK; } -template ::is_vector, bool> = true> +template = true> StatusCode AddElement(const T& element, std::vector& array) { if (element.empty()) { @@ -119,62 +113,61 @@ FieldData::Data() { return data_; } -BinaryVecFieldData::BinaryVecFieldData() : FieldData() { -} - -BinaryVecFieldData::BinaryVecFieldData(std::string name) - : FieldData(std::move(name)) { -} - -BinaryVecFieldData::BinaryVecFieldData(std::string name, const std::vector& data) - : FieldData(std::move(name), data) { +template +BinaryVecFieldDataImpl
::BinaryVecFieldDataImpl() : FieldData() { } -BinaryVecFieldData::BinaryVecFieldData(std::string name, std::vector&& data) - : FieldData(std::move(name), std::move(data)) { +template +BinaryVecFieldDataImpl
::BinaryVecFieldDataImpl(std::string name) : FieldData(std::move(name)) { } -BinaryVecFieldData::BinaryVecFieldData(std::string name, const std::vector>& data) - : FieldData(std::move(name), CreateBinaryStrings(data)) { +template +BinaryVecFieldDataImpl
::BinaryVecFieldDataImpl(std::string name, const std::vector& data) + : FieldData(std::move(name), data) { } -const std::vector& -BinaryVecFieldData::Data() const { - return data_; +template +BinaryVecFieldDataImpl
::BinaryVecFieldDataImpl(std::string name, std::vector&& data) + : FieldData(std::move(name), std::move(data)) { } -std::vector& -BinaryVecFieldData::Data() { - return data_; +template +BinaryVecFieldDataImpl
::BinaryVecFieldDataImpl(std::string name, const std::vector>& data) + : FieldData(std::move(name), CreateBinaryStrings(data)) { } +template std::vector> -BinaryVecFieldData::DataAsUnsignedChars() const { +BinaryVecFieldDataImpl
::DataAsUnsignedChars() const { std::vector> ret; - ret.reserve(data_.size()); - for (const auto& item : data_) { + ret.reserve(this->data_.size()); + for (const auto& item : this->data_) { ret.emplace_back(item.begin(), item.end()); } return ret; } +template StatusCode -BinaryVecFieldData::Add(const std::string& element) { - return AddElement(element, data_); +BinaryVecFieldDataImpl
::Add(const std::string& element) { + return AddElement(element, this->data_); } +template StatusCode -BinaryVecFieldData::Add(std::string&& element) { - return AddElement(element, data_); +BinaryVecFieldDataImpl
::Add(std::string&& element) { + return AddElement(std::move(element), this->data_); } +template StatusCode -BinaryVecFieldData::Add(const std::vector& element) { +BinaryVecFieldDataImpl
::Add(const std::vector& element) { return Add(std::string{element.begin(), element.end()}); } +template std::vector -BinaryVecFieldData::CreateBinaryStrings(const std::vector>& data) { +BinaryVecFieldDataImpl
::CreateBinaryStrings(const std::vector>& data) { std::vector ret; ret.reserve(data.size()); for (const auto& item : data) { @@ -183,9 +176,71 @@ BinaryVecFieldData::CreateBinaryStrings(const std::vector>& return ret; } -std::string -BinaryVecFieldData::CreateBinaryString(const std::vector& data) { - return std::string{data.begin(), data.end()}; +template +Fp16VecFieldData::Fp16VecFieldData(std::string name, const std::vector>& data) + : Fp16VecFieldData(std::move(name), CreateBinaryStringsFromFloats(data)) { +} + +template +Fp16VecFieldData::Fp16VecFieldData(std::string name, std::vector>&& data) + : Fp16VecFieldData(std::move(name), CreateBinaryStringsFromFloats(std::move(data))) { +} + +template +Fp16VecFieldData::Fp16VecFieldData(std::string name, const std::vector>& data) + : Fp16VecFieldData(std::move(name), CreateBinaryStringsFromFloats(data)) { +} + +template +Fp16VecFieldData::Fp16VecFieldData(std::string name, const std::vector>& data) + : Fp16VecFieldData(std::move(name), CreateBinaryStringsFromFloats(data)) { +} + +template +template +std::vector> +Fp16VecFieldData::DataAsFloats() const { + std::vector> result; + result.reserve(this->data_.size()); + for (const typename Fp16VecFieldData::ElementT& str : this->data_) { + std::vector float_vec = Float16NumVecBytesToFloatNumVec(str); + result.push_back(std::move(float_vec)); + } + return result; +} + +template +template +std::vector +Fp16VecFieldData::CreateBinaryStringsFromFloats(const std::vector>& data) { + std::vector result; + result.reserve(data.size()); + for (const auto& item : data) { + std::string bytes = FloatNumVecToFloat16NumVecBytes(item); + result.push_back(std::move(bytes)); + } + return result; +} + +template +StatusCode +Fp16VecFieldData::Add(const std::vector& element) { + std::string bytes = FloatNumVecToFloat16NumVecBytes(element); + return AddElement(bytes, this->data_); +} + +template +StatusCode +Fp16VecFieldData::Add(const std::vector& element) { + std::string bytes = FloatNumVecToFloat16NumVecBytes(element); + return AddElement(bytes, this->data_); +} + +template +StatusCode +Fp16VecFieldData::Add(const std::vector& element) { + std::string bytes = FloatNumVecToFloat16NumVecBytes(element); + return AddElement(bytes, this->data_); } // explicit declare FieldData @@ -199,5 +254,23 @@ template class FieldData; template class FieldData; template class FieldData; template class FieldData, DataType::FLOAT_VECTOR>; - +template class BinaryVecFieldDataImpl; +template class BinaryVecFieldDataImpl; +template class BinaryVecFieldDataImpl; +template class Fp16VecFieldData; +template class Fp16VecFieldData; + +template std::vector> +Fp16VecFieldData::DataAsFloats() const; +template std::vector> +Fp16VecFieldData::DataAsFloats() const; +template std::vector> +Fp16VecFieldData::DataAsFloats() const; + +template std::vector> +Fp16VecFieldData::DataAsFloats() const; +template std::vector> +Fp16VecFieldData::DataAsFloats() const; +template std::vector> +Fp16VecFieldData::DataAsFloats() const; } // namespace milvus diff --git a/src/impl/types/FloatUtils.cpp b/src/impl/types/FloatUtils.cpp new file mode 100644 index 0000000..3c96e2b --- /dev/null +++ b/src/impl/types/FloatUtils.cpp @@ -0,0 +1,176 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "milvus/types/FloatUtils.h" + +#include +#include + +#include "Eigen/Core" + +namespace milvus { + +template +struct Fp16ToFloatTraits; + +template <> +struct Fp16ToFloatTraits { + static inline Eigen::half + convert(uint16_t val) { + return Eigen::half_impl::raw_uint16_to_half(val); + } +}; + +template <> +struct Fp16ToFloatTraits { + static inline Eigen::bfloat16 + convert(uint16_t val) { + return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(val); + } +}; + +template +struct Fp16ToFloatTraits { + static_assert(std::is_same_v || std::is_same_v, "FloatT should be float or double"); + static FloatT + convert(uint16_t val) { + Eigen::half h = Eigen::half_impl::raw_uint16_to_half(val); + return Eigen::half_impl::half_to_float(h); + } +}; + +template +struct Fp16ToFloatTraits { + static_assert(std::is_same_v || std::is_same_v, "FloatT should be float or double"); + static FloatT + convert(uint16_t val) { + Eigen::bfloat16 b = Eigen::bfloat16_impl::raw_uint16_to_bfloat16(val); + return Eigen::bfloat16_impl::bfloat16_to_float(b); + } +}; + +template +std::vector +Float16NumVecBytesToFloatNumVec(const std::string& val) { + static_assert(sizeof(Fp16T) == 2, "Fp16T should be 2 bytes"); + std::vector result; + assert(val.size() % 2 == 0); + result.reserve(val.size() / 2); + for (size_t i = 0; i < val.size(); i += 2) { + union { + uint16_t u16; + char bytes[2]; + } value; + if constexpr (BYTE_ORDER == LITTLE_ENDIAN) { + value.bytes[0] = val[i]; + value.bytes[1] = val[i + 1]; + } else { + value.bytes[0] = val[i + 1]; + value.bytes[1] = val[i]; + } + FloatT fp = Fp16ToFloatTraits::convert(value.u16); + result.push_back(fp); + } + return result; +} + +template +std::string +FloatNumVecToFloat16NumVecBytes(const std::vector& data) { + static_assert(std::is_same_v || std::is_same_v, + "Fp16T should be Eigen::half or Eigen::bfloat16"); + + if constexpr (std::is_same_v || std::is_same_v) { + std::string ret; + ret.reserve(data.size() * 2); + for (const auto item : data) { + Fp16T fp16(static_cast(item)); + union { + uint16_t u16; + char bytes[2]; + } value; + if constexpr (std::is_same_v) { + value.u16 = Eigen::half_impl::raw_half_as_uint16(fp16); + } else { + value.u16 = Eigen::bfloat16_impl::raw_bfloat16_as_uint16(fp16); + } + if constexpr (BYTE_ORDER == LITTLE_ENDIAN) { + ret.push_back(value.bytes[0]); + ret.push_back(value.bytes[1]); + } else { + ret.push_back(value.bytes[1]); + ret.push_back(value.bytes[0]); + } + } + return ret; + } else if constexpr (std::is_same_v || std::is_same_v) { + static_assert(sizeof(Fp16T) == 2, "Fp16T should be 2 bytes"); + if constexpr (BYTE_ORDER == LITTLE_ENDIAN) { + return std::string(reinterpret_cast(data.data()), data.size() * sizeof(Fp16T)); + } else { + // Big endian + std::string ret; + ret.reserve(data.size() * sizeof(Fp16T)); + for (const Fp16T item : data) { + union { + uint16_t u16; + char bytes[2]; + } value; + if constexpr (std::is_same_v) { + value.u16 = Eigen::half_impl::raw_half_as_uint16(item); + } else { + value.u16 = Eigen::bfloat16_impl::raw_bfloat16_as_uint16(item); + } + ret.push_back(value.bytes[1]); + ret.push_back(value.bytes[0]); + } + return ret; + } + } else { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v, + "FloatT should be float, double, Eigen::half or Eigen::bfloat16"); + } +} + +template std::vector +Float16NumVecBytesToFloatNumVec(const std::string& val); +template std::vector +Float16NumVecBytesToFloatNumVec(const std::string& val); +template std::vector +Float16NumVecBytesToFloatNumVec(const std::string& val); + +template std::vector +Float16NumVecBytesToFloatNumVec(const std::string& val); +template std::vector +Float16NumVecBytesToFloatNumVec(const std::string& val); +template std::vector +Float16NumVecBytesToFloatNumVec(const std::string& val); + +template std::string +FloatNumVecToFloat16NumVecBytes(const std::vector& data); +template std::string +FloatNumVecToFloat16NumVecBytes(const std::vector& data); +template std::string +FloatNumVecToFloat16NumVecBytes(const std::vector& data); + +template std::string +FloatNumVecToFloat16NumVecBytes(const std::vector& data); +template std::string +FloatNumVecToFloat16NumVecBytes(const std::vector& data); +template std::string +FloatNumVecToFloat16NumVecBytes(const std::vector& data); +} // namespace milvus diff --git a/src/impl/types/SearchArguments.cpp b/src/impl/types/SearchArguments.cpp index 31474c6..f5230c3 100644 --- a/src/impl/types/SearchArguments.cpp +++ b/src/impl/types/SearchArguments.cpp @@ -17,7 +17,12 @@ #include "milvus/types/SearchArguments.h" #include +#include #include +#include + +#include "milvus/Status.h" +#include "milvus/types/FieldData.h" namespace milvus { namespace { @@ -118,74 +123,91 @@ SearchArguments::SetExpression(std::string expression) { FieldDataPtr SearchArguments::TargetVectors() const { - if (binary_vectors_ != nullptr) { - return binary_vectors_; - } else if (float_vectors_ != nullptr) { - return float_vectors_; - } - - return nullptr; -} - + return field_data_; +} + +TemplateAddTargetVectorDeclaration(, FloatVecFieldData); +TemplateAddTargetVectorDeclaration(, BinaryVecFieldData); +TemplateAddTargetVectorDeclaration(, Float16VecFieldData); +TemplateAddTargetVectorDeclaration(, BFloat16VecFieldData); + +template Status +SearchArguments::AddTargetVector(std::string field_name, const std::vector& vector); + +template Status +SearchArguments::AddTargetVector(std::string field_name, const std::vector& vector); +template Status +SearchArguments::AddTargetVector(std::string field_name, const std::vector& vector); +template Status +SearchArguments::AddTargetVector(std::string field_name, const std::vector& vector); +template Status +SearchArguments::AddTargetVector(std::string field_name, const std::vector& vector); + +template Status +SearchArguments::AddTargetVector(std::string field_name, + const std::vector& vector); +template Status +SearchArguments::AddTargetVector(std::string field_name, const std::vector& vector); +template Status +SearchArguments::AddTargetVector(std::string field_name, const std::vector& vector); +template Status +SearchArguments::AddTargetVector(std::string field_name, const std::vector& vector); + +template Status -SearchArguments::AddTargetVector(std::string field_name, const std::string& vector) { - return AddTargetVector(std::move(field_name), std::string{vector}); -} - -Status -SearchArguments::AddTargetVector(std::string field_name, const std::vector& vector) { - return AddTargetVector(std::move(field_name), milvus::BinaryVecFieldData::CreateBinaryString(vector)); -} - -Status -SearchArguments::AddTargetVector(std::string field_name, std::string&& vector) { - if (float_vectors_ != nullptr) { - return {StatusCode::INVALID_AGUMENT, "Target vector must be float type!"}; - } - - if (nullptr == binary_vectors_) { - binary_vectors_ = std::make_shared(std::move(field_name)); +SearchArguments::DoAddTargetVector(std::string field_name, const T& vector) { + static_assert( + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v, + "Only FloatVecFieldData, BinaryVecFieldData, Float16VecFieldData or BFloat16VecFieldData is supported"); + if (nullptr == field_data_) { + field_data_ = std::make_shared(std::move(field_name)); } - auto code = binary_vectors_->Add(std::move(vector)); - if (code != StatusCode::OK) { - return {code, "Failed to add vector"}; + if (auto vectors = std::dynamic_pointer_cast(field_data_)) { + auto code = vectors->Add(vector); + if (code != StatusCode::OK) { + return {code, "Failed to add vector"}; + } + } else { + return {StatusCode::INVALID_AGUMENT, "Invalid vector type!"}; } return Status::OK(); } +template Status -SearchArguments::AddTargetVector(std::string field_name, const FloatVecFieldData::ElementT& vector) { - if (binary_vectors_ != nullptr) { - return {StatusCode::INVALID_AGUMENT, "Target vector must be binary type!"}; - } - - if (nullptr == float_vectors_) { - float_vectors_ = std::make_shared(std::move(field_name)); - } - - auto code = float_vectors_->Add(vector); - if (code != StatusCode::OK) { - return {code, "Failed to add vector"}; +SearchArguments::AddTargetVector(std::string field_name, const ElementT& vector) { + if constexpr (std::is_same_v> && + (std::is_same_v || + std::is_same_v || + std::is_same_v)) { + return DoAddTargetVector(std::move(field_name), + std::move(std::string{vector.begin(), vector.end()})); + } else { + return DoAddTargetVector(std::move(field_name), vector); } - - return Status::OK(); } +template Status -SearchArguments::AddTargetVector(std::string field_name, FloatVecFieldData::ElementT&& vector) { - if (binary_vectors_ != nullptr) { - return {StatusCode::INVALID_AGUMENT, "Target vector must be binary type!"}; - } - - if (nullptr == float_vectors_) { - float_vectors_ = std::make_shared(std::move(field_name)); +SearchArguments::AddTargetVector(std::string field_name, typename VecFieldDataT::ElementT&& vector) { + static_assert( + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v, + "Only FloatVecFieldData, BinaryVecFieldData, Float16VecFieldData or BFloat16VecFieldData is supported"); + if (nullptr == field_data_) { + field_data_ = std::make_shared(std::move(field_name)); } - auto code = float_vectors_->Add(std::move(vector)); - if (code != StatusCode::OK) { - return {code, "Failed to add vector"}; + if (auto vectors = std::dynamic_pointer_cast(field_data_)) { + auto code = vectors->Add(std::move(vector)); + if (code != StatusCode::OK) { + return {code, "Failed to add vector"}; + } + } else { + return {StatusCode::INVALID_AGUMENT, "Invalid vector type!"}; } return Status::OK(); diff --git a/src/include/milvus/types/CalcDistanceArguments.h b/src/include/milvus/types/CalcDistanceArguments.h index 2b3d5e2..4c407e0 100644 --- a/src/include/milvus/types/CalcDistanceArguments.h +++ b/src/include/milvus/types/CalcDistanceArguments.h @@ -47,6 +47,18 @@ class CalcDistanceArguments { Status SetLeftVectors(BinaryVecFieldDataPtr vectors); + /** + * @brief Set the float16 vectors on the left of operator, without field name. + */ + Status + SetLeftVectors(Float16VecFieldDataPtr vectors); + + /** + * @brief Set the bfloat16 vectors on the left of operator, without field name. + */ + Status + SetLeftVectors(BFloat16VecFieldDataPtr vectors); + /** * @brief Set id array of the vectors on the left of operator, must specify field name and collection name. * Partition names is optinal, to narrow down the query scope. @@ -80,6 +92,18 @@ class CalcDistanceArguments { Status SetRightVectors(BinaryVecFieldDataPtr vectors); + /** + * @brief Set the float16 vectors on the right of operator, without field name. + */ + Status + SetRightVectors(Float16VecFieldDataPtr vectors); + + /** + * @brief Set the bfloat16 vectors on the right of operator, without field name. + */ + Status + SetRightVectors(BFloat16VecFieldDataPtr vectors); + /** * @brief Set id array of the vectors on the right of operator, must specify field name and collection name. * Partition names is optinal, to narrow down the query scope. diff --git a/src/include/milvus/types/DataType.h b/src/include/milvus/types/DataType.h index ec98f5a..1b995fc 100644 --- a/src/include/milvus/types/DataType.h +++ b/src/include/milvus/types/DataType.h @@ -20,6 +20,7 @@ namespace milvus { /** * @brief Data type of field + * @refer https://github.com/milvus-io/milvus-proto/blob/master/proto/schema.proto */ enum class DataType { UNKNOWN = 0, @@ -39,6 +40,9 @@ enum class DataType { BINARY_VECTOR = 100, FLOAT_VECTOR = 101, + FLOAT16_VECTOR = 102, + BFLOAT16_VECTOR = 103, + SPARSE_FLOAT_VECTOR = 104, }; } // namespace milvus diff --git a/src/include/milvus/types/FieldData.h b/src/include/milvus/types/FieldData.h index bf827a3..09e5759 100644 --- a/src/include/milvus/types/FieldData.h +++ b/src/include/milvus/types/FieldData.h @@ -17,14 +17,18 @@ #pragma once #include +#include #include #include "../Status.h" #include "DataType.h" +#include "Eigen/Core" namespace milvus { class Field { public: + virtual ~Field() = default; + /** * @brief Get field name */ @@ -123,49 +127,37 @@ class FieldData : public Field { virtual std::vector& Data(); - private: - friend class BinaryVecFieldData; + protected: std::vector data_; }; -class BinaryVecFieldData : public FieldData { +template +class BinaryVecFieldDataImpl : public FieldData { public: /** * @brief Constructor */ - BinaryVecFieldData(); + BinaryVecFieldDataImpl(); /** * @brief Constructor */ - explicit BinaryVecFieldData(std::string name); + explicit BinaryVecFieldDataImpl(std::string name); /** * @brief Constructor */ - BinaryVecFieldData(std::string name, const std::vector& data); + BinaryVecFieldDataImpl(std::string name, const std::vector& data); /** * @brief Constructor */ - BinaryVecFieldData(std::string name, std::vector&& data); + BinaryVecFieldDataImpl(std::string name, std::vector&& data); /** * @brief Constructor */ - BinaryVecFieldData(std::string name, const std::vector>& data); - - /** - * @brief Field elements array - */ - const std::vector& - Data() const override; - - /** - * @brief Field elements array - */ - std::vector& - Data() override; + BinaryVecFieldDataImpl(std::string name, const std::vector>& data); /** * @brief Data export as uint8_t's vector @@ -197,12 +189,67 @@ class BinaryVecFieldData : public FieldData CreateBinaryStrings(const std::vector>& data); +}; + +template +class Fp16VecFieldData : public BinaryVecFieldDataImpl
{ + public: + /** + * @brief Constructor + */ + Fp16VecFieldData() : BinaryVecFieldDataImpl
() { + } + + /** + * @brief Constructor + */ + explicit Fp16VecFieldData(std::string name) : BinaryVecFieldDataImpl
(std::move(name)) { + } + + Fp16VecFieldData(std::string name, const std::vector& data) + : BinaryVecFieldDataImpl
(std::move(name), data) { + } + + Fp16VecFieldData(std::string name, std::vector&& data) + : BinaryVecFieldDataImpl
(std::move(name), std::move(data)) { + } + + Fp16VecFieldData(std::string name, const std::vector>& data) + : BinaryVecFieldDataImpl
(std::move(name), data) { + } /** - * @brief Create binary vector string from uint8_t vector + * @brief Constructor */ - static std::string - CreateBinaryString(const std::vector& data); + Fp16VecFieldData(std::string name, const std::vector>& data); + + Fp16VecFieldData(std::string name, const std::vector>& data); + + Fp16VecFieldData(std::string name, const std::vector>& data); + + Fp16VecFieldData(std::string name, std::vector>&& data); + + /** + * @brief Data export as T vector + */ + template + std::vector> + DataAsFloats() const; + + template + static std::vector + CreateBinaryStringsFromFloats(const std::vector>& data); + + using BinaryVecFieldDataImpl
::Add; + + StatusCode + Add(const std::vector& element); + + StatusCode + Add(const std::vector& element); + + StatusCode + Add(const std::vector& element); }; /** @@ -232,6 +279,9 @@ using FloatFieldData = FieldData; using DoubleFieldData = FieldData; using VarCharFieldData = FieldData; using FloatVecFieldData = FieldData, DataType::FLOAT_VECTOR>; +using BinaryVecFieldData = BinaryVecFieldDataImpl; +using Float16VecFieldData = Fp16VecFieldData; +using BFloat16VecFieldData = Fp16VecFieldData; using BoolFieldDataPtr = std::shared_ptr; using Int8FieldDataPtr = std::shared_ptr; @@ -243,6 +293,8 @@ using DoubleFieldDataPtr = std::shared_ptr; using VarCharFieldDataPtr = std::shared_ptr; using BinaryVecFieldDataPtr = std::shared_ptr; using FloatVecFieldDataPtr = std::shared_ptr; +using Float16VecFieldDataPtr = std::shared_ptr; +using BFloat16VecFieldDataPtr = std::shared_ptr; extern template class FieldData; extern template class FieldData; @@ -254,5 +306,24 @@ extern template class FieldData; extern template class FieldData; extern template class FieldData; extern template class FieldData, DataType::FLOAT_VECTOR>; - +extern template class FieldData; +extern template class BinaryVecFieldDataImpl; +extern template class BinaryVecFieldDataImpl; +extern template class BinaryVecFieldDataImpl; +extern template class Fp16VecFieldData; +extern template class Fp16VecFieldData; + +extern template std::vector> +Fp16VecFieldData::DataAsFloats() const; +extern template std::vector> +Fp16VecFieldData::DataAsFloats() const; +extern template std::vector> +Fp16VecFieldData::DataAsFloats() const; + +extern template std::vector> +Fp16VecFieldData::DataAsFloats() const; +extern template std::vector> +Fp16VecFieldData::DataAsFloats() const; +extern template std::vector> +Fp16VecFieldData::DataAsFloats() const; } // namespace milvus diff --git a/src/include/milvus/types/FloatUtils.h b/src/include/milvus/types/FloatUtils.h new file mode 100644 index 0000000..30d2e15 --- /dev/null +++ b/src/include/milvus/types/FloatUtils.h @@ -0,0 +1,65 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace Eigen { +struct half; +struct bfloat16; +} // namespace Eigen + +namespace milvus { + +template +std::vector +Float16NumVecBytesToFloatNumVec(const std::string& val); + +template +std::string +FloatNumVecToFloat16NumVecBytes(const std::vector& data); + +extern template std::vector +Float16NumVecBytesToFloatNumVec(const std::string& val); +extern template std::vector +Float16NumVecBytesToFloatNumVec(const std::string& val); +extern template std::vector +Float16NumVecBytesToFloatNumVec(const std::string& val); + +extern template std::vector +Float16NumVecBytesToFloatNumVec(const std::string& val); +extern template std::vector +Float16NumVecBytesToFloatNumVec(const std::string& val); +extern template std::vector +Float16NumVecBytesToFloatNumVec(const std::string& val); + +extern template std::string +FloatNumVecToFloat16NumVecBytes(const std::vector& data); +extern template std::string +FloatNumVecToFloat16NumVecBytes(const std::vector& data); +extern template std::string +FloatNumVecToFloat16NumVecBytes(const std::vector& data); + +extern template std::string +FloatNumVecToFloat16NumVecBytes(const std::vector& data); +extern template std::string +FloatNumVecToFloat16NumVecBytes(const std::vector& data); +extern template std::string +FloatNumVecToFloat16NumVecBytes(const std::vector& data); + +} // namespace milvus diff --git a/src/include/milvus/types/SearchArguments.h b/src/include/milvus/types/SearchArguments.h index eeede41..a794591 100644 --- a/src/include/milvus/types/SearchArguments.h +++ b/src/include/milvus/types/SearchArguments.h @@ -87,34 +87,18 @@ class SearchArguments { TargetVectors() const; /** - * @brief Add a binary vector to search + * @brief Add a vector to search */ + template Status - AddTargetVector(std::string field_name, const std::string& vector); + AddTargetVector(std::string field_name, const ElementT& vector); /** - * @brief Add a binary vector to search with uint8_t vectors + * @brief Add a vector to search */ + template Status - AddTargetVector(std::string field_name, const std::vector& vector); - - /** - * @brief Add a binary vector to search - */ - Status - AddTargetVector(std::string field_name, std::string&& vector); - - /** - * @brief Add a float vector to search - */ - Status - AddTargetVector(std::string field_name, const FloatVecFieldData::ElementT& vector); - - /** - * @brief Add a float vector to search - */ - Status - AddTargetVector(std::string field_name, FloatVecFieldData::ElementT&& vector); + AddTargetVector(std::string field_name, typename VecFieldDataT::ElementT&& vector); /** * @brief Get travel timestamp. @@ -249,13 +233,15 @@ class SearchArguments { RangeSearch() const; private: + template + Status + DoAddTargetVector(std::string field_name, const T& vector); + std::string collection_name_; std::set partition_names_; std::set output_field_names_; std::string filter_expression_; - - BinaryVecFieldDataPtr binary_vectors_; - FloatVecFieldDataPtr float_vectors_; + FieldDataPtr field_data_; std::set output_fields_; std::unordered_map extra_params_; @@ -272,4 +258,37 @@ class SearchArguments { ::milvus::MetricType metric_type_{::milvus::MetricType::L2}; }; +#define TemplateAddTargetVectorDeclaration(E, VecFieldDataT) \ + E template Status SearchArguments::AddTargetVector(std::string field_name, \ + const typename VecFieldDataT::ElementT& vector); \ + E template Status SearchArguments::AddTargetVector(std::string field_name, \ + typename VecFieldDataT::ElementT && vector); + +TemplateAddTargetVectorDeclaration(extern, FloatVecFieldData); +TemplateAddTargetVectorDeclaration(extern, BinaryVecFieldData); +TemplateAddTargetVectorDeclaration(extern, Float16VecFieldData); +TemplateAddTargetVectorDeclaration(extern, BFloat16VecFieldData); + +extern template Status +SearchArguments::AddTargetVector(std::string field_name, const std::vector& vector); + +extern template Status +SearchArguments::AddTargetVector(std::string field_name, const std::vector& vector); +extern template Status +SearchArguments::AddTargetVector(std::string field_name, const std::vector& vector); +extern template Status +SearchArguments::AddTargetVector(std::string field_name, const std::vector& vector); +extern template Status +SearchArguments::AddTargetVector(std::string field_name, const std::vector& vector); + +extern template Status +SearchArguments::AddTargetVector(std::string field_name, + const std::vector& vector); +extern template Status +SearchArguments::AddTargetVector(std::string field_name, const std::vector& vector); +extern template Status +SearchArguments::AddTargetVector(std::string field_name, const std::vector& vector); +extern template Status +SearchArguments::AddTargetVector(std::string field_name, const std::vector& vector); + } // namespace milvus diff --git a/test/it/TestInsert.cpp b/test/it/TestInsert.cpp index 779ed6b..856ac75 100644 --- a/test/it/TestInsert.cpp +++ b/test/it/TestInsert.cpp @@ -67,10 +67,18 @@ auto floats_field_ptr = std::make_shared<::milvus::FloatVecFieldData>( "floats", std::vector>{ {0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}}); +auto float16s_field_ptr = std::make_shared<::milvus::Float16VecFieldData>( + "float16s", + std::vector>{ + {0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}}); +auto bfloat16_field_ptr = std::make_shared<::milvus::BFloat16VecFieldData>( + "bfloat16s", + std::vector>{ + {0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}}); -std::vector fields{bool_field_ptr, int8_field_ptr, int16_field_ptr, int32_field_ptr, - int64_field_ptr, float_field_ptr, double_field_ptr, string_field_ptr, - bins_field_ptr, floats_field_ptr}; +std::vector fields{bool_field_ptr, int8_field_ptr, int16_field_ptr, int32_field_ptr, + int64_field_ptr, float_field_ptr, double_field_ptr, string_field_ptr, + bins_field_ptr, floats_field_ptr, float16s_field_ptr, bfloat16_field_ptr}; } // namespace TEST_F(MilvusMockedTest, InsertFoo) { @@ -79,17 +87,18 @@ TEST_F(MilvusMockedTest, InsertFoo) { std::vector ret_ids{1000, 10001, 1002, 1003}; - EXPECT_CALL(service_, - Insert(_, - AllOf(Property(&InsertRequest::collection_name, collection), - Property(&InsertRequest::partition_name, partition), - Property(&InsertRequest::num_rows, num_of_rows), - Property(&InsertRequest::fields_data_size, fields.size()), - Property(&InsertRequest::fields_data, - ElementsAre(*bool_field_ptr, *int8_field_ptr, *int16_field_ptr, *int32_field_ptr, - *int64_field_ptr, *float_field_ptr, *double_field_ptr, - *string_field_ptr, *bins_field_ptr, *floats_field_ptr))), - _)) + EXPECT_CALL( + service_, + Insert( + _, + AllOf(Property(&InsertRequest::collection_name, collection), + Property(&InsertRequest::partition_name, partition), Property(&InsertRequest::num_rows, num_of_rows), + Property(&InsertRequest::fields_data_size, fields.size()), + Property(&InsertRequest::fields_data, + ElementsAre(*bool_field_ptr, *int8_field_ptr, *int16_field_ptr, *int32_field_ptr, + *int64_field_ptr, *float_field_ptr, *double_field_ptr, *string_field_ptr, + *bins_field_ptr, *floats_field_ptr, *float16s_field_ptr, *bfloat16_field_ptr))), + _)) .WillOnce([&ret_ids](::grpc::ServerContext*, const InsertRequest* request, MutationResult* response) { // ret for (const auto ret_id : ret_ids) { diff --git a/test/it/TestSearch.cpp b/test/it/TestSearch.cpp index 3771b3c..2e87eb2 100644 --- a/test/it/TestSearch.cpp +++ b/test/it/TestSearch.cpp @@ -53,10 +53,10 @@ operator==(const milvus::proto::common::KeyValuePair& lhs, const TestKv& rhs) { } } // namespace -template +template milvus::Status DoSearchVectors(testing::StrictMock<::milvus::MilvusMockedService>& service_, - std::shared_ptr<::milvus::MilvusClient>& client_, std::vector vectors, + std::shared_ptr<::milvus::MilvusClient>& client_, std::vector vectors, milvus::SearchResults& search_results, int simulate_timeout = 0, int search_timeout = 0) { milvus::SearchArguments search_arguments{}; search_arguments.SetCollectionName("foo"); @@ -66,7 +66,7 @@ DoSearchVectors(testing::StrictMock<::milvus::MilvusMockedService>& service_, search_arguments.AddOutputField("f2"); search_arguments.SetExpression("dummy expression"); for (const auto& vec : vectors) { - search_arguments.AddTargetVector("anns_dummy", vec); + search_arguments.AddTargetVector("anns_dummy", vec); } search_arguments.SetTravelTimestamp(10000); search_arguments.SetGuaranteeTimestamp(milvus::GuaranteeStrongTs()); @@ -125,7 +125,7 @@ DoSearchVectors(testing::StrictMock<::milvus::MilvusMockedService>& service_, for (int i = 0; i < vectors.size(); ++i) { const auto& vector = vectors.at(i); const auto& placeholder = placeholders.values(i); - T test_vector(vector.size()); + VecT test_vector(vector.size()); std::copy_n(placeholder.data(), placeholder.size(), reinterpret_cast(test_vector.data())); EXPECT_EQ(test_vector, vector); } @@ -176,13 +176,13 @@ DoSearchVectors(testing::StrictMock<::milvus::MilvusMockedService>& service_, return client_->Search(search_arguments, search_results, search_timeout); } -template +template void TestSearchVectors(testing::StrictMock<::milvus::MilvusMockedService>& service_, - std::shared_ptr<::milvus::MilvusClient>& client_, std::vector vectors) { + std::shared_ptr<::milvus::MilvusClient>& client_, std::vector vectors) { milvus::SearchResults search_results{}; - auto status = DoSearchVectors(service_, client_, vectors, search_results); + auto status = DoSearchVectors(service_, client_, vectors, search_results); EXPECT_TRUE(status.IsOk()); auto& results = search_results.Results(); EXPECT_EQ(results.size(), 2); @@ -204,14 +204,30 @@ TestSearchVectors(testing::StrictMock<::milvus::MilvusMockedService>& service_, TEST_F(MilvusMockedTest, SearchFoo) { milvus::ConnectParam connect_param{"127.0.0.1", server_.ListenPort()}; client_->Connect(connect_param); - - std::vector> float_vectors = {std::vector{0.1f, 0.2f, 0.3f, 0.4f}, - std::vector{0.2f, 0.3f, 0.4f, 0.5f}}; - TestSearchVectors>(service_, client_, float_vectors); - - std::vector> bin_vectors = {std::vector{1, 2, 3, 4}, - std::vector{2, 3, 4, 5}}; - TestSearchVectors>(service_, client_, bin_vectors); + { + std::vector> float_vectors = {std::vector{0.1f, 0.2f, 0.3f, 0.4f}, + std::vector{0.2f, 0.3f, 0.4f, 0.5f}}; + TestSearchVectors(service_, client_, float_vectors); + } + { + std::vector> float16_vectors = { + std::vector{Eigen::half(0.1f), Eigen::half(0.2f), Eigen::half(0.3f), Eigen::half(0.4f)}, + std::vector{Eigen::half(0.2f), Eigen::half(0.3f), Eigen::half(0.4f), Eigen::half(0.5f)}}; + TestSearchVectors(service_, client_, float16_vectors); + } + { + std::vector> bfloat16_vectors = { + std::vector{Eigen::bfloat16(0.1f), Eigen::bfloat16(0.2f), Eigen::bfloat16(0.3f), + Eigen::bfloat16(0.4f)}, + std::vector{Eigen::bfloat16(0.2f), Eigen::bfloat16(0.3f), Eigen::bfloat16(0.4f), + Eigen::bfloat16(0.5f)}}; + TestSearchVectors(service_, client_, bfloat16_vectors); + } + { + std::vector> bin_vectors = {std::vector{1, 2, 3, 4}, + std::vector{2, 3, 4, 5}}; + TestSearchVectors(service_, client_, bin_vectors); + } } TEST_F(MilvusMockedTest, SearchFooWithTimeoutExpired) { @@ -223,7 +239,8 @@ TEST_F(MilvusMockedTest, SearchFooWithTimeoutExpired) { milvus::SearchResults search_results{}; auto t0 = std::chrono::system_clock::now(); - auto status = DoSearchVectors(service_, client_, float_vectors, search_results, 1000, 500); + auto status = + DoSearchVectors(service_, client_, float_vectors, search_results, 1000, 500); auto duration = std::chrono::duration_cast(std::chrono::system_clock::now() - t0).count(); @@ -241,7 +258,8 @@ TEST_F(MilvusMockedTest, SearchFooWithTimeoutOk) { milvus::SearchResults search_results{}; auto t0 = std::chrono::system_clock::now(); - auto status = DoSearchVectors(service_, client_, float_vectors, search_results, 500, 1000); + auto status = + DoSearchVectors(service_, client_, float_vectors, search_results, 500, 1000); auto duration = std::chrono::duration_cast(std::chrono::system_clock::now() - t0).count(); @@ -277,7 +295,7 @@ TEST_F(MilvusMockedTest, SearchWithMismatchedAnnsField) { search_arguments.AddOutputField("f2"); search_arguments.SetExpression("dummy expression"); for (const auto& floats : floats_vec) { - search_arguments.AddTargetVector("anns_dummy", floats); + search_arguments.AddTargetVector("anns_dummy", floats); } search_arguments.SetTravelTimestamp(10000); search_arguments.SetGuaranteeTimestamp(10001); diff --git a/test/st/PythonMilvusServer.cpp b/test/st/PythonMilvusServer.cpp index 78f356f..cc5a88c 100644 --- a/test/st/PythonMilvusServer.cpp +++ b/test/st/PythonMilvusServer.cpp @@ -52,6 +52,9 @@ PythonMilvusServer::SetTls(int mode, const std::string& server_cert, const std:: void PythonMilvusServer::Start() { + if (Started()) { + return; + } // install command std::string cmd = std::string("pip3 install ") + kPythonMilvusServerVersion; auto ret = system(cmd.c_str()); @@ -63,10 +66,14 @@ PythonMilvusServer::Start() { cmd = "rm -fr " + base_dir_; system(cmd.c_str()); thread_ = std::thread([this]() { this->run(); }); + started_ = true; } void PythonMilvusServer::Stop() { + if (!Started()) { + return; + } if (pid_) { kill(pid_, SIGINT); } @@ -75,6 +82,7 @@ PythonMilvusServer::Stop() { } // sleep 5s to wait for release port std::this_thread::sleep_for(std::chrono::seconds(5)); + started_ = false; } void @@ -129,4 +137,4 @@ PythonMilvusServer::TestClientParam() const { return param; } } // namespace test -} // namespace milvus \ No newline at end of file +} // namespace milvus diff --git a/test/st/PythonMilvusServer.h b/test/st/PythonMilvusServer.h index 776f54e..8d882b0 100644 --- a/test/st/PythonMilvusServer.h +++ b/test/st/PythonMilvusServer.h @@ -42,6 +42,7 @@ class PythonMilvusServer { int status_{0}; int pid_{0}; + bool started_{false}; void run(); @@ -58,6 +59,11 @@ class PythonMilvusServer { void Start(); + bool + Started() const { + return started_; + } + void Stop(); @@ -69,4 +75,4 @@ class PythonMilvusServer { }; } // namespace test -} // namespace milvus \ No newline at end of file +} // namespace milvus diff --git a/test/st/TestCollection.cpp b/test/st/TestCollection.cpp index e960ad3..fdd0c81 100644 --- a/test/st/TestCollection.cpp +++ b/test/st/TestCollection.cpp @@ -16,66 +16,90 @@ #include "MilvusServerTest.h" -using milvus::test::MilvusServerTestWithParam; +using milvus::test::MilvusServerTest; -using MilvusServerTestCollection = MilvusServerTestWithParam; +struct Params { + bool using_string_primary_key; + milvus::DataType vector_data_type; +}; -TEST_P(MilvusServerTestCollection, CreateAndDeleteCollection) { - auto using_string_primary_key = GetParam(); - milvus::ConnectParam connect_param{"127.0.0.1", server_.ListenPort()}; - client_->Connect(connect_param); +class MilvusServerTestCollection : public MilvusServerTest { + protected: + void + SetUp() override { + MilvusServerTest::SetUp(); + milvus::ConnectParam connect_param{"127.0.0.1", server_.ListenPort()}; + client_->Connect(connect_param); + } - milvus::CollectionSchema collection_schema("Foo"); - if (using_string_primary_key) { - collection_schema.AddField( - // string as primary key, no auto-id - milvus::FieldSchema("name", milvus::DataType::VARCHAR, "name", true, false).WithMaxLength(64)); - } else { - collection_schema.AddField(milvus::FieldSchema("id", milvus::DataType::INT64, "id", true, true)); - collection_schema.AddField(milvus::FieldSchema("name", milvus::DataType::VARCHAR, "name").WithMaxLength(64)); + void + TearDown() override { + client_->Disconnect(); + MilvusServerTest::TearDown(); } - collection_schema.AddField(milvus::FieldSchema("age", milvus::DataType::INT16, "age")); - collection_schema.AddField( - milvus::FieldSchema("face", milvus::DataType::FLOAT_VECTOR, "face signature").WithDimension(1024)); - auto status = client_->CreateCollection(collection_schema); - EXPECT_EQ(status.Message(), "OK"); - EXPECT_TRUE(status.IsOk()); + void + DoCreateAndDeleteCollection(bool using_string_primary_key, milvus::DataType vector_data_type) { + milvus::CollectionSchema collection_schema("Foo"); + if (using_string_primary_key) { + collection_schema.AddField( + // string as primary key, no auto-id + milvus::FieldSchema("name", milvus::DataType::VARCHAR, "name", true, false).WithMaxLength(64)); + } else { + collection_schema.AddField(milvus::FieldSchema("id", milvus::DataType::INT64, "id", true, true)); + collection_schema.AddField( + milvus::FieldSchema("name", milvus::DataType::VARCHAR, "name").WithMaxLength(64)); + } + collection_schema.AddField(milvus::FieldSchema("age", milvus::DataType::INT16, "age")); + collection_schema.AddField(milvus::FieldSchema("face", vector_data_type, "face signature").WithDimension(1024)); - // create index needed after 2.2.0 - milvus::IndexDesc index_desc("face", "", milvus::IndexType::FLAT, milvus::MetricType::L2, 0); - status = client_->CreateIndex("Foo", index_desc); - EXPECT_TRUE(status.IsOk()); + auto status = client_->CreateCollection(collection_schema); + EXPECT_EQ(status.Message(), "OK"); + EXPECT_TRUE(status.IsOk()) << status.Message(); - // test for https://github.com/milvus-io/milvus-sdk-cpp/issues/188 - std::vector names; - std::vector collection_infos; - status = client_->ShowCollections(names, collection_infos); - EXPECT_TRUE(status.IsOk()); - EXPECT_EQ(collection_infos.size(), 1); - EXPECT_EQ(collection_infos.front().MemoryPercentage(), 0); - EXPECT_EQ(collection_infos.front().Name(), "Foo"); + // create index needed after 2.2.0 + milvus::IndexDesc index_desc("face", "", milvus::IndexType::FLAT, milvus::MetricType::L2, 0); + status = client_->CreateIndex("Foo", index_desc); + EXPECT_TRUE(status.IsOk()) << status.Message(); - // test for https://github.com/milvus-io/milvus-sdk-cpp/issues/246 - milvus::PartitionsInfo partitionsInfo{}; - status = client_->ShowPartitions("Foo", std::vector{}, partitionsInfo); - EXPECT_TRUE(status.IsOk()); + // test for https://github.com/milvus-io/milvus-sdk-cpp/issues/188 + std::vector names; + std::vector collection_infos; + status = client_->ShowCollections(names, collection_infos); + EXPECT_TRUE(status.IsOk()); + EXPECT_EQ(collection_infos.size(), 1); + EXPECT_EQ(collection_infos.front().MemoryPercentage(), 0); + EXPECT_EQ(collection_infos.front().Name(), "Foo"); - names.emplace_back("Foo"); - collection_infos.clear(); - status = client_->LoadCollection("Foo"); - EXPECT_TRUE(status.IsOk()); + // test for https://github.com/milvus-io/milvus-sdk-cpp/issues/246 + milvus::PartitionsInfo partitionsInfo{}; + status = client_->ShowPartitions("Foo", std::vector{}, partitionsInfo); + EXPECT_TRUE(status.IsOk()); - status = client_->ShowCollections(names, collection_infos); - EXPECT_TRUE(status.IsOk()); - EXPECT_EQ(collection_infos.size(), 1); - EXPECT_EQ(collection_infos.front().MemoryPercentage(), 100); + names.emplace_back("Foo"); + collection_infos.clear(); + status = client_->LoadCollection("Foo"); + EXPECT_TRUE(status.IsOk()); - status = client_->RenameCollection("Foo", "Bar"); - EXPECT_TRUE(status.IsOk()); + status = client_->ShowCollections(names, collection_infos); + EXPECT_TRUE(status.IsOk()); + EXPECT_EQ(collection_infos.size(), 1); + EXPECT_EQ(collection_infos.front().MemoryPercentage(), 100); - status = client_->DropCollection("Bar"); - EXPECT_TRUE(status.IsOk()); -} + status = client_->RenameCollection("Foo", "Bar"); + EXPECT_TRUE(status.IsOk()); + + status = client_->DropCollection("Bar"); + EXPECT_TRUE(status.IsOk()); + } +}; -INSTANTIATE_TEST_SUITE_P(SystemTest, MilvusServerTestCollection, ::testing::Values(false, true)); +TEST_F(MilvusServerTestCollection, CreateAndDeleteCollection) { + for (auto using_string_primary_key : {false, true}) { + // TODO: create index on bfloat16 vector is not supported in milvus-lite 2.3.x, + // add BFLOAT16_VECTOR in future + for (auto vector_data_type : {milvus::DataType::FLOAT_VECTOR, milvus::DataType::FLOAT16_VECTOR}) { + DoCreateAndDeleteCollection(using_string_primary_key, vector_data_type); + } + } +} diff --git a/test/st/TestSearch.cpp b/test/st/TestSearch.cpp index c41406a..ca060b5 100644 --- a/test/st/TestSearch.cpp +++ b/test/st/TestSearch.cpp @@ -39,14 +39,14 @@ class MilvusServerTestSearch : public MilvusServerTest { MilvusServerTest::TearDown(); } + template void createCollectionAndPartitions(bool create_flat_index) { milvus::CollectionSchema collection_schema(collection_name); collection_schema.AddField(milvus::FieldSchema("id", milvus::DataType::INT64, "id", true, true)); collection_schema.AddField(milvus::FieldSchema("age", milvus::DataType::INT16, "age")); collection_schema.AddField(milvus::FieldSchema("name", milvus::DataType::VARCHAR, "name").WithMaxLength(64)); - collection_schema.AddField( - milvus::FieldSchema("face", milvus::DataType::FLOAT_VECTOR, "face signature").WithDimension(4)); + collection_schema.AddField(milvus::FieldSchema("face", VECTOR_DT, "face signature").WithDimension(4)); auto status = client_->CreateCollection(collection_schema); EXPECT_EQ(status.Message(), "OK"); @@ -86,57 +86,65 @@ class MilvusServerTestSearch : public MilvusServerTest { auto status = client_->DropCollection(collection_name); EXPECT_TRUE(status.IsOk()); } -}; - -TEST_F(MilvusServerTestSearch, SearchWithoutIndex) { - std::vector fields{ - std::make_shared("age", std::vector{12, 13}), - std::make_shared("name", std::vector{"Tom", "Jerry"}), - std::make_shared( - "face", std::vector>{std::vector{0.1f, 0.2f, 0.3f, 0.4f}, - std::vector{0.5f, 0.6f, 0.7f, 0.8f}})}; - - createCollectionAndPartitions(true); - auto dml_results = insertRecords(fields); - loadCollection(); - - milvus::SearchArguments arguments{}; - arguments.SetCollectionName(collection_name); - arguments.AddPartitionName(partition_name); - arguments.SetTopK(10); - arguments.AddOutputField("age"); - arguments.AddOutputField("name"); - arguments.SetExpression("id > 0"); - arguments.AddTargetVector("face", std::vector{0.f, 0.f, 0.f, 0.f}); - arguments.AddTargetVector("face", std::vector{1.f, 1.f, 1.f, 1.f}); - milvus::SearchResults search_results{}; - auto status = client_->Search(arguments, search_results); - EXPECT_EQ(status.Message(), "OK"); - EXPECT_TRUE(status.IsOk()); - const auto& results = search_results.Results(); - EXPECT_EQ(results.size(), 2); - EXPECT_THAT(results.at(0).Ids().IntIDArray(), UnorderedElementsAreArray(dml_results.IdArray().IntIDArray())); - EXPECT_THAT(results.at(1).Ids().IntIDArray(), UnorderedElementsAreArray(dml_results.IdArray().IntIDArray())); - - EXPECT_EQ(results.at(0).Scores().size(), 2); - EXPECT_EQ(results.at(1).Scores().size(), 2); + template + void + TestSearchWithoutIndex() { + std::vector fields{ + std::make_shared("age", std::vector{12, 13}), + std::make_shared("name", std::vector{"Tom", "Jerry"}), + std::make_shared( + "face", std::vector>{std::vector{0.1f, 0.2f, 0.3f, 0.4f}, + std::vector{0.5f, 0.6f, 0.7f, 0.8f}})}; + + createCollectionAndPartitions(true); + auto dml_results = insertRecords(fields); + loadCollection(); + + milvus::SearchArguments arguments{}; + arguments.SetCollectionName(collection_name); + arguments.AddPartitionName(partition_name); + arguments.SetTopK(10); + arguments.AddOutputField("age"); + arguments.AddOutputField("name"); + arguments.SetExpression("id > 0"); + arguments.AddTargetVector("face", std::vector{0.1f, 0.2f, 0.2f, 0.4f}); + arguments.AddTargetVector("face", std::vector{0.5f, 0.6f, 0.7f, 1.f}); + milvus::SearchResults search_results{}; + auto status = client_->Search(arguments, search_results); + EXPECT_EQ(status.Message(), "OK"); + EXPECT_TRUE(status.IsOk()); - EXPECT_LT(results.at(0).Scores().at(0), results.at(0).Scores().at(1)); - EXPECT_LT(results.at(1).Scores().at(0), results.at(1).Scores().at(1)); + const auto& results = search_results.Results(); + EXPECT_EQ(results.size(), 2); + EXPECT_THAT(results.at(0).Ids().IntIDArray(), UnorderedElementsAreArray(dml_results.IdArray().IntIDArray())); + EXPECT_THAT(results.at(1).Ids().IntIDArray(), UnorderedElementsAreArray(dml_results.IdArray().IntIDArray())); + + EXPECT_EQ(results.at(0).Scores().size(), 2); + EXPECT_EQ(results.at(1).Scores().size(), 2); + + EXPECT_LT(results.at(0).Scores().at(0), results.at(0).Scores().at(1)); + EXPECT_LT(results.at(1).Scores().at(0), results.at(1).Scores().at(1)); + + // match fields: age + EXPECT_EQ(results.at(0).OutputFields().size(), 2); + EXPECT_EQ(results.at(1).OutputFields().size(), 2); + EXPECT_THAT(dynamic_cast(*results.at(0).OutputField("age")).Data(), + UnorderedElementsAre(12, 13)); + EXPECT_THAT(dynamic_cast(*results.at(1).OutputField("age")).Data(), + UnorderedElementsAre(12, 13)); + EXPECT_THAT(dynamic_cast(*results.at(0).OutputField("name")).Data(), + UnorderedElementsAre("Tom", "Jerry")); + EXPECT_THAT(dynamic_cast(*results.at(1).OutputField("name")).Data(), + UnorderedElementsAre("Tom", "Jerry")); + dropCollection(); + } +}; - // match fields: age - EXPECT_EQ(results.at(0).OutputFields().size(), 2); - EXPECT_EQ(results.at(1).OutputFields().size(), 2); - EXPECT_THAT(dynamic_cast(*results.at(0).OutputField("age")).Data(), - UnorderedElementsAre(12, 13)); - EXPECT_THAT(dynamic_cast(*results.at(1).OutputField("age")).Data(), - UnorderedElementsAre(12, 13)); - EXPECT_THAT(dynamic_cast(*results.at(0).OutputField("name")).Data(), - UnorderedElementsAre("Tom", "Jerry")); - EXPECT_THAT(dynamic_cast(*results.at(1).OutputField("name")).Data(), - UnorderedElementsAre("Tom", "Jerry")); - dropCollection(); +TEST_F(MilvusServerTestSearch, SearchWithoutIndex) { + TestSearchWithoutIndex(); + TestSearchWithoutIndex(); + // TestSearchWithoutIndex(); } TEST_F(MilvusServerTestSearch, RangeSearch) { @@ -154,7 +162,7 @@ TEST_F(MilvusServerTestSearch, RangeSearch) { std::vector{0.7f, 0.8f, 0.9f, 1.0f}, })}; - createCollectionAndPartitions(true); + createCollectionAndPartitions(true); auto dml_results = insertRecords(fields); loadCollection(); @@ -165,8 +173,8 @@ TEST_F(MilvusServerTestSearch, RangeSearch) { arguments.SetTopK(10); arguments.AddOutputField("age"); arguments.AddOutputField("name"); - arguments.AddTargetVector("face", std::vector{0.f, 0.f, 0.f, 0.f}); - arguments.AddTargetVector("face", std::vector{1.f, 1.f, 1.f, 1.f}); + arguments.AddTargetVector("face", std::vector{0.f, 0.f, 0.f, 0.f}); + arguments.AddTargetVector("face", std::vector{1.f, 1.f, 1.f, 1.f}); milvus::SearchResults search_results{}; auto status = client_->Search(arguments, search_results); EXPECT_EQ(status.Message(), "OK"); @@ -217,7 +225,7 @@ TEST_F(MilvusServerTestSearch, SearchWithStringFilter) { "face", std::vector>{std::vector{0.1f, 0.2f, 0.3f, 0.4f}, std::vector{0.5f, 0.6f, 0.7f, 0.8f}})}; - createCollectionAndPartitions(true); + createCollectionAndPartitions(true); auto dml_results = insertRecords(fields); loadCollection(); @@ -228,8 +236,8 @@ TEST_F(MilvusServerTestSearch, SearchWithStringFilter) { arguments.AddOutputField("age"); arguments.AddOutputField("name"); arguments.SetExpression("name like \"To%\""); // Tom match To% - arguments.AddTargetVector("face", std::vector{0.f, 0.f, 0.f, 0.f}); - arguments.AddTargetVector("face", std::vector{1.f, 1.f, 1.f, 1.f}); + arguments.AddTargetVector("face", std::vector{0.f, 0.f, 0.f, 0.f}); + arguments.AddTargetVector("face", std::vector{1.f, 1.f, 1.f, 1.f}); milvus::SearchResults search_results{}; auto status = client_->Search(arguments, search_results); EXPECT_EQ(status.Message(), "OK"); @@ -274,7 +282,7 @@ TEST_F(MilvusServerTestSearch, SearchWithIVFIndex) { std::make_shared("name", names), std::make_shared("face", faces)}; - createCollectionAndPartitions(false); + createCollectionAndPartitions(false); auto dml_results = insertRecords(fields); milvus::IndexDesc index_desc("face", "", milvus::IndexType::IVF_FLAT, milvus::MetricType::L2, 0); @@ -290,8 +298,8 @@ TEST_F(MilvusServerTestSearch, SearchWithIVFIndex) { arguments.SetTopK(10); arguments.SetMetricType(milvus::MetricType::L2); arguments.AddExtraParam("nprobe", 10); - arguments.AddTargetVector("face", std::vector{0.f, 0.f, 0.f, 0.f}); - arguments.AddTargetVector("face", std::vector{1.f, 1.f, 1.f, 1.f}); + arguments.AddTargetVector("face", std::vector{0.f, 0.f, 0.f, 0.f}); + arguments.AddTargetVector("face", std::vector{1.f, 1.f, 1.f, 1.f}); milvus::SearchResults search_results{}; status = client_->Search(arguments, search_results); EXPECT_EQ(status.Message(), "OK"); diff --git a/test/st/TestSearchWithBinaryVectors.cpp b/test/st/TestSearchWithBinaryVectors.cpp index 17a2d35..2141da1 100644 --- a/test/st/TestSearchWithBinaryVectors.cpp +++ b/test/st/TestSearchWithBinaryVectors.cpp @@ -22,10 +22,11 @@ using milvus::test::MilvusServerTest; using testing::UnorderedElementsAre; using testing::UnorderedElementsAreArray; -class MilvusServerTestSearchWithBinaryVectors : public MilvusServerTest { +class MilvusServerTestSearchWithVectors : public MilvusServerTest { protected: std::string collection_name{"Foo"}; std::string partition_name{"Bar"}; + static constexpr int DIMENSION = 32; void SetUp() override { @@ -36,16 +37,17 @@ class MilvusServerTestSearchWithBinaryVectors : public MilvusServerTest { void TearDown() override { + client_->Disconnect(); MilvusServerTest::TearDown(); } void - createCollectionAndPartitions() { + createCollectionAndPartitions(milvus::DataType vector_data_type) { milvus::CollectionSchema collection_schema(collection_name); collection_schema.AddField(milvus::FieldSchema("id", milvus::DataType::INT64, "id", true, true)); collection_schema.AddField(milvus::FieldSchema("age", milvus::DataType::INT16, "age")); collection_schema.AddField( - milvus::FieldSchema("face", milvus::DataType::BINARY_VECTOR, "face signature").WithDimension(32)); + milvus::FieldSchema("face", vector_data_type, "face signature").WithDimension(DIMENSION)); auto status = client_->CreateCollection(collection_schema); EXPECT_EQ(status.Message(), "OK"); @@ -81,7 +83,7 @@ class MilvusServerTestSearchWithBinaryVectors : public MilvusServerTest { }; // for issue #194 -TEST_F(MilvusServerTestSearchWithBinaryVectors, RegressionIssue194) { +TEST_F(MilvusServerTestSearchWithVectors, RegressionIssue194) { std::mt19937 rng(std::random_device{}()); std::uniform_int_distribution age_gen{10, 30}; std::uniform_int_distribution face_gen{0, 255}; @@ -96,7 +98,7 @@ TEST_F(MilvusServerTestSearchWithBinaryVectors, RegressionIssue194) { std::vector fields{std::make_shared("age", ages), std::make_shared("face", faces)}; - createCollectionAndPartitions(); + createCollectionAndPartitions(milvus::DataType::BINARY_VECTOR); auto dml_results = insertRecords(fields); milvus::IndexDesc index_desc("face", "", milvus::IndexType::BIN_FLAT, milvus::MetricType::HAMMING, 0); @@ -110,8 +112,54 @@ TEST_F(MilvusServerTestSearchWithBinaryVectors, RegressionIssue194) { arguments.SetCollectionName(collection_name); arguments.SetTopK(10); arguments.SetMetricType(milvus::MetricType::HAMMING); - arguments.AddTargetVector("face", std::vector{255, 255, 255, 255}); - arguments.AddTargetVector("face", std::vector{0, 0, 0, 0}); + arguments.AddTargetVector("face", std::vector{255, 255, 255, 255}); + arguments.AddTargetVector("face", std::vector{0, 0, 0, 0}); + milvus::SearchResults search_results{}; + status = client_->Search(arguments, search_results); + EXPECT_EQ(status.Message(), "OK"); + EXPECT_TRUE(status.IsOk()); + + const auto& results = search_results.Results(); + EXPECT_EQ(results.size(), 2); + + EXPECT_EQ(results.at(0).Scores().size(), 10); + EXPECT_EQ(results.at(1).Scores().size(), 10); + + dropCollection(); +} + +// milvus lite does not support bfloat16 vector +TEST_F(MilvusServerTestSearchWithVectors, Float16Vector) { + std::mt19937 rng(std::random_device{}()); + std::uniform_int_distribution age_gen{10, 30}; + std::uniform_real_distribution face_gen{0, 255}; + size_t test_count = 10; + std::vector ages{}; + std::vector> faces{}; + for (auto i = test_count; i > 0; --i) { + ages.emplace_back(age_gen(rng)); + faces.emplace_back(std::vector(DIMENSION, face_gen(rng))); + } + auto faces_field = std::make_shared("face", faces); + ASSERT_EQ(faces_field->DataAsFloats()[0].size(), DIMENSION); + std::vector fields{std::make_shared("age", ages), faces_field}; + + createCollectionAndPartitions(milvus::DataType::FLOAT16_VECTOR); + auto dml_results = insertRecords(fields); + + milvus::IndexDesc index_desc("face", "", milvus::IndexType::FLAT, milvus::MetricType::L2, 0); + auto status = client_->CreateIndex(collection_name, index_desc); + EXPECT_EQ(status.Message(), "OK"); + EXPECT_TRUE(status.IsOk()); + + loadCollection(); + + milvus::SearchArguments arguments{}; + arguments.SetCollectionName(collection_name); + arguments.SetTopK(10); + arguments.SetMetricType(milvus::MetricType::L2); + arguments.AddTargetVector("face", std::vector(DIMENSION, 255)); + arguments.AddTargetVector("face", std::vector(DIMENSION, 0)); milvus::SearchResults search_results{}; status = client_->Search(arguments, search_results); EXPECT_EQ(status.Message(), "OK"); diff --git a/test/ut/TestCalcDistanceArguments.cpp b/test/ut/TestCalcDistanceArguments.cpp index ee7d519..7923149 100644 --- a/test/ut/TestCalcDistanceArguments.cpp +++ b/test/ut/TestCalcDistanceArguments.cpp @@ -58,23 +58,25 @@ TEST_F(CalcDistanceArgumentsTest, GeneralTesting) { EXPECT_EQ(sqrt, arguments.Sqrt()); } -TEST_F(CalcDistanceArgumentsTest, FloatVectors) { - milvus::FloatVecFieldDataPtr vectors_1 = std::make_shared(); - std::vector element_1 = {1.0, 2.0}; - std::vector element_2 = {3.0, 4.0}; +template +static void +DoTestFloatVectors() { + auto vectors_1 = std::make_shared(); + std::vector element_1 = {FloatT(1.0), FloatT(2.0)}; + std::vector element_2 = {FloatT(3.0), FloatT(4.0)}; vectors_1->Add(element_1); vectors_1->Add(element_2); - milvus::FloatVecFieldDataPtr vectors_2 = std::make_shared(); - std::vector element_3 = {1.0, 2.0}; - std::vector element_4 = {3.0, 4.0}; - std::vector element_5 = {5.0, 6.0}; + auto vectors_2 = std::make_shared(); + std::vector element_3 = {FloatT(1.0), FloatT(2.0)}; + std::vector element_4 = {FloatT(3.0), FloatT(4.0)}; + std::vector element_5 = {FloatT(5.0), FloatT(6.0)}; vectors_2->Add(element_3); vectors_2->Add(element_4); vectors_2->Add(element_5); milvus::CalcDistanceArguments arguments; - milvus::FloatVecFieldDataPtr vectors_3 = nullptr; + decltype(vectors_1) vectors_3 = nullptr; auto status = arguments.SetLeftVectors(std::move(vectors_3)); EXPECT_FALSE(status.IsOk()); status = arguments.SetRightVectors(std::move(vectors_3)); @@ -90,6 +92,16 @@ TEST_F(CalcDistanceArgumentsTest, FloatVectors) { EXPECT_EQ(arguments.RightVectors()->Count(), 3); } +TEST_F(CalcDistanceArgumentsTest, FloatVectors) { + DoTestFloatVectors(); + DoTestFloatVectors(); + DoTestFloatVectors(); + DoTestFloatVectors(); + DoTestFloatVectors(); + DoTestFloatVectors(); + DoTestFloatVectors(); +} + TEST_F(CalcDistanceArgumentsTest, BinaryVectors) { milvus::BinaryVecFieldDataPtr vectors_1 = std::make_shared(); std::vector element_1 = {1, 2}; diff --git a/test/ut/TestCollectionSchema.cpp b/test/ut/TestCollectionSchema.cpp index d04d376..a0c75a7 100644 --- a/test/ut/TestCollectionSchema.cpp +++ b/test/ut/TestCollectionSchema.cpp @@ -32,8 +32,17 @@ TEST_F(CollectionSchemaTest, GeneralTesting) { EXPECT_TRUE(schema.AddField(milvus::FieldSchema("bar", milvus::DataType::FLOAT_VECTOR, "bar"))); EXPECT_FALSE(schema.AddField(milvus::FieldSchema("bar", milvus::DataType::FLOAT_VECTOR, "bar"))); + EXPECT_TRUE(schema.AddField(milvus::FieldSchema("fp16", milvus::DataType::FLOAT16_VECTOR, "fp16"))); + EXPECT_FALSE(schema.AddField(milvus::FieldSchema("fp16", milvus::DataType::FLOAT16_VECTOR, "fp16"))); + + EXPECT_TRUE(schema.AddField(milvus::FieldSchema("bf16", milvus::DataType::BFLOAT16_VECTOR, "bf16"))); + EXPECT_FALSE(schema.AddField(milvus::FieldSchema("bf16", milvus::DataType::BFLOAT16_VECTOR, "bf16"))); + EXPECT_EQ(schema.ShardsNum(), 2); - EXPECT_EQ(schema.Fields().size(), 2); - EXPECT_EQ(schema.AnnsFieldNames().size(), 1); - EXPECT_EQ(*schema.AnnsFieldNames().begin(), "bar"); + EXPECT_EQ(schema.Fields().size(), 4); + auto anns_field_names = schema.AnnsFieldNames(); + EXPECT_EQ(anns_field_names.size(), 3); + EXPECT_TRUE(anns_field_names.find("bar") != anns_field_names.end()); + EXPECT_TRUE(anns_field_names.find("fp16") != anns_field_names.end()); + EXPECT_TRUE(anns_field_names.find("bf16") != anns_field_names.end()); } diff --git a/test/ut/TestFieldData.cpp b/test/ut/TestFieldData.cpp index 86a38a2..81c7f65 100644 --- a/test/ut/TestFieldData.cpp +++ b/test/ut/TestFieldData.cpp @@ -18,7 +18,35 @@ #include "milvus/types/FieldData.h" -class FieldDataTest : public ::testing::Test {}; +class FieldDataTest : public ::testing::Test { + protected: + template + static void + TestFloat16Vec(const std::string& name) { + FieldDataT data{name}; + std::vector element_1 = {FloatT(1.0), FloatT(2.0)}; + std::vector element_2 = {FloatT(3.0), FloatT(4.0)}; + data.Add(element_1); + data.Add(element_2); + + EXPECT_EQ(data.template DataAsFloats().size(), 2); + + std::vector element_3 = {FloatT(5.0)}; + EXPECT_EQ(data.Add(element_3), milvus::StatusCode::DIMENSION_NOT_EQUAL); + std::vector element_4; + EXPECT_EQ(data.Add(element_4), milvus::StatusCode::VECTOR_IS_EMPTY); + EXPECT_EQ(data.Name(), name); + EXPECT_EQ(data.Type(), expected_data_type); + EXPECT_EQ(data.Count(), 2); + EXPECT_EQ(data.Data().size(), 2); + EXPECT_EQ(data.Data().at(0).size(), element_1.size() * 2); + EXPECT_EQ(data.Data().at(1).size(), element_2.size() * 2); + + std::vector> values = {element_1, element_2}; + FieldDataT cp(name, values); + EXPECT_EQ(cp.Data().size(), 2); + } +}; TEST_F(FieldDataTest, GeneralTesting) { std::string name = "dummy"; @@ -212,4 +240,12 @@ TEST_F(FieldDataTest, GeneralTesting) { expected = std::vector>{{1, 2}, {3, 4}, {0, 0}}; EXPECT_EQ(data.DataAsUnsignedChars(), expected); } + + TestFloat16Vec(name); + TestFloat16Vec(name); + TestFloat16Vec(name); + + TestFloat16Vec(name); + TestFloat16Vec(name); + TestFloat16Vec(name); } diff --git a/test/ut/TestFieldSchema.cpp b/test/ut/TestFieldSchema.cpp index f2df1b1..6ab19ba 100644 --- a/test/ut/TestFieldSchema.cpp +++ b/test/ut/TestFieldSchema.cpp @@ -53,11 +53,19 @@ TEST_F(FieldSchemaTest, GeneralTesting) { EXPECT_EQ("256", type_params.at(milvus::FieldDim())); } -TEST_F(FieldSchemaTest, TestWithDimention) { +TEST_F(FieldSchemaTest, TestWithDimension) { EXPECT_EQ("1024", milvus::FieldSchema("vectors", milvus::DataType::FLOAT_VECTOR, "") .WithDimension(1024) .TypeParams() .at(milvus::FieldDim())); + EXPECT_EQ("1024", milvus::FieldSchema("vectors", milvus::DataType::FLOAT16_VECTOR, "") + .WithDimension(1024) + .TypeParams() + .at(milvus::FieldDim())); + EXPECT_EQ("1024", milvus::FieldSchema("vectors", milvus::DataType::BFLOAT16_VECTOR, "") + .WithDimension(1024) + .TypeParams() + .at(milvus::FieldDim())); } TEST_F(FieldSchemaTest, TestForMaxLength) { @@ -68,4 +76,4 @@ TEST_F(FieldSchemaTest, TestForMaxLength) { schema.SetMaxLength(300); EXPECT_EQ("300", schema.TypeParams().at(milvus::FieldMaxLength())); EXPECT_EQ(300, schema.MaxLength()); -} \ No newline at end of file +} diff --git a/test/ut/TestFloatUtils.cpp b/test/ut/TestFloatUtils.cpp new file mode 100644 index 0000000..00aa9a5 --- /dev/null +++ b/test/ut/TestFloatUtils.cpp @@ -0,0 +1,48 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "Eigen/Core" +#include "milvus/types/FloatUtils.h" + +class FloatUtilsTest : public ::testing::Test { + protected: + template + static void + FloatConversion() { + std::vector data; + for (size_t i = 0; i < 100; ++i) { + data.push_back(FloatT(static_cast(i) * 0.111)); + } + std::string bytes = milvus::FloatNumVecToFloat16NumVecBytes(data); + ASSERT_EQ(bytes.size(), 100 * 2); + std::vector result = milvus::Float16NumVecBytesToFloatNumVec(bytes); + ASSERT_EQ(data.size(), result.size()); + for (size_t i = 0; i < data.size(); ++i) { + ASSERT_NEAR(data[i], result[i], 4e-2); + } + } +}; + +TEST_F(FloatUtilsTest, Float16NumVecBytesToFloatNumVec) { + FloatConversion(); + FloatConversion(); + FloatConversion(); + FloatConversion(); + FloatConversion(); + FloatConversion(); +} diff --git a/test/ut/TestSearchArguments.cpp b/test/ut/TestSearchArguments.cpp index bd1bd92..764211b 100644 --- a/test/ut/TestSearchArguments.cpp +++ b/test/ut/TestSearchArguments.cpp @@ -16,6 +16,7 @@ #include +#include "milvus/types/FieldData.h" #include "milvus/types/SearchArguments.h" class SearchArgumentsTest : public ::testing::Test {}; @@ -59,17 +60,22 @@ TEST_F(SearchArgumentsTest, VectorTesting) { { milvus::SearchArguments arguments; - auto status = arguments.AddTargetVector("dummy", binary_vector); + auto status = arguments.AddTargetVector("dummy", binary_vector); EXPECT_TRUE(status.IsOk()); - status = arguments.AddTargetVector("dummy", float_vector); + status = arguments.AddTargetVector("dummy", float_vector); EXPECT_FALSE(status.IsOk()); + // inconsistent types std::vector new_vector = {1, 2}; - status = arguments.AddTargetVector("dummy", new_vector); + status = arguments.AddTargetVector("dummy", new_vector); + EXPECT_FALSE(status.IsOk()); + status = arguments.AddTargetVector("dummy", new_vector); + EXPECT_FALSE(status.IsOk()); + status = arguments.AddTargetVector("dummy", new_vector); EXPECT_FALSE(status.IsOk()); - status = arguments.AddTargetVector("dummy", std::vector{}); + status = arguments.AddTargetVector("dummy", std::vector{}); EXPECT_FALSE(status.IsOk()); auto target_vectors = arguments.TargetVectors(); @@ -77,7 +83,7 @@ TEST_F(SearchArgumentsTest, VectorTesting) { EXPECT_EQ(milvus::DataType::BINARY_VECTOR, target_vectors->Type()); EXPECT_EQ(1, target_vectors->Count()); - status = arguments.AddTargetVector("dummy", binary_string); + status = arguments.AddTargetVector("dummy", binary_string); EXPECT_TRUE(status.IsOk()); target_vectors = arguments.TargetVectors(); @@ -86,17 +92,21 @@ TEST_F(SearchArgumentsTest, VectorTesting) { { milvus::SearchArguments arguments; - auto status = arguments.AddTargetVector("dummy", float_vector); + auto status = arguments.AddTargetVector("dummy", float_vector); EXPECT_TRUE(status.IsOk()); - status = arguments.AddTargetVector("dummy", binary_vector); + status = arguments.AddTargetVector("dummy", binary_vector); EXPECT_FALSE(status.IsOk()); std::vector new_vector = {1.0, 2.0, 3.0}; - status = arguments.AddTargetVector("dummy", new_vector); + status = arguments.AddTargetVector("dummy", new_vector); + EXPECT_FALSE(status.IsOk()); + status = arguments.AddTargetVector("dummy", new_vector); + EXPECT_FALSE(status.IsOk()); + status = arguments.AddTargetVector("dummy", new_vector); EXPECT_FALSE(status.IsOk()); - status = arguments.AddTargetVector("dummy", std::vector{}); + status = arguments.AddTargetVector("dummy", std::vector{}); EXPECT_FALSE(status.IsOk()); auto target_vectors = arguments.TargetVectors(); @@ -107,17 +117,21 @@ TEST_F(SearchArgumentsTest, VectorTesting) { { milvus::SearchArguments arguments; - auto status = arguments.AddTargetVector("dummy", std::vector{1, 2, 3}); + auto status = arguments.AddTargetVector("dummy", std::vector{1, 2, 3}); EXPECT_TRUE(status.IsOk()); - status = arguments.AddTargetVector("dummy", std::vector{1.f, 2.f}); + status = arguments.AddTargetVector("dummy", std::vector{1.f, 2.f}); EXPECT_FALSE(status.IsOk()); std::vector new_vector = {1, 2}; - status = arguments.AddTargetVector("dummy", new_vector); + status = arguments.AddTargetVector("dummy", new_vector); + EXPECT_FALSE(status.IsOk()); + status = arguments.AddTargetVector("dummy", new_vector); + EXPECT_FALSE(status.IsOk()); + status = arguments.AddTargetVector("dummy", new_vector); EXPECT_FALSE(status.IsOk()); - status = arguments.AddTargetVector("dummy", std::vector{}); + status = arguments.AddTargetVector("dummy", std::vector{}); EXPECT_FALSE(status.IsOk()); auto target_vectors = arguments.TargetVectors(); @@ -128,17 +142,21 @@ TEST_F(SearchArgumentsTest, VectorTesting) { { milvus::SearchArguments arguments; - auto status = arguments.AddTargetVector("dummy", std::vector{1.f, 2.f}); + auto status = arguments.AddTargetVector("dummy", std::vector{1.f, 2.f}); EXPECT_TRUE(status.IsOk()); - status = arguments.AddTargetVector("dummy", std::vector{1, 2, 3}); + status = arguments.AddTargetVector("dummy", std::vector{1, 2, 3}); EXPECT_FALSE(status.IsOk()); std::vector new_vector = {1.0, 2.0, 3.0}; - status = arguments.AddTargetVector("dummy", new_vector); + status = arguments.AddTargetVector("dummy", new_vector); + EXPECT_FALSE(status.IsOk()); + status = arguments.AddTargetVector("dummy", new_vector); + EXPECT_FALSE(status.IsOk()); + status = arguments.AddTargetVector("dummy", new_vector); EXPECT_FALSE(status.IsOk()); - status = arguments.AddTargetVector("dummy", std::vector{}); + status = arguments.AddTargetVector("dummy", std::vector{}); EXPECT_FALSE(status.IsOk()); auto target_vectors = arguments.TargetVectors(); @@ -146,6 +164,48 @@ TEST_F(SearchArgumentsTest, VectorTesting) { EXPECT_EQ(milvus::DataType::FLOAT_VECTOR, target_vectors->Type()); EXPECT_EQ(1, target_vectors->Count()); } + + { + milvus::SearchArguments arguments; + auto status = arguments.AddTargetVector("dummy", std::vector{1.f, 2.f}); + EXPECT_TRUE(status.IsOk()); + + std::vector new_vector = {1.0, 2.0, 3.0}; + status = arguments.AddTargetVector("dummy", new_vector); + EXPECT_FALSE(status.IsOk()); + + status = arguments.AddTargetVector("dummy", std::vector{1.f, 2.f, 3.f}); + EXPECT_FALSE(status.IsOk()); + + status = arguments.AddTargetVector("dummy", std::vector{1.f, 2.f}); + EXPECT_TRUE(status.IsOk()); + + auto target_vectors = arguments.TargetVectors(); + EXPECT_TRUE(target_vectors != nullptr); + EXPECT_EQ(milvus::DataType::FLOAT16_VECTOR, target_vectors->Type()); + EXPECT_EQ(2, target_vectors->Count()); + } + + { + milvus::SearchArguments arguments; + auto status = arguments.AddTargetVector("dummy", std::vector{1.f, 2.f}); + EXPECT_TRUE(status.IsOk()); + + std::vector new_vector = {1.0, 2.0, 3.0}; + status = arguments.AddTargetVector("dummy", new_vector); + EXPECT_FALSE(status.IsOk()); + + status = arguments.AddTargetVector("dummy", std::vector{1.f, 2.f, 3.f}); + EXPECT_FALSE(status.IsOk()); + + status = arguments.AddTargetVector("dummy", std::vector{1.f, 2.f}); + EXPECT_TRUE(status.IsOk()); + + auto target_vectors = arguments.TargetVectors(); + EXPECT_TRUE(target_vectors != nullptr); + EXPECT_EQ(milvus::DataType::BFLOAT16_VECTOR, target_vectors->Type()); + EXPECT_EQ(2, target_vectors->Count()); + } } TEST_F(SearchArgumentsTest, ValidateTesting) { diff --git a/test/ut/TestTypeUtils.cpp b/test/ut/TestTypeUtils.cpp index 8a47839..6eab27e 100644 --- a/test/ut/TestTypeUtils.cpp +++ b/test/ut/TestTypeUtils.cpp @@ -19,6 +19,7 @@ #include #include "TypeUtils.h" +#include "milvus/types/FloatUtils.h" using milvus::CreateIDArray; using milvus::CreateMilvusFieldData; @@ -265,6 +266,7 @@ TEST_F(TypeUtilsTest, BinaryVecFieldEqualsAndCast) { std::vector{3, 4}, }}; auto proto_field_data = CreateProtoFieldData(static_cast(bins_field_data)); + EXPECT_EQ(proto_field_data.vectors().dim(), 16); auto bins_field_data_ptr = CreateMilvusFieldData(proto_field_data); EXPECT_EQ(proto_field_data, bins_field_data); EXPECT_EQ(proto_field_data, *bins_field_data_ptr); @@ -299,18 +301,28 @@ TEST_F(TypeUtilsTest, BinaryVecFieldNotEquals) { EXPECT_FALSE(proto_field == bins_field); } -TEST_F(TypeUtilsTest, FloatVecFieldEqualsAndCast) { - milvus::FloatVecFieldData floats_field_data{"foo", std::vector>{ - std::vector{0.1f, 0.2f}, - std::vector{0.3f, 0.4f}, - }}; +template +static void +DoFloatVecFieldEqualsAndCast() { + FloatVecFieldT floats_field_data{"foo", std::vector>{ + std::vector{0.1f, 0.2f}, + std::vector{0.3f, 0.4f}, + }}; auto proto_field_data = CreateProtoFieldData(static_cast(floats_field_data)); auto floats_field_data_ptr = CreateMilvusFieldData(proto_field_data); + EXPECT_EQ(proto_field_data.vectors().dim(), 2); + EXPECT_EQ(floats_field_data_ptr->Count(), 2); EXPECT_EQ(proto_field_data, floats_field_data); EXPECT_EQ(proto_field_data, *floats_field_data_ptr); EXPECT_EQ(floats_field_data, *floats_field_data_ptr); } +TEST_F(TypeUtilsTest, FloatVecFieldEqualsAndCast) { + DoFloatVecFieldEqualsAndCast(); + DoFloatVecFieldEqualsAndCast(); + DoFloatVecFieldEqualsAndCast(); +} + TEST_F(TypeUtilsTest, FloatVecFieldNotEquals) { const std::string field_name = "foo"; milvus::FloatVecFieldData floats_field{"foo", std::vector>{ @@ -325,11 +337,10 @@ TEST_F(TypeUtilsTest, FloatVecFieldNotEquals) { proto_field.mutable_scalars(); EXPECT_FALSE(proto_field == floats_field); - auto scalars = proto_field.mutable_vectors(); - scalars->mutable_binary_vector(); + auto* scalars = proto_field.mutable_vectors(); EXPECT_FALSE(proto_field == floats_field); - auto floats_scalars = scalars->mutable_float_vector(); + auto* floats_scalars = scalars->mutable_float_vector(); floats_scalars->add_data(0.1f); EXPECT_FALSE(proto_field == floats_field); @@ -339,6 +350,50 @@ TEST_F(TypeUtilsTest, FloatVecFieldNotEquals) { EXPECT_FALSE(proto_field == floats_field); } +template +static void +DoFloat16VecFieldNotEquals() { + const std::string field_name = "foo"; + FloatVecFieldT floats_field{"foo", std::vector>{ + std::vector{FloatT(0.1), FloatT(0.2)}, + std::vector{FloatT(0.3), FloatT(0.4)}, + }}; + milvus::proto::schema::FieldData proto_field; + proto_field.set_field_name("_"); + EXPECT_FALSE(proto_field == floats_field); + + proto_field.set_field_name(field_name); + proto_field.mutable_scalars(); + EXPECT_FALSE(proto_field == floats_field); + + auto scalars = proto_field.mutable_vectors(); + EXPECT_FALSE(proto_field == floats_field); + + std::string* floats_scalars = nullptr; + if constexpr (std::is_same_v) { + floats_scalars = scalars->mutable_float16_vector(); + } else { + floats_scalars = scalars->mutable_bfloat16_vector(); + } + floats_scalars->push_back('a'); + EXPECT_FALSE(proto_field == floats_field); + + floats_scalars->push_back('a'); + floats_scalars->push_back('a'); + floats_scalars->push_back('a'); + EXPECT_FALSE(proto_field == floats_field); +} + +TEST_F(TypeUtilsTest, Float16VecFieldNotEquals) { + DoFloat16VecFieldNotEquals(); + DoFloat16VecFieldNotEquals(); + DoFloat16VecFieldNotEquals(); + + DoFloat16VecFieldNotEquals(); + DoFloat16VecFieldNotEquals(); + DoFloat16VecFieldNotEquals(); +} + TEST_F(TypeUtilsTest, IDArray) { milvus::proto::schema::IDs ids; ids.mutable_int_id()->add_data(10000); @@ -418,19 +473,53 @@ TEST_F(TypeUtilsTest, CreateMilvusFieldDataWithRange) { CreateMilvusFieldData(CreateProtoFieldData(static_cast(string_field_data)), 1, 2)); EXPECT_THAT(string_field_data_ptr->Data(), ElementsAre("b", "c")); - milvus::BinaryVecFieldData bins_field_data{"foo", - std::vector>{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}; - const auto bins_field_data_ptr = std::dynamic_pointer_cast( - CreateMilvusFieldData(CreateProtoFieldData(static_cast(bins_field_data)), 1, 2)); - EXPECT_THAT(bins_field_data_ptr->DataAsUnsignedChars(), - ElementsAre(std::vector{4, 5, 6}, std::vector{7, 8, 9})); + { + milvus::BinaryVecFieldData bins_field_data{"foo", + std::vector>{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}; + const auto bins_field_data_ptr = std::dynamic_pointer_cast( + CreateMilvusFieldData(CreateProtoFieldData(static_cast(bins_field_data)), 1, 2)); + EXPECT_THAT(bins_field_data_ptr->DataAsUnsignedChars(), + ElementsAre(std::vector{4, 5, 6}, std::vector{7, 8, 9})); + } + + { + milvus::FloatVecFieldData floats_field_data{ + "foo", std::vector>{{0.1f, 0.2f, 0.3f}, {0.4f, 0.5f, 0.6f}, {0.7f, 0.8f, 0.9f}}}; + const auto floats_field_data_ptr = std::dynamic_pointer_cast( + CreateMilvusFieldData(CreateProtoFieldData(static_cast(floats_field_data)), 1, 2)); + EXPECT_THAT(floats_field_data_ptr->Data(), + ElementsAre(std::vector{0.4f, 0.5f, 0.6f}, std::vector{0.7f, 0.8f, 0.9f})); + } + + { + milvus::Float16VecFieldData float16s_field_data{ + "foo", std::vector>{{0.1f, 0.2f, 0.3f}, {0.4f, 0.5f, 0.6f}, {0.7f, 0.8f, 0.9f}}}; + const auto float16s_field_data_ptr = std::dynamic_pointer_cast( + CreateMilvusFieldData(CreateProtoFieldData(static_cast(float16s_field_data)), 1, 2)); + + std::string fp16_vec1 = + milvus::FloatNumVecToFloat16NumVecBytes(std::vector{0.4f, 0.5f, 0.6f}); + std::string fp16_vec2 = + milvus::FloatNumVecToFloat16NumVecBytes(std::vector{0.7f, 0.8f, 0.9f}); + EXPECT_THAT(float16s_field_data_ptr->DataAsUnsignedChars(), + ElementsAre(std::vector(fp16_vec1.begin(), fp16_vec1.end()), + std::vector(fp16_vec2.begin(), fp16_vec2.end()))); + } - milvus::FloatVecFieldData floats_field_data{ - "foo", std::vector>{{0.1f, 0.2f, 0.3f}, {0.4f, 0.5f, 0.6f}, {0.7f, 0.8f, 0.9f}}}; - const auto floats_field_data_ptr = std::dynamic_pointer_cast( - CreateMilvusFieldData(CreateProtoFieldData(static_cast(floats_field_data)), 1, 2)); - EXPECT_THAT(floats_field_data_ptr->Data(), - ElementsAre(std::vector{0.4f, 0.5f, 0.6f}, std::vector{0.7f, 0.8f, 0.9f})); + { + milvus::BFloat16VecFieldData bfloat16s_field_data{ + "foo", std::vector>{{0.1f, 0.2f, 0.3f}, {0.4f, 0.5f, 0.6f}, {0.7f, 0.8f, 0.9f}}}; + const auto bfloat16s_field_data_ptr = std::dynamic_pointer_cast( + CreateMilvusFieldData(CreateProtoFieldData(static_cast(bfloat16s_field_data)), 1, 2)); + + std::string bf16_vec1 = + milvus::FloatNumVecToFloat16NumVecBytes(std::vector{0.4f, 0.5f, 0.6f}); + std::string bf16_vec2 = + milvus::FloatNumVecToFloat16NumVecBytes(std::vector{0.7f, 0.8f, 0.9f}); + EXPECT_THAT(bfloat16s_field_data_ptr->DataAsUnsignedChars(), + ElementsAre(std::vector(bf16_vec1.begin(), bf16_vec1.end()), + std::vector(bf16_vec2.begin(), bf16_vec2.end()))); + } } TEST_F(TypeUtilsTest, MetricTypeCastTest) { @@ -459,7 +548,9 @@ TEST_F(TypeUtilsTest, DataTypeCast) { {milvus::DataType::DOUBLE, milvus::proto::schema::DataType::Double}, {milvus::DataType::VARCHAR, milvus::proto::schema::DataType::VarChar}, {milvus::DataType::FLOAT_VECTOR, milvus::proto::schema::DataType::FloatVector}, - {milvus::DataType::BINARY_VECTOR, milvus::proto::schema::DataType::BinaryVector}}; + {milvus::DataType::BINARY_VECTOR, milvus::proto::schema::DataType::BinaryVector}, + {milvus::DataType::FLOAT16_VECTOR, milvus::proto::schema::DataType::Float16Vector}, + {milvus::DataType::BFLOAT16_VECTOR, milvus::proto::schema::DataType::BFloat16Vector}}; for (auto& pair : data_types) { auto dt = milvus::DataTypeCast(milvus::DataType(pair.first)); @@ -533,4 +624,4 @@ TEST_F(TypeUtilsTest, TestB64EncodeGeneric) { EXPECT_EQ(milvus::Base64Encode("abc"), "YWJj"); EXPECT_EQ(milvus::Base64Encode("abcd"), "YWJjZA=="); EXPECT_EQ(milvus::Base64Encode("abcde"), "YWJjZGU="); -} \ No newline at end of file +}