diff --git a/internal/proxy/cgo_util.go b/internal/proxy/cgo_util.go new file mode 100644 index 0000000000000..ec91c8d3b2d1d --- /dev/null +++ b/internal/proxy/cgo_util.go @@ -0,0 +1,38 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +/* +#cgo pkg-config: milvus_segcore +#include "segcore/check_vec_index_c.h" +#include +*/ +import "C" + +import ( + "unsafe" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func CheckVecIndexWithDataTypeExist(name string, dType schemapb.DataType) bool { + cIndexName := C.CString(name) + cType := uint32(dType) + defer C.free(unsafe.Pointer(cIndexName)) + check := bool(C.CheckVecIndexWithDataType(cIndexName, cType)) + return check +} diff --git a/internal/proxy/cgo_util_test.go b/internal/proxy/cgo_util_test.go new file mode 100644 index 0000000000000..363ee644f9027 --- /dev/null +++ b/internal/proxy/cgo_util_test.go @@ -0,0 +1,58 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "testing" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/indexparamcheck" +) + +func Test_CheckVecIndexWithDataTypeExist(t *testing.T) { + cases := []struct { + indexType string + dataType schemapb.DataType + want bool + }{ + {indexparamcheck.IndexHNSW, schemapb.DataType_FloatVector, true}, + {indexparamcheck.IndexHNSW, schemapb.DataType_BinaryVector, false}, + {indexparamcheck.IndexHNSW, schemapb.DataType_Float16Vector, true}, + + {indexparamcheck.IndexSparseWand, schemapb.DataType_SparseFloatVector, true}, + {indexparamcheck.IndexSparseWand, schemapb.DataType_FloatVector, false}, + {indexparamcheck.IndexSparseWand, schemapb.DataType_Float16Vector, false}, + + {indexparamcheck.IndexGpuBF, schemapb.DataType_FloatVector, true}, + {indexparamcheck.IndexGpuBF, schemapb.DataType_Float16Vector, false}, + {indexparamcheck.IndexGpuBF, schemapb.DataType_BinaryVector, false}, + + {indexparamcheck.IndexFaissBinIvfFlat, schemapb.DataType_BinaryVector, true}, + {indexparamcheck.IndexFaissBinIvfFlat, schemapb.DataType_FloatVector, false}, + + {indexparamcheck.IndexDISKANN, schemapb.DataType_FloatVector, true}, + {indexparamcheck.IndexDISKANN, schemapb.DataType_Float16Vector, true}, + {indexparamcheck.IndexDISKANN, schemapb.DataType_BFloat16Vector, true}, + {indexparamcheck.IndexDISKANN, schemapb.DataType_BinaryVector, false}, + } + + for _, test := range cases { + if got := CheckVecIndexWithDataTypeExist(test.indexType, test.dataType); got != test.want { + t.Errorf("CheckVecIndexWithDataTypeExist(%v, %v) = %v", test.indexType, test.dataType, test.want) + } + } +} diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index 17dde3ebc44c0..54207f1869e94 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -370,18 +370,19 @@ func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) erro indexParams[common.BitmapCardinalityLimitKey] = paramtable.Get().CommonCfg.BitmapIndexCardinalityBound.GetValue() } } - if typeutil.IsVectorType(field.DataType) && indexType != indexparamcheck.AutoIndex { - exist := indexparamcheck.CheckVecIndexWithDataTypeExist(indexType, field.DataType) - if !exist { - return fmt.Errorf("data type %d can't build with this index %s", field.DataType, indexType) - } - } checker, err := indexparamcheck.GetIndexCheckerMgrInstance().GetChecker(indexType) if err != nil { log.Warn("Failed to get index checker", zap.String(common.IndexTypeKey, indexType)) return fmt.Errorf("invalid index type: %s", indexType) } + if typeutil.IsVectorType(field.DataType) && indexType != indexparamcheck.AutoIndex { + exist := CheckVecIndexWithDataTypeExist(indexType, field.DataType) + if !exist { + return fmt.Errorf("data type %d can't build with this index %s", field.DataType, indexType) + } + } + if !typeutil.IsSparseFloatVectorType(field.DataType) { if err := fillDimension(field, indexParams); err != nil { return err diff --git a/pkg/util/indexparamcheck/utils.go b/pkg/util/indexparamcheck/utils.go index e1d2eccc17b1f..adca93aeb362a 100644 --- a/pkg/util/indexparamcheck/utils.go +++ b/pkg/util/indexparamcheck/utils.go @@ -16,19 +16,10 @@ package indexparamcheck -/* -#cgo pkg-config: milvus_segcore -#include "segcore/check_vec_index_c.h" -#include -*/ -import "C" - import ( "fmt" "strconv" - "unsafe" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/util/funcutil" ) @@ -78,11 +69,3 @@ func setDefaultIfNotExist(params map[string]string, key string, defaultValue str params[key] = defaultValue } } - -func CheckVecIndexWithDataTypeExist(name string, dType schemapb.DataType) bool { - cIndexName := C.CString(name) - cType := uint32(dType) - defer C.free(unsafe.Pointer(cIndexName)) - check := bool(C.CheckVecIndexWithDataType(cIndexName, cType)) - return check -}