From f8bb536090e306f96b9ebaabf537e8565ad94ad4 Mon Sep 17 00:00:00 2001 From: wei liu Date: Mon, 3 Jun 2024 19:13:40 +0800 Subject: [PATCH 01/11] enhance: refine param name for AlterDatabase (#757) Signed-off-by: Wei Liu --- client/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/client.go b/client/client.go index f7db4d3c..1ae9dcda 100644 --- a/client/client.go +++ b/client/client.go @@ -44,7 +44,7 @@ type Client interface { // DropDatabase drop database with the given db name. DropDatabase(ctx context.Context, dbName string, opts ...DropDatabaseOption) error // AlterDatabase alter database props with given db name. - AlterDatabase(ctx context.Context, collName string, attrs ...entity.DatabaseAttribute) error + AlterDatabase(ctx context.Context, dbName string, attrs ...entity.DatabaseAttribute) error DescribeDatabase(ctx context.Context, dbName string) (*entity.Database, error) // -- collection -- From 0f2b1f29bf47db9e071a82b8aa048f26e1a540d7 Mon Sep 17 00:00:00 2001 From: congqixia Date: Tue, 4 Jun 2024 11:33:41 +0800 Subject: [PATCH 02/11] fix: Add parentheses for query iterator origin expr (#761) Related to #758 #759 Signed-off-by: Congqi Xia --- client/iterator.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/client/iterator.go b/client/iterator.go index 34ebd552..9957141a 100644 --- a/client/iterator.go +++ b/client/iterator.go @@ -116,15 +116,12 @@ func (itr *QueryIterator) composeIteratorExpr() string { } expr := strings.TrimSpace(itr.expr) - if expr != "" { - expr += " and " - } switch itr.pkField.DataType { case entity.FieldTypeInt64: - expr += fmt.Sprintf("%s > %d", itr.pkField.Name, itr.lastPK) + expr = fmt.Sprintf("(%s) and %s > %d", expr, itr.pkField.Name, itr.lastPK) case entity.FieldTypeVarChar: - expr += fmt.Sprintf(`%s > "%s"`, itr.pkField.Name, itr.lastPK) + expr += fmt.Sprintf(`(%s) and %s > "%s"`, expr, itr.pkField.Name, itr.lastPK) default: return itr.expr } From 0dde68b6b9a25107af88a51890c8aee406668911 Mon Sep 17 00:00:00 2001 From: congqixia Date: Thu, 6 Jun 2024 11:31:44 +0800 Subject: [PATCH 03/11] fix: Check batch size when initializing iterator (#762) See also #754 Signed-off-by: Congqi Xia --- client/iterator.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/client/iterator.go b/client/iterator.go index 9957141a..e313fc71 100644 --- a/client/iterator.go +++ b/client/iterator.go @@ -102,6 +102,10 @@ type QueryIterator struct { // init fetches the first batch of data and put it into cache. // this operation could be used to check all the parameters before returning the iterator. func (itr *QueryIterator) init(ctx context.Context) error { + if itr.batchSize <= 0 { + return errors.New("batch size cannot less than 1") + } + rs, err := itr.fetchNextBatch(ctx) if err != nil { return err From 195148310b22504a93056deddb7670e02133a7fc Mon Sep 17 00:00:00 2001 From: congqixia Date: Thu, 6 Jun 2024 15:15:45 +0800 Subject: [PATCH 04/11] fix: Check empty expr before compose iterator next batch (#764) See also #763 Signed-off-by: Congqi Xia --- client/iterator.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/client/iterator.go b/client/iterator.go index e313fc71..2d7bc8f5 100644 --- a/client/iterator.go +++ b/client/iterator.go @@ -123,9 +123,17 @@ func (itr *QueryIterator) composeIteratorExpr() string { switch itr.pkField.DataType { case entity.FieldTypeInt64: - expr = fmt.Sprintf("(%s) and %s > %d", expr, itr.pkField.Name, itr.lastPK) + if len(expr) == 0 { + expr = fmt.Sprintf("%s > %d", itr.pkField.Name, itr.lastPK) + } else { + expr = fmt.Sprintf("(%s) and %s > %d", expr, itr.pkField.Name, itr.lastPK) + } case entity.FieldTypeVarChar: - expr += fmt.Sprintf(`(%s) and %s > "%s"`, expr, itr.pkField.Name, itr.lastPK) + if len(expr) == 0 { + expr = fmt.Sprintf(`%s > "%s"`, itr.pkField.Name, itr.lastPK) + } else { + expr = fmt.Sprintf(`(%s) and %s > "%s"`, expr, itr.pkField.Name, itr.lastPK) + } default: return itr.expr } From dd7078cfa86c2fafd3c5236df74cee739a038796 Mon Sep 17 00:00:00 2001 From: congqixia Date: Thu, 6 Jun 2024 17:49:45 +0800 Subject: [PATCH 05/11] fix: Make FP16 & BF16 column use correct byte len when parsing (#765) See also #756 Signed-off-by: Congqi Xia --- client/results.go | 4 +- entity/columns.go | 13 ++++--- entity/columns_array.go | 9 +++-- entity/columns_array_gen.go | 66 +++++++++++++++++++-------------- entity/columns_json.go | 9 +++-- entity/columns_scalar_gen.go | 72 ++++++++++++++++++++---------------- entity/columns_sparse.go | 9 +++-- entity/columns_varchar.go | 9 +++-- entity/columns_vector_gen.go | 52 +++++++++++++++++--------- 9 files changed, 142 insertions(+), 101 deletions(-) diff --git a/client/results.go b/client/results.go index 677a0b93..869b312c 100644 --- a/client/results.go +++ b/client/results.go @@ -1,6 +1,8 @@ package client -import "github.com/milvus-io/milvus-sdk-go/v2/entity" +import ( + "github.com/milvus-io/milvus-sdk-go/v2/entity" +) // SearchResult contains the result from Search api of client // IDs is the auto generated id values for the entities diff --git a/entity/columns.go b/entity/columns.go index 5ef09b76..cf6e3a90 100644 --- a/entity/columns.go +++ b/entity/columns.go @@ -363,12 +363,12 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) { data := x.Float16Vector dim := int(vectors.GetDim()) if end < 0 { - end = int(len(data) / dim) + end = int(len(data) / dim / 2) } vector := make([][]byte, 0, end-begin) for i := begin; i < end; i++ { - v := make([]byte, dim) - copy(v, data[i*dim:(i+1)*dim]) + v := make([]byte, dim*2) + copy(v, data[i*dim*2:(i+1)*dim*2]) vector = append(vector, v) } return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil @@ -381,13 +381,14 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) { } data := x.Bfloat16Vector dim := int(vectors.GetDim()) + if end < 0 { - end = int(len(data) / dim) + end = int(len(data) / dim / 2) } vector := make([][]byte, 0, end-begin) // shall not have remanunt for i := begin; i < end; i++ { - v := make([]byte, dim) - copy(v, data[i*dim:(i+1)*dim]) + v := make([]byte, dim*2) + copy(v, data[i*dim*2:(i+1)*dim*2]) vector = append(vector, v) } return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil diff --git a/entity/columns_array.go b/entity/columns_array.go index e2726add..9de845a8 100644 --- a/entity/columns_array.go +++ b/entity/columns_array.go @@ -30,11 +30,12 @@ func (c *ColumnVarCharArray) Len() int { } func (c *ColumnVarCharArray) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnVarCharArray{ ColumnBase: c.ColumnBase, diff --git a/entity/columns_array_gen.go b/entity/columns_array_gen.go index 8da7bb03..6b80da21 100755 --- a/entity/columns_array_gen.go +++ b/entity/columns_array_gen.go @@ -34,11 +34,15 @@ func (c *ColumnBoolArray) Len() int { } func (c *ColumnBoolArray) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l + } + if end == -1 || end > l { + end = l } return &ColumnBoolArray{ ColumnBase: c.ColumnBase, @@ -147,11 +151,12 @@ func (c *ColumnInt8Array) Len() int { } func (c *ColumnInt8Array) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnInt8Array{ ColumnBase: c.ColumnBase, @@ -260,11 +265,12 @@ func (c *ColumnInt16Array) Len() int { } func (c *ColumnInt16Array) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnInt16Array{ ColumnBase: c.ColumnBase, @@ -373,11 +379,12 @@ func (c *ColumnInt32Array) Len() int { } func (c *ColumnInt32Array) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnInt32Array{ ColumnBase: c.ColumnBase, @@ -486,11 +493,12 @@ func (c *ColumnInt64Array) Len() int { } func (c *ColumnInt64Array) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnInt64Array{ ColumnBase: c.ColumnBase, @@ -599,11 +607,12 @@ func (c *ColumnFloatArray) Len() int { } func (c *ColumnFloatArray) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnFloatArray{ ColumnBase: c.ColumnBase, @@ -712,11 +721,12 @@ func (c *ColumnDoubleArray) Len() int { } func (c *ColumnDoubleArray) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnDoubleArray{ ColumnBase: c.ColumnBase, diff --git a/entity/columns_json.go b/entity/columns_json.go index 88301b83..4c598219 100644 --- a/entity/columns_json.go +++ b/entity/columns_json.go @@ -36,11 +36,12 @@ func (c *ColumnJSONBytes) Len() int { } func (c *ColumnJSONBytes) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnJSONBytes{ ColumnBase: c.ColumnBase, diff --git a/entity/columns_scalar_gen.go b/entity/columns_scalar_gen.go index 33462cad..f97cbde6 100755 --- a/entity/columns_scalar_gen.go +++ b/entity/columns_scalar_gen.go @@ -33,11 +33,12 @@ func (c *ColumnBool) Len() int { } func (c *ColumnBool) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnBool{ ColumnBase: c.ColumnBase, @@ -134,11 +135,12 @@ func (c *ColumnInt8) Len() int { } func (c *ColumnInt8) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnInt8{ ColumnBase: c.ColumnBase, @@ -235,11 +237,12 @@ func (c *ColumnInt16) Len() int { } func (c *ColumnInt16) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnInt16{ ColumnBase: c.ColumnBase, @@ -336,11 +339,12 @@ func (c *ColumnInt32) Len() int { } func (c *ColumnInt32) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnInt32{ ColumnBase: c.ColumnBase, @@ -437,11 +441,12 @@ func (c *ColumnInt64) Len() int { } func (c *ColumnInt64) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnInt64{ ColumnBase: c.ColumnBase, @@ -538,11 +543,12 @@ func (c *ColumnFloat) Len() int { } func (c *ColumnFloat) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnFloat{ ColumnBase: c.ColumnBase, @@ -639,11 +645,12 @@ func (c *ColumnDouble) Len() int { } func (c *ColumnDouble) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnDouble{ ColumnBase: c.ColumnBase, @@ -740,11 +747,12 @@ func (c *ColumnString) Len() int { } func (c *ColumnString) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnString{ ColumnBase: c.ColumnBase, diff --git a/entity/columns_sparse.go b/entity/columns_sparse.go index 2df3a11e..86693f0d 100644 --- a/entity/columns_sparse.go +++ b/entity/columns_sparse.go @@ -144,11 +144,12 @@ func (c *ColumnSparseFloatVector) Len() int { } func (c *ColumnSparseFloatVector) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnSparseFloatVector{ ColumnBase: c.ColumnBase, diff --git a/entity/columns_varchar.go b/entity/columns_varchar.go index 72b12d99..f86b5d68 100644 --- a/entity/columns_varchar.go +++ b/entity/columns_varchar.go @@ -46,11 +46,12 @@ func (c *ColumnVarChar) GetAsString(idx int) (string, error) { } func (c *ColumnVarChar) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnVarChar{ ColumnBase: c.ColumnBase, diff --git a/entity/columns_vector_gen.go b/entity/columns_vector_gen.go index 90e5cb01..f690d36d 100755 --- a/entity/columns_vector_gen.go +++ b/entity/columns_vector_gen.go @@ -35,11 +35,15 @@ func (c * ColumnBinaryVector) Len() int { } func (c *ColumnBinaryVector) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l + } + if end == -1 || end > l { + end = l } return &ColumnBinaryVector{ ColumnBase: c.ColumnBase, @@ -150,11 +154,15 @@ func (c *ColumnFloatVector) Get(idx int) (interface{}, error) { } func (c *ColumnFloatVector) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l + } + if end == -1 || end > l { + end = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnFloatVector{ ColumnBase: c.ColumnBase, @@ -241,11 +249,15 @@ func (c * ColumnFloat16Vector) Len() int { } func (c *ColumnFloat16Vector) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l + } + if end == -1 || end > l { + end = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnFloat16Vector{ ColumnBase: c.ColumnBase, @@ -291,7 +303,7 @@ func (c *ColumnFloat16Vector) FieldData() *schema.FieldData { FieldName: c.name, } - data := make([]byte, 0, len(c.values)* c.dim) + data := make([]byte, 0, len(c.values)* c.dim *2) for _, vector := range c.values { data = append(data, vector...) @@ -343,11 +355,15 @@ func (c * ColumnBFloat16Vector) Len() int { } func (c *ColumnBFloat16Vector) Slice(start, end int) Column { - if start > c.Len() { - start = c.Len() + l := c.Len() + if start > l { + start = l + } + if end == -1 || end > l { + end = l } - if end == -1 || end > c.Len() { - end = c.Len() + if end == -1 || end > l { + end = l } return &ColumnBFloat16Vector{ ColumnBase: c.ColumnBase, @@ -393,7 +409,7 @@ func (c *ColumnBFloat16Vector) FieldData() *schema.FieldData { FieldName: c.name, } - data := make([]byte, 0, len(c.values)* c.dim) + data := make([]byte, 0, len(c.values)* c.dim*2) for _, vector := range c.values { data = append(data, vector...) From 302e5648edeaaf2adfd5344b06e7739d6747f1e0 Mon Sep 17 00:00:00 2001 From: congqixia Date: Tue, 11 Jun 2024 10:45:48 +0800 Subject: [PATCH 06/11] enhance: Support InMemory option for ListCollection (#768) Support specifiy check in-memory status when use `ListCollections` Signed-off-by: Congqi Xia --- client/client.go | 2 +- client/collection.go | 13 ++++++++++++- client/options.go | 12 ++++++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/client/client.go b/client/client.go index 1ae9dcda..1811781c 100644 --- a/client/client.go +++ b/client/client.go @@ -52,7 +52,7 @@ type Client interface { // NewCollection intializeds a new collection with pre defined attributes NewCollection(ctx context.Context, collName string, dimension int64, opts ...CreateCollectionOption) error // ListCollections list collections from connection - ListCollections(ctx context.Context) ([]*entity.Collection, error) + ListCollections(ctx context.Context, opts ...ListCollectionOption) ([]*entity.Collection, error) // CreateCollection create collection using provided schema CreateCollection(ctx context.Context, schema *entity.Schema, shardsNum int32, opts ...CreateCollectionOption) error // DescribeCollection describe collection meta diff --git a/client/collection.go b/client/collection.go index 98b0cb4b..bf5fdec7 100644 --- a/client/collection.go +++ b/client/collection.go @@ -45,14 +45,25 @@ func handleRespStatus(status *commonpb.Status) error { // ListCollections list collections from connection // Note that schema info are not provided in collection list -func (c *GrpcClient) ListCollections(ctx context.Context) ([]*entity.Collection, error) { +func (c *GrpcClient) ListCollections(ctx context.Context, opts ...ListCollectionOption) ([]*entity.Collection, error) { if c.Service == nil { return []*entity.Collection{}, ErrClientNotReady } + + o := &listCollectionOpt{} + for _, opt := range opts { + opt(o) + } + req := &milvuspb.ShowCollectionsRequest{ DbName: "", TimeStamp: 0, // means now } + + if o.showInMemory { + req.Type = milvuspb.ShowType_InMemory + } + resp, err := c.Service.ShowCollections(ctx, req) if err != nil { return []*entity.Collection{}, err diff --git a/client/options.go b/client/options.go index dc82454c..6ac2a962 100644 --- a/client/options.go +++ b/client/options.go @@ -289,6 +289,18 @@ func GetWithOutputFields(outputFields ...string) GetOption { } } +type listCollectionOpt struct { + showInMemory bool +} + +type ListCollectionOption func(*listCollectionOpt) + +func WithShowInMemory(value bool) ListCollectionOption { + return func(opt *listCollectionOpt) { + opt.showInMemory = value + } +} + type DropCollectionOption func(*milvuspb.DropCollectionRequest) type ReleaseCollectionOption func(*milvuspb.ReleaseCollectionRequest) From 3a0f12b0eb8cd27f8727b165bcfb56572218805b Mon Sep 17 00:00:00 2001 From: wayblink Date: Tue, 11 Jun 2024 17:19:48 +0800 Subject: [PATCH 07/11] Support CDC (#766) #767 Signed-off-by: wayblink --- client/client.go | 2 +- client/insert.go | 27 +++++++++++++++++++-------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/client/client.go b/client/client.go index 1811781c..cfc1a883 100644 --- a/client/client.go +++ b/client/client.go @@ -134,7 +134,7 @@ type Client interface { Flush(ctx context.Context, collName string, async bool, opts ...FlushOption) error // FlushV2 flush collection, specified, return newly sealed segmentIds, all flushed segmentIds of the collection, seal time and error // currently it is only used in milvus-backup(https://github.com/zilliztech/milvus-backup) - FlushV2(ctx context.Context, collName string, async bool, opts ...FlushOption) ([]int64, []int64, int64, error) + FlushV2(ctx context.Context, collName string, async bool, opts ...FlushOption) ([]int64, []int64, int64, map[string]msgpb.MsgPosition, error) // DeleteByPks deletes entries related to provided primary keys DeleteByPks(ctx context.Context, collName string, partitionName string, ids entity.Column) error // Delete deletes entries match expression diff --git a/client/insert.go b/client/insert.go index d7b91dfa..652601d7 100644 --- a/client/insert.go +++ b/client/insert.go @@ -21,6 +21,7 @@ import ( "github.com/golang/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-sdk-go/v2/entity" @@ -185,18 +186,18 @@ func (c *GrpcClient) mergeDynamicColumns(dynamicName string, rowSize int, column // Flush force collection to flush memory records into storage // in sync mode, flush will wait all segments to be flushed func (c *GrpcClient) Flush(ctx context.Context, collName string, async bool, opts ...FlushOption) error { - _, _, _, err := c.FlushV2(ctx, collName, async, opts...) + _, _, _, _, err := c.FlushV2(ctx, collName, async, opts...) return err } // Flush force collection to flush memory records into storage // in sync mode, flush will wait all segments to be flushed -func (c *GrpcClient) FlushV2(ctx context.Context, collName string, async bool, opts ...FlushOption) ([]int64, []int64, int64, error) { +func (c *GrpcClient) FlushV2(ctx context.Context, collName string, async bool, opts ...FlushOption) ([]int64, []int64, int64, map[string]msgpb.MsgPosition, error) { if c.Service == nil { - return nil, nil, 0, ErrClientNotReady + return nil, nil, 0, nil, ErrClientNotReady } if err := c.checkCollectionExists(ctx, collName); err != nil { - return nil, nil, 0, err + return nil, nil, 0, nil, err } req := &milvuspb.FlushRequest{ DbName: "", // reserved, @@ -207,11 +208,12 @@ func (c *GrpcClient) FlushV2(ctx context.Context, collName string, async bool, o } resp, err := c.Service.Flush(ctx, req) if err != nil { - return nil, nil, 0, err + return nil, nil, 0, nil, err } if err := handleRespStatus(resp.GetStatus()); err != nil { - return nil, nil, 0, err + return nil, nil, 0, nil, err } + channelCPs := resp.GetChannelCps() if !async { segmentIDs, has := resp.GetCollSegIDs()[collName] ids := segmentIDs.GetData() @@ -232,14 +234,23 @@ func (c *GrpcClient) FlushV2(ctx context.Context, collName string, async bool, o // respect context deadline/cancel select { case <-ctx.Done(): - return nil, nil, 0, errors.New("deadline exceeded") + return nil, nil, 0, nil, errors.New("deadline exceeded") default: } time.Sleep(200 * time.Millisecond) } } } - return resp.GetCollSegIDs()[collName].GetData(), resp.GetFlushCollSegIDs()[collName].GetData(), resp.GetCollSealTimes()[collName], nil + channelCPEntities := make(map[string]msgpb.MsgPosition, len(channelCPs)) + for k, v := range channelCPs { + channelCPEntities[k] = msgpb.MsgPosition{ + ChannelName: v.GetChannelName(), + MsgID: v.GetMsgID(), + MsgGroup: v.GetMsgGroup(), + Timestamp: v.GetTimestamp(), + } + } + return resp.GetCollSegIDs()[collName].GetData(), resp.GetFlushCollSegIDs()[collName].GetData(), resp.GetCollSealTimes()[collName], channelCPEntities, nil } // DeleteByPks deletes entries related to provided primary keys From b091d819088ac9c73c15cfde5b4df3b949af8238 Mon Sep 17 00:00:00 2001 From: ThreadDao Date: Wed, 12 Jun 2024 07:33:49 +0800 Subject: [PATCH 08/11] feat: add test cases for query iterator (#760) Signed-off-by: ThreadDao --- entity/schema.go | 2 + test/base/milvus_client.go | 9 + test/common/response_check.go | 107 +++++- test/common/utils.go | 24 +- test/testcases/configure_test.go | 2 +- test/testcases/delete_test.go | 12 +- test/testcases/highlevel_test.go | 12 +- test/testcases/main_test.go | 2 +- test/testcases/query_test.go | 590 ++++++++++++++++++++++++++++--- test/testcases/search_test.go | 51 ++- 10 files changed, 735 insertions(+), 76 deletions(-) diff --git a/entity/schema.go b/entity/schema.go index 6b31bc08..03086bfa 100644 --- a/entity/schema.go +++ b/entity/schema.go @@ -419,6 +419,8 @@ func (t FieldType) String() string { return "[]byte" case FieldTypeBFloat16Vector: return "[]byte" + case FieldTypeSparseVector: + return "[]SparseEmbedding" default: return "undefined" } diff --git a/test/base/milvus_client.go b/test/base/milvus_client.go index 4505c29f..e3a41e89 100644 --- a/test/base/milvus_client.go +++ b/test/base/milvus_client.go @@ -487,6 +487,15 @@ func (mc *MilvusClient) Get(ctx context.Context, collName string, ids entity.Col return queryResults, err } +// QueryIterator QueryIterator from collection +func (mc *MilvusClient) QueryIterator(ctx context.Context, opt *client.QueryIteratorOption) (*client.QueryIterator, error) { + funcName := "QueryIterator" + preRequest(funcName, ctx, opt) + itr, err := mc.mClient.QueryIterator(ctx, opt) + postResponse(funcName, err, itr) + return itr, err +} + // -- row based apis -- // CreateCollectionByRow Create Collection By Row diff --git a/test/common/response_check.go b/test/common/response_check.go index 869d6109..aa118e99 100644 --- a/test/common/response_check.go +++ b/test/common/response_check.go @@ -1,7 +1,9 @@ package common import ( + "context" "fmt" + "io" "log" "strings" "testing" @@ -160,14 +162,48 @@ func EqualColumn(t *testing.T, columnA entity.Column, columnB entity.Column) { require.ElementsMatch(t, columnA.(*entity.ColumnFloatVector).Data(), columnB.(*entity.ColumnFloatVector).Data()) case entity.FieldTypeBinaryVector: require.ElementsMatch(t, columnA.(*entity.ColumnBinaryVector).Data(), columnB.(*entity.ColumnBinaryVector).Data()) + case entity.FieldTypeFloat16Vector: + require.ElementsMatch(t, columnA.(*entity.ColumnFloat16Vector).Data(), columnB.(*entity.ColumnFloat16Vector).Data()) + case entity.FieldTypeBFloat16Vector: + require.ElementsMatch(t, columnA.(*entity.ColumnBFloat16Vector).Data(), columnB.(*entity.ColumnBFloat16Vector).Data()) + case entity.FieldTypeSparseVector: + require.ElementsMatch(t, columnA.(*entity.ColumnSparseFloatVector).Data(), columnB.(*entity.ColumnSparseFloatVector).Data()) case entity.FieldTypeArray: - log.Println("TODO support column element type") + EqualArrayColumn(t, columnA, columnB) default: log.Printf("The column type not in: [%v, %v, %v, %v, %v, %v, %v, %v, %v, %v, %v, %v]", entity.FieldTypeBool, entity.FieldTypeInt8, entity.FieldTypeInt16, entity.FieldTypeInt32, entity.FieldTypeInt64, entity.FieldTypeFloat, entity.FieldTypeDouble, entity.FieldTypeString, entity.FieldTypeVarChar, entity.FieldTypeArray, entity.FieldTypeFloatVector, entity.FieldTypeBinaryVector) + } +} +// EqualColumn assert field data is equal of two columns +func EqualArrayColumn(t *testing.T, columnA entity.Column, columnB entity.Column) { + require.Equal(t, columnA.Name(), columnB.Name()) + require.IsType(t, columnA.Type(), entity.FieldTypeArray) + require.IsType(t, columnB.Type(), entity.FieldTypeArray) + switch columnA.(type) { + case *entity.ColumnBoolArray: + require.ElementsMatch(t, columnA.(*entity.ColumnBoolArray).Data(), columnB.(*entity.ColumnBoolArray).Data()) + case *entity.ColumnInt8Array: + require.ElementsMatch(t, columnA.(*entity.ColumnInt8Array).Data(), columnB.(*entity.ColumnInt8Array).Data()) + case *entity.ColumnInt16Array: + require.ElementsMatch(t, columnA.(*entity.ColumnInt16Array).Data(), columnB.(*entity.ColumnInt16Array).Data()) + case *entity.ColumnInt32Array: + require.ElementsMatch(t, columnA.(*entity.ColumnInt32Array).Data(), columnB.(*entity.ColumnInt32Array).Data()) + case *entity.ColumnInt64Array: + require.ElementsMatch(t, columnA.(*entity.ColumnInt64Array).Data(), columnB.(*entity.ColumnInt64Array).Data()) + case *entity.ColumnFloatArray: + require.ElementsMatch(t, columnA.(*entity.ColumnFloatArray).Data(), columnB.(*entity.ColumnFloatArray).Data()) + case *entity.ColumnDoubleArray: + require.ElementsMatch(t, columnA.(*entity.ColumnDoubleArray).Data(), columnB.(*entity.ColumnDoubleArray).Data()) + case *entity.ColumnVarCharArray: + require.ElementsMatch(t, columnA.(*entity.ColumnVarCharArray).Data(), columnB.(*entity.ColumnVarCharArray).Data()) + default: + log.Printf("Now support array type: [%v, %v, %v, %v, %v, %v, %v, %v]", + entity.FieldTypeBool, entity.FieldTypeInt8, entity.FieldTypeInt16, entity.FieldTypeInt32, + entity.FieldTypeInt64, entity.FieldTypeFloat, entity.FieldTypeDouble, entity.FieldTypeVarChar) } } @@ -203,6 +239,75 @@ func CheckSearchResult(t *testing.T, actualSearchResults []client.SearchResult, } +func EqualIntSlice(a []int, b []int) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +type CheckIteratorOption func(opt *checkIteratorOpt) + +type checkIteratorOpt struct { + expBatchSize []int + expOutputFields []string +} + +func WithExpBatchSize(expBatchSize []int) CheckIteratorOption { + return func(opt *checkIteratorOpt) { + opt.expBatchSize = expBatchSize + } +} + +func WithExpOutputFields(expOutputFields []string) CheckIteratorOption { + return func(opt *checkIteratorOpt) { + opt.expOutputFields = expOutputFields + } +} + +// check queryIterator: result limit, each batch size, output fields +func CheckQueryIteratorResult(ctx context.Context, t *testing.T, itr *client.QueryIterator, expLimit int, opts ...CheckIteratorOption) { + opt := &checkIteratorOpt{} + for _, o := range opts { + o(opt) + } + actualLimit := 0 + var actualBatchSize []int + for { + rs, err := itr.Next(ctx) + if err != nil { + if err == io.EOF { + break + } + log.Fatalf("QueryIterator next gets error: %v", err) + } + //log.Printf("QueryIterator result len: %d", rs.Len()) + //log.Printf("QueryIterator result data: %d", rs.GetColumn("int64")) + + if opt.expBatchSize != nil { + actualBatchSize = append(actualBatchSize, rs.Len()) + } + var actualOutputFields []string + if opt.expOutputFields != nil { + for _, column := range rs { + actualOutputFields = append(actualOutputFields, column.Name()) + } + require.ElementsMatch(t, opt.expOutputFields, actualOutputFields) + } + actualLimit = actualLimit + rs.Len() + } + require.Equal(t, expLimit, actualLimit) + if opt.expBatchSize != nil { + log.Printf("QueryIterator result len: %v", actualBatchSize) + require.True(t, EqualIntSlice(opt.expBatchSize, actualBatchSize)) + } +} + // CheckPersistentSegments check persistent segments func CheckPersistentSegments(t *testing.T, actualSegments []*entity.Segment, expNb int64) { actualNb := int64(0) diff --git a/test/common/utils.go b/test/common/utils.go index c0fc80eb..6df077ae 100644 --- a/test/common/utils.go +++ b/test/common/utils.go @@ -78,6 +78,7 @@ const ( DefaultPartitionNum = 16 // default num_partitions MaxTopK = 16384 MaxVectorFieldNum = 4 + DefaultBatchSize = 1000 ) var IndexStateValue = map[string]int32{ @@ -842,7 +843,7 @@ func GenDefaultJSONRows(start int, nb int, dim int64, enableDynamicField bool) [ } for i := start; i < start+nb; i++ { - // jsonStruct row and dynamic row + //jsonStruct row and dynamic row var jsonStruct JSONStruct if i%2 == 0 { jsonStruct = JSONStruct{ @@ -1386,6 +1387,8 @@ type InvalidExprStruct struct { var InvalidExpressions = []InvalidExprStruct{ {Expr: "id in [0]", ErrNil: true, ErrMsg: "fieldName(id) not found"}, // not exist field but no error {Expr: "int64 in not [0]", ErrNil: false, ErrMsg: "cannot parse expression"}, // wrong term expr keyword + {Expr: "int64 > 10 AND int64 < 100", ErrNil: false, ErrMsg: "cannot parse expression"}, // AND isn't supported + {Expr: "int64 < 10 OR int64 > 100", ErrNil: false, ErrMsg: "cannot parse expression"}, // OR isn't supported {Expr: "int64 < floatVec", ErrNil: false, ErrMsg: "not supported"}, // unsupported compare field {Expr: "floatVec in [0]", ErrNil: false, ErrMsg: "cannot be casted to FloatVector"}, // value and field type mismatch {Expr: fmt.Sprintf("%s == 1", DefaultJSONFieldName), ErrNil: true, ErrMsg: ""}, // hist empty @@ -1406,4 +1409,23 @@ var InvalidExpressions = []InvalidExprStruct{ {Expr: fmt.Sprintf(fmt.Sprintf("%s[-1] > 1", DefaultJSONFieldName)), ErrNil: false, ErrMsg: "invalid expression"}, // json[-1] > } +func GenBatchSizes(limit int, batch int) []int { + if batch == 0 { + log.Fatal("Batch should be larger than 0") + } + if limit == 0 { + return []int{} + } + _loop := limit / batch + _last := limit % batch + batchSizes := make([]int, 0, _loop+1) + for i := 0; i < _loop; i++ { + batchSizes = append(batchSizes, batch) + } + if _last > 0 { + batchSizes = append(batchSizes, _last) + } + return batchSizes +} + // --- search utils --- diff --git a/test/testcases/configure_test.go b/test/testcases/configure_test.go index 115f76d6..79ace802 100644 --- a/test/testcases/configure_test.go +++ b/test/testcases/configure_test.go @@ -74,7 +74,7 @@ func TestCompactAfterDelete(t *testing.T) { common.CheckErr(t, err, true) // delete half ids - deleteIds := entity.NewColumnInt64(common.DefaultIntFieldName, ids.(*entity.ColumnInt64).Data()[:common.DefaultNb/2]) + deleteIds := ids.Slice(0, common.DefaultNb/2) errDelete := mc.DeleteByPks(ctx, collName, "", deleteIds) common.CheckErr(t, errDelete, true) diff --git a/test/testcases/delete_test.go b/test/testcases/delete_test.go index 6ba19f7b..0ae6b24e 100644 --- a/test/testcases/delete_test.go +++ b/test/testcases/delete_test.go @@ -28,7 +28,7 @@ func TestDelete(t *testing.T) { common.CheckErr(t, errLoad, true) // delete - deleteIds := entity.NewColumnInt64(common.DefaultIntFieldName, ids.(*entity.ColumnInt64).Data()[:10]) + deleteIds := ids.Slice(0, 10) errDelete := mc.DeleteByPks(ctx, collName, common.DefaultPartition, deleteIds) common.CheckErr(t, errDelete, true) @@ -48,7 +48,7 @@ func TestDeleteStringPks(t *testing.T) { collName, ids := createVarcharCollectionWithDataIndex(ctx, t, mc, true, client.WithConsistencyLevel(entity.ClStrong)) // delete - deleteIds := entity.NewColumnVarChar(common.DefaultVarcharFieldName, ids.(*entity.ColumnVarChar).Data()[:10]) + deleteIds := ids.Slice(0, 10) errDelete := mc.DeleteByPks(ctx, collName, common.DefaultPartition, deleteIds) common.CheckErr(t, errDelete, true) @@ -103,7 +103,7 @@ func TestDeleteNotExistPartition(t *testing.T) { common.CheckErr(t, errLoad, true) // delete - deleteIds := entity.NewColumnInt64(common.DefaultIntFieldName, ids.(*entity.ColumnInt64).Data()[:10]) + deleteIds := ids.Slice(0, 10) errDelete := mc.DeleteByPks(ctx, collName, "p1", deleteIds) common.CheckErr(t, errDelete, false, fmt.Sprintf("partition p1 of collection %s does not exist", collName)) } @@ -125,7 +125,7 @@ func TestDeleteEmptyPartitionNames(t *testing.T) { mc.Flush(ctx, collName, false) // delete - deleteIds := entity.NewColumnInt64(common.DefaultIntFieldName, intColumn.(*entity.ColumnInt64).Data()[:10]) + deleteIds := intColumn.Slice(0, 10) errDelete := mc.DeleteByPks(ctx, collName, emptyPartitionName, deleteIds) common.CheckErr(t, errDelete, true) @@ -160,7 +160,7 @@ func TestDeleteEmptyPartition(t *testing.T) { common.CheckErr(t, errLoad, true) // delete from empty partition p1 - deleteIds := entity.NewColumnInt64(common.DefaultIntFieldName, ids.(*entity.ColumnInt64).Data()[:10]) + deleteIds := ids.Slice(0, 10) errDelete := mc.DeleteByPks(ctx, collName, "p1", deleteIds) common.CheckErr(t, errDelete, true) @@ -186,7 +186,7 @@ func TestDeletePartitionIdsNotMatch(t *testing.T) { partitionName, vecColumnDefault, _ := createInsertTwoPartitions(ctx, t, mc, collName, common.DefaultNb) // delete [0:10) from new partition -> delete nothing - deleteIds := entity.NewColumnInt64(common.DefaultIntFieldName, vecColumnDefault.IdsColumn.(*entity.ColumnInt64).Data()[:10]) + deleteIds := vecColumnDefault.IdsColumn.Slice(0, 10) errDelete := mc.DeleteByPks(ctx, collName, partitionName, deleteIds) common.CheckErr(t, errDelete, true) diff --git a/test/testcases/highlevel_test.go b/test/testcases/highlevel_test.go index 0cf02b50..03b6e8c6 100644 --- a/test/testcases/highlevel_test.go +++ b/test/testcases/highlevel_test.go @@ -65,13 +65,13 @@ func TestNewCollection(t *testing.T) { queryResult, err := mc.Get( ctx, collName, - entity.NewColumnInt64(DefaultPkFieldName, pkColumn.(*entity.ColumnInt64).Data()[:10]), + pkColumn.Slice(0, 10), ) common.CheckErr(t, err, true) common.CheckOutputFields(t, queryResult, []string{DefaultPkFieldName, DefaultVectorFieldName}) common.CheckQueryResult(t, queryResult, []entity.Column{ - entity.NewColumnInt64(DefaultPkFieldName, pkColumn.(*entity.ColumnInt64).Data()[:10]), - entity.NewColumnFloatVector(DefaultVectorFieldName, int(common.DefaultDim), vecColumn.(*entity.ColumnFloatVector).Data()[:10]), + pkColumn.Slice(0, 10), + vecColumn.Slice(0, 10), }) // search @@ -142,13 +142,13 @@ func TestNewCollectionCustomize(t *testing.T) { queryResult, err := mc.Get( ctx, collName, - entity.NewColumnVarChar(pkFieldName, pkColumn.(*entity.ColumnVarChar).Data()[:10]), + pkColumn.Slice(0, 10), ) common.CheckErr(t, err, true) common.CheckOutputFields(t, queryResult, []string{pkFieldName, vectorFieldName}) common.CheckQueryResult(t, queryResult, []entity.Column{ - entity.NewColumnVarChar(pkFieldName, pkColumn.(*entity.ColumnVarChar).Data()[:10]), - entity.NewColumnFloatVector(vectorFieldName, int(common.DefaultDim), vecColumn.(*entity.ColumnFloatVector).Data()[:10]), + pkColumn.Slice(0, 10), + vecColumn.Slice(0, 10), }) // search diff --git a/test/testcases/main_test.go b/test/testcases/main_test.go index db7e16fc..a75d3532 100644 --- a/test/testcases/main_test.go +++ b/test/testcases/main_test.go @@ -248,7 +248,7 @@ const ( Int64FloatVecJSON CollectionFieldsType = "PkInt64FloatVecJson" // int64 + float + floatVec + json Int64FloatVecArray CollectionFieldsType = "Int64FloatVecArray" // int64 + float + floatVec + all array Int64VarcharSparseVec CollectionFieldsType = "Int64VarcharSparseVec" // int64 + varchar + float32Vec + sparseVec - AllVectors CollectionFieldsType = "AllVectors" // int64 + fp32Vec + fp16Vec + binaryVec + AllVectors CollectionFieldsType = "AllVectors" // int64 + fp32Vec + fp16Vec + bf16Vec + binaryVec AllFields CollectionFieldsType = "AllFields" // all scalar fields + floatVec ) diff --git a/test/testcases/query_test.go b/test/testcases/query_test.go index 1aa41fdc..55bd6ec1 100644 --- a/test/testcases/query_test.go +++ b/test/testcases/query_test.go @@ -5,6 +5,7 @@ package testcases import ( "encoding/json" "fmt" + "io" "log" "strconv" "testing" @@ -32,16 +33,15 @@ func TestQueryDefaultPartition(t *testing.T) { common.CheckErr(t, errLoad, true) //query - pks := ids.(*entity.ColumnInt64).Data() + pks := ids.Slice(0, 10) var queryResult, _ = mc.QueryByPks( ctx, collName, []string{common.DefaultPartition}, - entity.NewColumnInt64(common.DefaultIntFieldName, pks[:10]), + pks, []string{common.DefaultIntFieldName}, ) - expColumn := entity.NewColumnInt64(common.DefaultIntFieldName, pks[:10]) - common.CheckQueryResult(t, queryResult, []entity.Column{expColumn}) + common.CheckQueryResult(t, queryResult, []entity.Column{pks}) } // test query with varchar field filter @@ -58,16 +58,15 @@ func TestQueryVarcharField(t *testing.T) { common.CheckErr(t, errLoad, true) //query - pks := ids.(*entity.ColumnVarChar).Data() + pks := ids.Slice(0, 10) queryResult, _ := mc.QueryByPks( ctx, collName, []string{common.DefaultPartition}, - entity.NewColumnVarChar(common.DefaultVarcharFieldName, pks[:10]), + pks, []string{common.DefaultVarcharFieldName}, ) - expColumn := entity.NewColumnVarChar(common.DefaultVarcharFieldName, pks[:10]) - common.CheckQueryResult(t, queryResult, []entity.Column{expColumn}) + common.CheckQueryResult(t, queryResult, []entity.Column{pks}) } // query from not existed collection @@ -84,12 +83,12 @@ func TestQueryNotExistCollection(t *testing.T) { common.CheckErr(t, errLoad, true) //query - pks := ids.(*entity.ColumnInt64).Data() + pks := ids.Slice(0, 10) _, errQuery := mc.QueryByPks( ctx, "collName", []string{common.DefaultPartition}, - entity.NewColumnInt64(common.DefaultIntFieldName, pks[:10]), + pks, []string{common.DefaultIntFieldName}, ) common.CheckErr(t, errQuery, false, "can't find collection") @@ -109,12 +108,12 @@ func TestQueryNotExistPartition(t *testing.T) { common.CheckErr(t, errLoad, true) //query - pks := ids.(*entity.ColumnInt64).Data() + pks := ids.Slice(0, 10) _, errQuery := mc.QueryByPks( ctx, collName, []string{"aaa"}, - entity.NewColumnInt64(common.DefaultIntFieldName, pks[:10]), + pks, []string{common.DefaultIntFieldName}, ) common.CheckErr(t, errQuery, false, "partition name aaa not found") @@ -150,7 +149,7 @@ func TestQueryEmptyPartitionName(t *testing.T) { ctx, collName, []string{emptyPartitionName}, - entity.NewColumnInt64(common.DefaultIntFieldName, intColumn.(*entity.ColumnInt64).Data()[:10]), + intColumn.Slice(0, 10), []string{common.DefaultIntFieldName}, ) common.CheckErr(t, errQuery, false, "Partition name should not be empty") @@ -286,7 +285,7 @@ func TestQueryEmptyOutputFields(t *testing.T) { //query with empty output fields []string{}-> output "int64" queryEmptyOutputs, _ := mc.QueryByPks( ctx, collName, []string{common.DefaultPartition}, - entity.NewColumnInt64(common.DefaultIntFieldName, ids.(*entity.ColumnInt64).Data()[:10]), + ids.Slice(0, 10), []string{}, ) common.CheckOutputFields(t, queryEmptyOutputs, []string{common.DefaultIntFieldName}) @@ -294,7 +293,7 @@ func TestQueryEmptyOutputFields(t *testing.T) { //query with empty output fields []string{""}-> output "int64" and dynamic field queryEmptyOutputs, err := mc.QueryByPks( ctx, collName, []string{common.DefaultPartition}, - entity.NewColumnInt64(common.DefaultIntFieldName, ids.(*entity.ColumnInt64).Data()[:10]), + ids.Slice(0, 10), []string{""}, ) if enableDynamic { @@ -306,7 +305,7 @@ func TestQueryEmptyOutputFields(t *testing.T) { // query with "float" output fields -> output "int64, float" queryFloatOutputs, _ := mc.QueryByPks( ctx, collName, []string{common.DefaultPartition}, - entity.NewColumnInt64(common.DefaultIntFieldName, ids.(*entity.ColumnInt64).Data()[:10]), + ids.Slice(0, 10), []string{common.DefaultFloatFieldName}, ) common.CheckOutputFields(t, queryFloatOutputs, []string{common.DefaultIntFieldName, common.DefaultFloatFieldName}) @@ -341,14 +340,13 @@ func TestQueryOutputFields(t *testing.T) { queryResult, _ := mc.QueryByPks( ctx, collName, []string{common.DefaultPartition}, - entity.NewColumnInt64(common.DefaultIntFieldName, intColumn.(*entity.ColumnInt64).Data()[:pos]), + intColumn.Slice(0, pos), []string{common.DefaultIntFieldName, common.DefaultFloatFieldName, common.DefaultFloatVecFieldName}, ) common.CheckQueryResult(t, queryResult, []entity.Column{ - entity.NewColumnInt64(common.DefaultIntFieldName, intColumn.(*entity.ColumnInt64).Data()[:pos]), - entity.NewColumnFloat(common.DefaultFloatFieldName, floatColumn.(*entity.ColumnFloat).Data()[:pos]), - entity.NewColumnFloatVector(common.DefaultFloatVecFieldName, int(common.DefaultDim), vecColumn.(*entity.ColumnFloatVector).Data()[:pos]), - }) + intColumn.Slice(0, pos), + floatColumn.Slice(0, pos), + vecColumn.Slice(0, pos)}) common.CheckOutputFields(t, queryResult, []string{common.DefaultIntFieldName, common.DefaultFloatFieldName, common.DefaultFloatVecFieldName}) } @@ -382,33 +380,78 @@ func TestQueryOutputBinaryAndVarchar(t *testing.T) { ctx, collName, []string{common.DefaultPartition}, - entity.NewColumnVarChar(common.DefaultVarcharFieldName, varcharColumn.(*entity.ColumnVarChar).Data()[:pos]), + varcharColumn.Slice(0, pos), []string{common.DefaultBinaryVecFieldName}, ) common.CheckQueryResult(t, queryResult, []entity.Column{ - entity.NewColumnVarChar(common.DefaultVarcharFieldName, varcharColumn.(*entity.ColumnVarChar).Data()[:pos]), - entity.NewColumnBinaryVector(common.DefaultBinaryVecFieldName, int(common.DefaultDim), vecColumn.(*entity.ColumnBinaryVector).Data()[:pos]), - }) + varcharColumn.Slice(0, pos), + vecColumn.Slice(0, pos)}) common.CheckOutputFields(t, queryResult, []string{common.DefaultBinaryVecFieldName, common.DefaultVarcharFieldName}) } // test query output all fields -func TestOutputAllFields(t *testing.T) { +func TestOutputAllFieldsRows(t *testing.T) { ctx := createContext(t, time.Second*common.DefaultTimeout) // connect mc := createMilvusClient(ctx, t) - for _, withRows := range []bool{true, false} { - // create collection - var capacity int64 = common.TestCapacity - cp := CollectionParams{CollectionFieldsType: AllFields, AutoID: false, EnableDynamicField: true, + // create collection + var capacity int64 = common.TestCapacity + cp := CollectionParams{CollectionFieldsType: AllFields, AutoID: false, EnableDynamicField: true, + ShardsNum: common.DefaultShards, Dim: common.DefaultDim, MaxCapacity: capacity} + collName := createCollection(ctx, t, mc, cp) + + // prepare and insert data + dp := DataParams{CollectionName: collName, PartitionName: "", CollectionFieldsType: AllFields, + start: 0, nb: common.DefaultNb, dim: common.DefaultDim, EnableDynamicField: true, WithRows: true} + _, _ = insertData(ctx, t, mc, dp, common.WithArrayCapacity(capacity)) + + // flush and check row count + errFlush := mc.Flush(ctx, collName, false) + common.CheckErr(t, errFlush, true) + + idx, _ := entity.NewIndexHNSW(entity.L2, 8, 96) + for _, fieldName := range []string{"floatVec", "fp16Vec", "bf16Vec"} { + _ = mc.CreateIndex(ctx, collName, fieldName, idx, false) + } + binIdx, _ := entity.NewIndexBinFlat(entity.JACCARD, 16) + _ = mc.CreateIndex(ctx, collName, "binaryVec", binIdx, false) + + // Load collection + errLoad := mc.LoadCollection(ctx, collName, false) + common.CheckErr(t, errLoad, true) + + // query output all fields -> output all fields, includes vector and $meta field + allFieldsName := append(common.AllArrayFieldsName, "int64", "bool", "int8", "int16", "int32", "float", + "double", "varchar", "json", "floatVec", "fp16Vec", "bf16Vec", "binaryVec", common.DefaultDynamicFieldName) + queryResultAll, errQuery := mc.Query(ctx, collName, []string{}, + fmt.Sprintf("%s == 0", common.DefaultIntFieldName), []string{"*"}) + common.CheckErr(t, errQuery, true) + common.CheckOutputFields(t, queryResultAll, allFieldsName) +} + +// test query output all fields and verify data +func TestOutputAllFieldsColumn(t *testing.T) { + ctx := createContext(t, time.Second*common.DefaultTimeout) + // connect + mc := createMilvusClient(ctx, t) + + // create collection + var capacity int64 = common.TestCapacity + for _, isDynamic := range [2]bool{true, false} { + cp := CollectionParams{CollectionFieldsType: AllFields, AutoID: false, EnableDynamicField: isDynamic, ShardsNum: common.DefaultShards, Dim: common.DefaultDim, MaxCapacity: capacity} collName := createCollection(ctx, t, mc, cp) // prepare and insert data - dp := DataParams{CollectionName: collName, PartitionName: "", CollectionFieldsType: AllFields, - start: 0, nb: common.DefaultNb, dim: common.DefaultDim, EnableDynamicField: true, WithRows: withRows} - _, _ = insertData(ctx, t, mc, dp, common.WithArrayCapacity(capacity)) + data := common.GenAllFieldsData(0, common.DefaultNb, common.DefaultDim, common.WithArrayCapacity(10)) + _data := data + if isDynamic { + _data = append(_data, common.GenDynamicFieldData(0, common.DefaultNb)...) + } + ids, err := mc.Insert(ctx, collName, "", _data...) + common.CheckErr(t, err, true) + require.Equal(t, common.DefaultNb, ids.Len()) // flush and check row count errFlush := mc.Flush(ctx, collName, false) @@ -426,12 +469,24 @@ func TestOutputAllFields(t *testing.T) { common.CheckErr(t, errLoad, true) // query output all fields -> output all fields, includes vector and $meta field + pos := 10 allFieldsName := append(common.AllArrayFieldsName, "int64", "bool", "int8", "int16", "int32", "float", - "double", "varchar", "json", "floatVec", "fp16Vec", "bf16Vec", "binaryVec", common.DefaultDynamicFieldName) - queryResultAll, errQuery := mc.Query(ctx, collName, []string{}, - fmt.Sprintf("%s == 0", common.DefaultIntFieldName), []string{"*"}) + "double", "varchar", "json", "floatVec", "fp16Vec", "bf16Vec", "binaryVec") + if isDynamic { + allFieldsName = append(allFieldsName, common.DefaultDynamicFieldName) + } + queryResultAll, errQuery := mc.Query(ctx, collName, []string{}, fmt.Sprintf("%s < %d", common.DefaultIntFieldName, pos), []string{"*"}) common.CheckErr(t, errQuery, true) common.CheckOutputFields(t, queryResultAll, allFieldsName) + + expColumns := make([]entity.Column, 0, len(data)+1) + for _, column := range data { + expColumns = append(expColumns, column.Slice(0, pos)) + } + if isDynamic { + expColumns = append(expColumns, common.MergeColumnsToDynamic(pos, common.GenDynamicFieldData(0, pos))) + } + common.CheckQueryResult(t, queryResultAll, expColumns) } } @@ -453,7 +508,7 @@ func TestQueryOutputNotExistField(t *testing.T) { ctx, collName, []string{common.DefaultPartition}, - entity.NewColumnInt64(common.DefaultIntFieldName, ids.(*entity.ColumnInt64).Data()[:10]), + ids.Slice(0, 10), []string{common.DefaultIntFieldName, "varchar"}, ) common.CheckErr(t, errQuery, false, "field varchar not exist") @@ -708,6 +763,7 @@ func TestQueryArrayFieldExpr(t *testing.T) { log.Println(_exprCount.expr) countRes, err := mc.Query(ctx, collName, []string{}, _exprCount.expr, []string{common.QueryCountFieldName}) + log.Println(countRes.GetColumn(common.QueryCountFieldName).FieldData()) common.CheckErr(t, err, true) require.Equal(t, _exprCount.count, countRes.GetColumn(common.QueryCountFieldName).(*entity.ColumnInt64).Data()[0]) } @@ -875,14 +931,13 @@ func TestQueryJsonDynamicFieldRows(t *testing.T) { []string{common.DefaultIntFieldName, common.DefaultJSONFieldName, common.DefaultDynamicFieldName}, ) common.CheckErr(t, err, true) - //jsonColumn := common.GenDefaultJSONData(common.DefaultJSONFieldName, 0, 2) m0 := common.JSONStruct{String: strconv.Itoa(0), Bool: true} j0, _ := json.Marshal(&m0) m1 := common.JSONStruct{Number: int32(1), String: strconv.Itoa(1), Bool: false, List: []int64{int64(1), int64(2)}} j1, _ := json.Marshal(&m1) jsonValues := [][]byte{j0, j1} jsonColumn := entity.NewColumnJSONBytes(common.DefaultJSONFieldName, jsonValues) - dynamicColumn := common.MergeColumnsToDynamic(2, common.GenDynamicFieldData(0, 2)) + dynamicColumn := common.MergeColumnsToDynamic(10, common.GenDynamicFieldData(0, 10)) // gen dynamic json column for _, column := range queryResult { @@ -899,24 +954,17 @@ func TestQueryJsonDynamicFieldRows(t *testing.T) { log.Println(jsonData) } } - common.CheckQueryResult(t, queryResult, []entity.Column{pkColumn, jsonColumn, dynamicColumn}) + common.CheckQueryResult(t, queryResult, []entity.Column{pkColumn, jsonColumn, dynamicColumn.Slice(0, 2)}) // query with different expr and count expr := fmt.Sprintf("%s['number'] < 10 && %s < 10", common.DefaultJSONFieldName, common.DefaultDynamicNumberField) queryRes, _ := mc.Query(ctx, collName, - []string{common.DefaultPartition}, - expr, []string{common.DefaultJSONFieldName, common.DefaultDynamicNumberField}) + []string{common.DefaultPartition}, expr, []string{common.DefaultDynamicNumberField}) // verify output fields and count, dynamicNumber value - common.CheckOutputFields(t, queryRes, []string{common.DefaultIntFieldName, common.DefaultJSONFieldName, common.DefaultDynamicNumberField}) - require.Equal(t, 10, queryRes.GetColumn(common.DefaultJSONFieldName).Len()) - dynamicNumColumn := queryRes.GetColumn(common.DefaultDynamicNumberField) - var numberData []int64 - for i := 0; i < dynamicNumColumn.Len(); i++ { - line, _ := dynamicNumColumn.GetAsInt64(i) - numberData = append(numberData, line) - } - require.Equal(t, numberData, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) + common.CheckOutputFields(t, queryRes, []string{common.DefaultIntFieldName, common.DefaultDynamicNumberField}) + pkColumn2 := common.GenColumnData(0, 10, entity.FieldTypeInt64, common.DefaultIntFieldName) + common.CheckQueryResult(t, queryRes, []entity.Column{pkColumn2, dynamicColumn}) } // test query with invalid expr @@ -1097,7 +1145,7 @@ func TestQuerySparseVector(t *testing.T) { // insert intColumn, _, floatColumn := common.GenDefaultColumnData(0, common.DefaultNb, common.DefaultDim) varColumn := common.GenColumnData(0, common.DefaultNb, entity.FieldTypeVarChar, common.DefaultVarcharFieldName) - sparseColumn := common.GenColumnData(0, common.DefaultNb, entity.FieldTypeSparseVector, common.DefaultSparseVecFieldName) + sparseColumn := common.GenColumnData(0, common.DefaultNb, entity.FieldTypeSparseVector, common.DefaultSparseVecFieldName, common.WithSparseVectorLen(20)) mc.Insert(ctx, collName, "", intColumn, varColumn, floatColumn, sparseColumn) mc.Flush(ctx, collName, false) mc.LoadCollection(ctx, collName, false) @@ -1107,14 +1155,440 @@ func TestQuerySparseVector(t *testing.T) { require.Equal(t, int64(common.DefaultNb), countRes.GetColumn(common.QueryCountFieldName).(*entity.ColumnInt64).Data()[0]) // query - queryResult, err := mc.Query(ctx, collName, []string{}, fmt.Sprintf("%s == 0", common.DefaultIntFieldName), []string{"*"}) + queryResult, err := mc.Query(ctx, collName, []string{}, fmt.Sprintf("%s in [0, 1]", common.DefaultIntFieldName), []string{"*"}) common.CheckErr(t, err, true) - expIntColumn := entity.NewColumnInt64(common.DefaultIntFieldName, intColumn.(*entity.ColumnInt64).Data()[:1]) - expVarcharColumn := entity.NewColumnVarChar(common.DefaultVarcharFieldName, varColumn.(*entity.ColumnVarChar).Data()[:1]) - expVecColumn := entity.NewColumnFloatVector(common.DefaultFloatVecFieldName, int(common.DefaultDim), floatColumn.(*entity.ColumnFloatVector).Data()[:1]) - expSparseColumn := entity.NewColumnSparseVectors(common.DefaultSparseVecFieldName, sparseColumn.(*entity.ColumnSparseFloatVector).Data()[:1]) common.CheckOutputFields(t, queryResult, []string{common.DefaultIntFieldName, common.DefaultVarcharFieldName, common.DefaultFloatVecFieldName, common.DefaultSparseVecFieldName}) - common.CheckQueryResult(t, queryResult, []entity.Column{expIntColumn, expVarcharColumn, expVecColumn, expSparseColumn}) + t.Log("https://github.com/milvus-io/milvus-sdk-go/issues/769") + //common.CheckQueryResult(t, queryResult, []entity.Column{intColumn.Slice(0, 2), varColumn.Slice(0, 2), floatColumn.Slice(0, 2), sparseColumn.Slice(0, 2)}) + } +} + +// test query iterator default +func TestQueryIteratorDefault(t *testing.T) { + ctx := createContext(t, time.Second*common.DefaultTimeout) + // connect + mc := createMilvusClient(ctx, t) + + // create collection + cp := CollectionParams{CollectionFieldsType: Int64FloatVec, AutoID: false, EnableDynamicField: true, + ShardsNum: common.DefaultShards, Dim: common.DefaultDim} + collName := createCollection(ctx, t, mc, cp) + + // insert + dp := DataParams{CollectionName: collName, PartitionName: "", CollectionFieldsType: Int64FloatVec, + start: 0, nb: common.DefaultNb, dim: common.DefaultDim, EnableDynamicField: true, WithRows: false} + _, _ = insertData(ctx, t, mc, dp) + + dp2 := DataParams{CollectionName: collName, PartitionName: "", CollectionFieldsType: Int64FloatVec, + start: common.DefaultNb, nb: common.DefaultNb * 2, dim: common.DefaultDim, EnableDynamicField: true, WithRows: true} + _, _ = insertData(ctx, t, mc, dp2) + + idx, _ := entity.NewIndexHNSW(entity.L2, 8, 96) + _ = mc.CreateIndex(ctx, collName, common.DefaultFloatVecFieldName, idx, false) + + // Load collection + errLoad := mc.LoadCollection(ctx, collName, false) + common.CheckErr(t, errLoad, true) + + // query iterator with default batch + itr, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(collName)) + common.CheckErr(t, err, true) + common.CheckQueryIteratorResult(ctx, t, itr, common.DefaultNb*3, common.WithExpBatchSize(common.GenBatchSizes(common.DefaultNb*3, common.DefaultBatchSize))) +} + +// test query iterator default +func TestQueryIteratorHitEmpty(t *testing.T) { + ctx := createContext(t, time.Second*common.DefaultTimeout) + // connect + mc := createMilvusClient(ctx, t) + + // create collection + cp := CollectionParams{CollectionFieldsType: Int64FloatVec, AutoID: false, EnableDynamicField: true, + ShardsNum: common.DefaultShards, Dim: common.DefaultDim} + collName := createCollection(ctx, t, mc, cp) + + idx, _ := entity.NewIndexHNSW(entity.L2, 8, 96) + _ = mc.CreateIndex(ctx, collName, common.DefaultFloatVecFieldName, idx, false) + + // Load collection + errLoad := mc.LoadCollection(ctx, collName, false) + common.CheckErr(t, errLoad, true) + + // query iterator with default batch + itr, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(collName)) + common.CheckErr(t, err, true) + rs, err := itr.Next(ctx) + require.Empty(t, rs) + require.Error(t, err, io.EOF) + common.CheckQueryIteratorResult(ctx, t, itr, 0, common.WithExpBatchSize(common.GenBatchSizes(0, common.DefaultBatchSize))) +} + +func TestQueryIteratorBatchSize(t *testing.T) { + ctx := createContext(t, time.Second*common.DefaultTimeout) + // connect + mc := createMilvusClient(ctx, t) + // create collection + cp := CollectionParams{CollectionFieldsType: Int64FloatVec, AutoID: false, EnableDynamicField: true, + ShardsNum: common.DefaultShards, Dim: common.DefaultDim} + collName := createCollection(ctx, t, mc, cp) + + // insert + nb := 201 + dp := DataParams{CollectionName: collName, PartitionName: "", CollectionFieldsType: Int64FloatVec, + start: 0, nb: nb, dim: common.DefaultDim, EnableDynamicField: true, WithRows: false} + _, _ = insertData(ctx, t, mc, dp) + + idx, _ := entity.NewIndexHNSW(entity.L2, 8, 96) + _ = mc.CreateIndex(ctx, collName, common.DefaultFloatVecFieldName, idx, false) + + // Load collection + errLoad := mc.LoadCollection(ctx, collName, false) + common.CheckErr(t, errLoad, true) + + type batchStruct struct { + batch int + expBatchSize []int + } + batchStructs := []batchStruct{ + {batch: nb / 2, expBatchSize: common.GenBatchSizes(nb, nb/2)}, + {batch: nb, expBatchSize: common.GenBatchSizes(nb, nb)}, + {batch: nb + 1, expBatchSize: common.GenBatchSizes(nb, nb+1)}, + } + + for _, _batchStruct := range batchStructs { + // query iterator with default batch + itr, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(collName).WithBatchSize(_batchStruct.batch)) + common.CheckErr(t, err, true) + common.CheckQueryIteratorResult(ctx, t, itr, nb, common.WithExpBatchSize(_batchStruct.expBatchSize)) + } +} + +func TestQueryIteratorOutputAllFields(t *testing.T) { + t.Parallel() + ctx := createContext(t, time.Second*common.DefaultTimeout) + // connect + mc := createMilvusClient(ctx, t) + for _, dynamic := range [2]bool{false, true} { + // create collection + cp := CollectionParams{CollectionFieldsType: AllFields, AutoID: false, EnableDynamicField: dynamic, + ShardsNum: common.DefaultShards, Dim: common.DefaultDim} + collName := createCollection(ctx, t, mc, cp, client.WithConsistencyLevel(entity.ClStrong)) + + // insert + nb := 2501 + dp := DataParams{CollectionName: collName, PartitionName: "", CollectionFieldsType: AllFields, + start: 0, nb: nb, dim: common.DefaultDim, EnableDynamicField: dynamic, WithRows: false} + insertData(ctx, t, mc, dp) + + indexHnsw, _ := entity.NewIndexHNSW(entity.L2, 8, 96) + indexBinary, _ := entity.NewIndexBinIvfFlat(entity.JACCARD, 64) + for _, fieldName := range common.AllVectorsFieldsName { + if fieldName == common.DefaultBinaryVecFieldName { + mc.CreateIndex(ctx, collName, fieldName, indexBinary, false) + } else { + mc.CreateIndex(ctx, collName, fieldName, indexHnsw, false) + } + } + + // Load collection + errLoad := mc.LoadCollection(ctx, collName, false) + common.CheckErr(t, errLoad, true) + + // output * fields + nbFilter := 1001 + batch := 500 + expr := fmt.Sprintf("%s < %d", common.DefaultIntFieldName, nbFilter) + + itr, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(collName).WithBatchSize(batch).WithOutputFields("*").WithExpr(expr)) + common.CheckErr(t, err, true) + allFields := common.GetAllFieldsName(dynamic, false) + common.CheckQueryIteratorResult(ctx, t, itr, nbFilter, common.WithExpBatchSize(common.GenBatchSizes(nbFilter, batch)), common.WithExpOutputFields(allFields)) + } +} + +func TestQueryIteratorOutputSparseFieldsRows(t *testing.T) { + t.Parallel() + ctx := createContext(t, time.Second*common.DefaultTimeout) + // connect + mc := createMilvusClient(ctx, t) + for _, withRows := range [2]bool{true, false} { + // create collection + cp := CollectionParams{CollectionFieldsType: Int64VarcharSparseVec, AutoID: false, EnableDynamicField: true, + ShardsNum: common.DefaultShards, Dim: common.DefaultDim, MaxLength: common.TestMaxLen} + collName := createCollection(ctx, t, mc, cp) + + // insert + nb := 2501 + dp := DataParams{CollectionName: collName, PartitionName: "", CollectionFieldsType: Int64VarcharSparseVec, + start: 0, nb: nb, dim: common.DefaultDim, EnableDynamicField: true, WithRows: withRows, maxLenSparse: 1000} + _, _ = insertData(ctx, t, mc, dp) + + indexHnsw, _ := entity.NewIndexHNSW(entity.L2, 8, 96) + indexSparse, _ := entity.NewIndexSparseInverted(entity.IP, 0.1) + mc.CreateIndex(ctx, collName, common.DefaultFloatVecFieldName, indexHnsw, false) + mc.CreateIndex(ctx, collName, common.DefaultSparseVecFieldName, indexSparse, false) + + // Load collection + errLoad := mc.LoadCollection(ctx, collName, false) + common.CheckErr(t, errLoad, true) + + // output * fields + itr, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(collName).WithBatchSize(400).WithOutputFields("*")) + common.CheckErr(t, err, true) + fields := []string{common.DefaultIntFieldName, common.DefaultVarcharFieldName, common.DefaultFloatVecFieldName, common.DefaultSparseVecFieldName, common.DefaultDynamicFieldName} + common.CheckQueryIteratorResult(ctx, t, itr, nb, common.WithExpBatchSize(common.GenBatchSizes(nb, 400)), common.WithExpOutputFields(fields)) + } +} + +// test query iterator with non-existed collection/partition name, invalid batch size +func TestQueryIteratorInvalid(t *testing.T) { + ctx := createContext(t, time.Second*common.DefaultTimeout) + // connect + mc := createMilvusClient(ctx, t) + // create collection + cp := CollectionParams{CollectionFieldsType: Int64FloatVec, AutoID: false, EnableDynamicField: false, + ShardsNum: common.DefaultShards, Dim: common.DefaultDim} + collName := createCollection(ctx, t, mc, cp) + + // insert + nb := 201 + dp := DataParams{CollectionName: collName, PartitionName: "", CollectionFieldsType: Int64FloatVec, + start: 0, nb: nb, dim: common.DefaultDim, EnableDynamicField: false, WithRows: false} + _, _ = insertData(ctx, t, mc, dp) + + idx, _ := entity.NewIndexHNSW(entity.L2, 8, 96) + _ = mc.CreateIndex(ctx, collName, common.DefaultFloatVecFieldName, idx, false) + + // Load collection + errLoad := mc.LoadCollection(ctx, collName, false) + common.CheckErr(t, errLoad, true) + + // query iterator with not existed collection name + _, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption("aaa")) + common.CheckErr(t, err, false, "can't find collection") + + // query iterator with not existed partition name + _, errPar := mc.QueryIterator(ctx, client.NewQueryIteratorOption(collName).WithPartitions("aaa")) + common.CheckErr(t, errPar, false, "partition name aaa not found") + + // query iterator with not existed partition name + _, errPar = mc.QueryIterator(ctx, client.NewQueryIteratorOption(collName).WithPartitions("aaa", common.DefaultPartition)) + common.CheckErr(t, errPar, false, "partition name aaa not found") + + _, errOutput := mc.QueryIterator(ctx, client.NewQueryIteratorOption(collName).WithOutputFields(common.QueryCountFieldName)) + common.CheckErr(t, errOutput, false, "count entities with pagination is not allowed") + + // query iterator with invalid batch size + for _, batch := range []int{-1, 0} { + // query iterator with default batch + _, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(collName).WithBatchSize(batch)) + common.CheckErr(t, err, false, "batch size cannot less than 1") + } +} + +// query iterator with invalid expr +func TestQueryIteratorInvalidExpr(t *testing.T) { + ctx := createContext(t, time.Second*common.DefaultTimeout) + // connect + mc := createMilvusClient(ctx, t) + + // create collection + cp := CollectionParams{CollectionFieldsType: Int64FloatVecJSON, AutoID: false, EnableDynamicField: true, + ShardsNum: common.DefaultShards, Dim: common.DefaultDim, + } + collName := createCollection(ctx, t, mc, cp) + + // insert + dp := DataParams{CollectionName: collName, PartitionName: "", CollectionFieldsType: Int64FloatVecJSON, + start: 0, nb: common.DefaultNb, dim: common.DefaultDim, EnableDynamicField: true, + } + _, _ = insertData(ctx, t, mc, dp) + + idx, _ := entity.NewIndexHNSW(entity.L2, 8, 96) + _ = mc.CreateIndex(ctx, collName, common.DefaultFloatVecFieldName, idx, false) + + // Load collection + errLoad := mc.LoadCollection(ctx, collName, false) + common.CheckErr(t, errLoad, true) + + for _, _invalidExprs := range common.InvalidExpressions { + _, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(collName).WithExpr(_invalidExprs.Expr)) + common.CheckErr(t, err, _invalidExprs.ErrNil, _invalidExprs.ErrMsg) + } +} + +// test query iterator with non-existed field when dynamic or not +func TestQueryIteratorOutputFieldDynamic(t *testing.T) { + ctx := createContext(t, time.Second*common.DefaultTimeout) + // connect + mc := createMilvusClient(ctx, t) + for _, dynamic := range [2]bool{true, false} { + // create collection + cp := CollectionParams{CollectionFieldsType: Int64FloatVec, AutoID: false, EnableDynamicField: dynamic, + ShardsNum: common.DefaultShards, Dim: common.DefaultDim} + collName := createCollection(ctx, t, mc, cp) + // insert + nb := 201 + dp := DataParams{CollectionName: collName, PartitionName: "", CollectionFieldsType: Int64FloatVec, + start: 0, nb: nb, dim: common.DefaultDim, EnableDynamicField: dynamic, WithRows: false} + _, _ = insertData(ctx, t, mc, dp) + + idx, _ := entity.NewIndexHNSW(entity.L2, 8, 96) + _ = mc.CreateIndex(ctx, collName, common.DefaultFloatVecFieldName, idx, false) + + // Load collection + errLoad := mc.LoadCollection(ctx, collName, false) + common.CheckErr(t, errLoad, true) + + // query iterator with not existed output fields: if dynamic, non-existent field are equivalent to dynamic field + itr, errOutput := mc.QueryIterator(ctx, client.NewQueryIteratorOption(collName).WithOutputFields("aaa")) + if dynamic { + common.CheckErr(t, errOutput, true) + expFields := []string{common.DefaultIntFieldName, common.DefaultDynamicFieldName} + common.CheckQueryIteratorResult(ctx, t, itr, nb, common.WithExpBatchSize(common.GenBatchSizes(nb, common.DefaultBatchSize)), common.WithExpOutputFields(expFields)) + } else { + common.CheckErr(t, errOutput, false, "field aaa not exist") + } + } +} + +func TestQueryIteratorExpr(t *testing.T) { + //t.Log("https://github.com/milvus-io/milvus-sdk-go/issues/756") + type exprCount struct { + expr string + count int + } + capacity := common.TestCapacity + exprLimits := []exprCount{ + {expr: fmt.Sprintf("%s in [0, 1, 2]", common.DefaultIntFieldName), count: 3}, + {expr: fmt.Sprintf("%s >= 1000 || %s > 2000", common.DefaultIntFieldName, common.DefaultIntFieldName), count: 2000}, + {expr: fmt.Sprintf("%s >= 1000 and %s < 2000", common.DefaultIntFieldName, common.DefaultIntFieldName), count: 1000}, + + //json and dynamic field filter expr: == < in bool/ list/ int + {expr: fmt.Sprintf("%s['number'] == 0", common.DefaultJSONFieldName), count: 1500 / 2}, + {expr: fmt.Sprintf("%s['number'] < 100 and %s['number'] != 0", common.DefaultJSONFieldName, common.DefaultJSONFieldName), count: 50}, + {expr: fmt.Sprintf("%s < 100", common.DefaultDynamicNumberField), count: 100}, + {expr: "dynamicNumber % 2 == 0", count: 1500}, + {expr: fmt.Sprintf("%s == false", common.DefaultDynamicBoolField), count: 2000}, + {expr: fmt.Sprintf("%s in ['1', '2'] ", common.DefaultDynamicStringField), count: 2}, + {expr: fmt.Sprintf("%s['string'] in ['1', '2', '5'] ", common.DefaultJSONFieldName), count: 3}, + {expr: fmt.Sprintf("%s['list'] == [1, 2] ", common.DefaultJSONFieldName), count: 1}, + {expr: fmt.Sprintf("%s['list'][0] < 10 ", common.DefaultJSONFieldName), count: 5}, + {expr: fmt.Sprintf("%s[\"dynamicList\"] != [2, 3]", common.DefaultDynamicFieldName), count: 0}, + + // json contains + {expr: fmt.Sprintf("json_contains (%s['list'], 2)", common.DefaultJSONFieldName), count: 1}, + {expr: fmt.Sprintf("json_contains (%s['number'], 0)", common.DefaultJSONFieldName), count: 0}, + {expr: fmt.Sprintf("JSON_CONTAINS_ANY (%s['list'], [1, 3])", common.DefaultJSONFieldName), count: 2}, + // string like + {expr: "dynamicString like '1%' ", count: 1111}, + + // key exist + {expr: fmt.Sprintf("exists %s['list']", common.DefaultJSONFieldName), count: common.DefaultNb / 2}, + {expr: fmt.Sprintf("exists a "), count: 0}, + {expr: fmt.Sprintf("exists %s ", common.DefaultDynamicStringField), count: common.DefaultNb}, + + // data type not match and no error + {expr: fmt.Sprintf("%s['number'] == '0' ", common.DefaultJSONFieldName), count: 0}, + + // json field + {expr: fmt.Sprintf("%s >= 1500", common.DefaultJSONFieldName), count: 1500 / 2}, // json >= 1500 + {expr: fmt.Sprintf("%s > 1499.5", common.DefaultJSONFieldName), count: 1500 / 2}, // json >= 1500.0 + {expr: fmt.Sprintf("%s like '21%%'", common.DefaultJSONFieldName), count: 100 / 4}, // json like '21%' + {expr: fmt.Sprintf("%s == [1503, 1504]", common.DefaultJSONFieldName), count: 1}, // json == [1,2] + {expr: fmt.Sprintf("%s[0] > 1", common.DefaultJSONFieldName), count: 1500 / 4}, // json[0] > 1 + {expr: fmt.Sprintf("%s[0][0] > 1", common.DefaultJSONFieldName), count: 0}, // json == [1,2] + {expr: fmt.Sprintf("%s[0] == false", common.DefaultBoolArrayField), count: common.DefaultNb / 2}, // array[0] == + {expr: fmt.Sprintf("%s[0] > 0", common.DefaultInt64ArrayField), count: common.DefaultNb - 1}, // array[0] > + {expr: fmt.Sprintf("%s[0] > 0", common.DefaultInt8ArrayField), count: 1524}, // array[0] > int8 range: [-128, 127] + {expr: fmt.Sprintf("array_contains (%s, %d)", common.DefaultInt16ArrayField, capacity), count: capacity}, // array_contains(array, 1) + {expr: fmt.Sprintf("json_contains (%s, 1)", common.DefaultInt32ArrayField), count: 2}, // json_contains(array, 1) + {expr: fmt.Sprintf("array_contains (%s, 1000000)", common.DefaultInt32ArrayField), count: 0}, // array_contains(array, 1) + {expr: fmt.Sprintf("json_contains_all (%s, [90, 91])", common.DefaultInt64ArrayField), count: 91}, // json_contains_all(array, [x]) + {expr: fmt.Sprintf("json_contains_any (%s, [0, 100, 10])", common.DefaultFloatArrayField), count: 101}, // json_contains_any (array, [x]) + {expr: fmt.Sprintf("%s == [0, 1]", common.DefaultDoubleArrayField), count: 0}, // array == + {expr: fmt.Sprintf("array_length(%s) == %d", common.DefaultDoubleArrayField, capacity), count: common.DefaultNb}, // array_length + } + + ctx := createContext(t, time.Second*common.DefaultTimeout) + // connect + mc := createMilvusClient(ctx, t) + + // create collection + cp := CollectionParams{CollectionFieldsType: AllFields, AutoID: false, EnableDynamicField: true, + ShardsNum: common.DefaultShards, Dim: common.DefaultDim, MaxCapacity: common.TestCapacity} + collName := createCollection(ctx, t, mc, cp, client.WithConsistencyLevel(entity.ClStrong)) + + // insert + dp := DataParams{CollectionName: collName, PartitionName: "", CollectionFieldsType: AllFields, + start: 0, nb: common.DefaultNb, dim: common.DefaultDim, EnableDynamicField: true, WithRows: false} + _, err := insertData(ctx, t, mc, dp, common.WithArrayCapacity(common.TestCapacity)) + common.CheckErr(t, err, true) + mc.Flush(ctx, collName, false) + + indexHnsw, _ := entity.NewIndexHNSW(entity.L2, 8, 96) + indexBinary, _ := entity.NewIndexBinIvfFlat(entity.JACCARD, 64) + for _, fieldName := range common.AllVectorsFieldsName { + if fieldName == common.DefaultBinaryVecFieldName { + mc.CreateIndex(ctx, collName, fieldName, indexBinary, false) + } else { + mc.CreateIndex(ctx, collName, fieldName, indexHnsw, false) + } + } + + // Load collection + errLoad := mc.LoadCollection(ctx, collName, false) + common.CheckErr(t, errLoad, true) + batch := 500 + + for _, exprLimit := range exprLimits { + log.Printf("case expr is: %s, limit=%d", exprLimit.expr, exprLimit.count) + itr, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(collName).WithBatchSize(batch).WithExpr(exprLimit.expr)) + common.CheckErr(t, err, true) + common.CheckQueryIteratorResult(ctx, t, itr, exprLimit.count, common.WithExpBatchSize(common.GenBatchSizes(exprLimit.count, batch))) + } +} + +// test query iterator with partition +func TestQueryIteratorPartitions(t *testing.T) { + ctx := createContext(t, time.Second*common.DefaultTimeout) + // connect + mc := createMilvusClient(ctx, t) + // create collection + cp := CollectionParams{CollectionFieldsType: Int64FloatVec, AutoID: false, EnableDynamicField: false, + ShardsNum: common.DefaultShards, Dim: common.DefaultDim} + collName := createCollection(ctx, t, mc, cp) + pName := "p1" + err := mc.CreatePartition(ctx, collName, pName) + common.CheckErr(t, err, true) + + // insert [0, nb) into partition: _default + nb := 1500 + dp := DataParams{CollectionName: collName, PartitionName: "", CollectionFieldsType: Int64FloatVec, + start: 0, nb: nb, dim: common.DefaultDim, EnableDynamicField: false, WithRows: false} + _, _ = insertData(ctx, t, mc, dp) + // insert [nb, nb*2) into partition: p1 + dp1 := DataParams{CollectionName: collName, PartitionName: pName, CollectionFieldsType: Int64FloatVec, + start: nb, nb: nb, dim: common.DefaultDim, EnableDynamicField: false, WithRows: false} + _, _ = insertData(ctx, t, mc, dp1) + + idx, _ := entity.NewIndexHNSW(entity.L2, 8, 96) + _ = mc.CreateIndex(ctx, collName, common.DefaultFloatVecFieldName, idx, false) + + // Load collection + errLoad := mc.LoadCollection(ctx, collName, false) + common.CheckErr(t, errLoad, true) + + // query iterator with partition + expr := fmt.Sprintf("%s < %d", common.DefaultIntFieldName, nb) + mParLimit := map[string]int{ + common.DefaultPartition: nb, + pName: 0, + } + for par, limit := range mParLimit { + itr, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(collName).WithExpr(expr).WithPartitions(par)) + common.CheckErr(t, err, true) + common.CheckQueryIteratorResult(ctx, t, itr, limit, common.WithExpBatchSize(common.GenBatchSizes(limit, common.DefaultBatchSize))) } } diff --git a/test/testcases/search_test.go b/test/testcases/search_test.go index fe5827c8..7cca2adb 100644 --- a/test/testcases/search_test.go +++ b/test/testcases/search_test.go @@ -1,4 +1,4 @@ -///go:build L0 +//go:build L0 package testcases @@ -667,7 +667,7 @@ func TestSearchInvalidVectors(t *testing.T) { // dim not match {vectors: common.GenSearchVectors(common.DefaultNq, 64, entity.FieldTypeFloatVector), errMsg: "vector dimension mismatch"}, - // vector type not match + //vector type not match {vectors: common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeBinaryVector), errMsg: "vector type must be the same"}, // empty vectors @@ -693,6 +693,53 @@ func TestSearchInvalidVectors(t *testing.T) { } } +// test search with invalid vectors +func TestSearchInvalidVectorsEmptyCollection(t *testing.T) { + t.Skip("https://github.com/milvus-io/milvus/issues/33639") + t.Skip("https://github.com/milvus-io/milvus/issues/33637") + t.Parallel() + ctx := createContext(t, time.Second*common.DefaultTimeout*2) + // connect + mc := createMilvusClient(ctx, t) + + // create collection with data + collName := createDefaultCollection(ctx, t, mc, false, common.DefaultShards) + + // index + idx, _ := entity.NewIndexHNSW(entity.L2, 8, 96) + err := mc.CreateIndex(ctx, collName, common.DefaultFloatVecFieldName, idx, false, client.WithIndexName("")) + common.CheckErr(t, err, true) + + // load collection + errLoad := mc.LoadCollection(ctx, collName, false) + common.CheckErr(t, errLoad, true) + + type invalidVectorsStruct struct { + vectors []entity.Vector + errMsg string + } + + invalidVectors := []invalidVectorsStruct{ + // dim not match + {vectors: common.GenSearchVectors(common.DefaultNq, 64, entity.FieldTypeFloatVector), errMsg: "vector dimension mismatch"}, + + //vector type not match + {vectors: common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeBinaryVector), errMsg: "vector type must be the same"}, + + // empty vectors + {vectors: []entity.Vector{}, errMsg: "nq [0] is invalid"}, + {vectors: []entity.Vector{entity.FloatVector{}}, errMsg: "vector dimension mismatch"}, + } + + sp, _ := entity.NewIndexHNSWSearchParam(74) + for _, invalidVector := range invalidVectors { + // search vectors empty slice + _, errSearchEmpty := mc.Search(ctx, collName, []string{}, "", []string{"*"}, invalidVector.vectors, + common.DefaultFloatVecFieldName, entity.L2, common.DefaultTopK, sp) + common.CheckErr(t, errSearchEmpty, false, invalidVector.errMsg) + } +} + // test search metric type isn't the same with index metric type func TestSearchNotMatchMetricType(t *testing.T) { ctx := createContext(t, time.Second*common.DefaultTimeout*2) From 95550a5cd8b75c55c694ae606e108408f2c47c35 Mon Sep 17 00:00:00 2001 From: congqixia Date: Wed, 12 Jun 2024 14:21:49 +0800 Subject: [PATCH 09/11] enhance: Bump sdk verison to v2.4.1 (#771) Signed-off-by: Congqi Xia --- common/common.go | 2 +- go.mod | 2 +- go.sum | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/common/common.go b/common/common.go index afa93a79..985c34d5 100644 --- a/common/common.go +++ b/common/common.go @@ -2,5 +2,5 @@ package common const ( // SDKVersion const value for current version - SDKVersion = `v2.4.0` + SDKVersion = `v2.4.1` ) diff --git a/go.mod b/go.mod index 872ec746..980c22af 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/go-faker/faker/v4 v4.1.0 github.com/golang/protobuf v1.5.2 github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 - github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240430025921-135167be0694 + github.com/milvus-io/milvus-proto/go-api/v2 v2.4.3 github.com/stretchr/testify v1.8.1 github.com/tidwall/gjson v1.14.4 github.com/x448/float16 v0.8.4 diff --git a/go.sum b/go.sum index 3ffaa91a..988bf02e 100644 --- a/go.sum +++ b/go.sum @@ -159,6 +159,8 @@ github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc= github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240430025921-135167be0694 h1:iub0yx8peGNtnb9n11iuWNmhIhIXw3xfZooIDcrfeU8= github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240430025921-135167be0694/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek= +github.com/milvus-io/milvus-proto/go-api/v2 v2.4.3 h1:KUSaWVePVlHMIluAXf2qmNffI1CMlGFLLiP+4iy9014= +github.com/milvus-io/milvus-proto/go-api/v2 v2.4.3/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= From d346d3ed3898244bf17a324767b220b808c52c25 Mon Sep 17 00:00:00 2001 From: chyezh Date: Mon, 17 Jun 2024 12:03:51 +0800 Subject: [PATCH 10/11] enhance: add resource group declarative api (#733) issue: https://github.com/milvus-io/milvus/issues/32282 - Add UpdateResourceGroups and modify AddResourceGroup api. - Add example for add resource group declarative api. --------- Signed-off-by: chyezh --- client/client.go | 4 +- client/options.go | 23 +++ client/options_test.go | 24 ++++ client/resource_group.go | 25 +++- client/resource_group_test.go | 64 ++++++++- entity/resource_group.go | 14 ++ examples/resourcegroup/resourcegroup.go | 180 ++++++++++++++++++++++++ test/testcases/resource_group_test.go | 7 +- 8 files changed, 335 insertions(+), 6 deletions(-) create mode 100644 examples/resourcegroup/resourcegroup.go diff --git a/client/client.go b/client/client.go index cfc1a883..ec3765d5 100644 --- a/client/client.go +++ b/client/client.go @@ -214,7 +214,9 @@ type Client interface { // ListResourceGroups returns list of resource group names in current Milvus instance. ListResourceGroups(ctx context.Context) ([]string, error) // CreateResourceGroup creates a resource group with provided name. - CreateResourceGroup(ctx context.Context, rgName string) error + CreateResourceGroup(ctx context.Context, rgName string, opts ...CreateResourceGroupOption) error + // UpdateResourceGroups updates resource groups with provided options. + UpdateResourceGroups(ctx context.Context, opts ...UpdateResourceGroupsOption) error // DescribeResourceGroup returns resource groups information. DescribeResourceGroup(ctx context.Context, rgName string) (*entity.ResourceGroup, error) // DropResourceGroup drops the resource group with provided name. diff --git a/client/options.go b/client/options.go index 6ac2a962..0254dd7b 100644 --- a/client/options.go +++ b/client/options.go @@ -322,3 +322,26 @@ type DropPartitionOption func(*milvuspb.DropPartitionRequest) type LoadPartitionsOption func(*milvuspb.LoadPartitionsRequest) type ReleasePartitionsOption func(*milvuspb.ReleasePartitionsRequest) + +// CreateResourceGroupOption is an option that is used in CreateResourceGroup API. +type CreateResourceGroupOption func(*milvuspb.CreateResourceGroupRequest) + +// WithCreateResourceGroupConfig returns a CreateResourceGroupOption that setup the config. +func WithCreateResourceGroupConfig(config *entity.ResourceGroupConfig) CreateResourceGroupOption { + return func(req *milvuspb.CreateResourceGroupRequest) { + req.Config = config + } +} + +// UpdateResourceGroupsOption is an option that is used in UpdateResourceGroups API. +type UpdateResourceGroupsOption func(*milvuspb.UpdateResourceGroupsRequest) + +// WithUpdateResourceGroupConfig returns an UpdateResourceGroupsOption that sets the new config to the specified resource group. +func WithUpdateResourceGroupConfig(resourceGroupName string, config *entity.ResourceGroupConfig) UpdateResourceGroupsOption { + return func(urgr *milvuspb.UpdateResourceGroupsRequest) { + if urgr.ResourceGroups == nil { + urgr.ResourceGroups = make(map[string]*entity.ResourceGroupConfig) + } + urgr.ResourceGroups[resourceGroupName] = config + } +} diff --git a/client/options_test.go b/client/options_test.go index bc9194cb..23f1dc7b 100644 --- a/client/options_test.go +++ b/client/options_test.go @@ -177,3 +177,27 @@ func TestMakeSearchQueryOption(t *testing.T) { assert.Error(t, err) }) } + +func TestWithUpdateResourceGroupConfig(t *testing.T) { + req := &milvuspb.UpdateResourceGroupsRequest{} + + WithUpdateResourceGroupConfig("rg1", &entity.ResourceGroupConfig{ + Requests: &entity.ResourceGroupLimit{NodeNum: 1}, + })(req) + WithUpdateResourceGroupConfig("rg2", &entity.ResourceGroupConfig{ + Requests: &entity.ResourceGroupLimit{NodeNum: 2}, + })(req) + + assert.Equal(t, 2, len(req.ResourceGroups)) + assert.Equal(t, int32(1), req.ResourceGroups["rg1"].Requests.NodeNum) + assert.Equal(t, int32(2), req.ResourceGroups["rg2"].Requests.NodeNum) +} + +func TestWithCreateResourceGroup(t *testing.T) { + req := &milvuspb.CreateResourceGroupRequest{} + + WithCreateResourceGroupConfig(&entity.ResourceGroupConfig{ + Requests: &entity.ResourceGroupLimit{NodeNum: 1}, + })(req) + assert.Equal(t, int32(1), req.Config.Requests.NodeNum) +} diff --git a/client/resource_group.go b/client/resource_group.go index fc3f90db..6cfd6552 100644 --- a/client/resource_group.go +++ b/client/resource_group.go @@ -38,7 +38,7 @@ func (c *GrpcClient) ListResourceGroups(ctx context.Context) ([]string, error) { } // CreateResourceGroup creates a resource group with provided name. -func (c *GrpcClient) CreateResourceGroup(ctx context.Context, rgName string) error { +func (c *GrpcClient) CreateResourceGroup(ctx context.Context, rgName string, opts ...CreateResourceGroupOption) error { if c.Service == nil { return ErrClientNotReady } @@ -46,6 +46,9 @@ func (c *GrpcClient) CreateResourceGroup(ctx context.Context, rgName string) err req := &milvuspb.CreateResourceGroupRequest{ ResourceGroup: rgName, } + for _, opt := range opts { + opt(req) + } resp, err := c.Service.CreateResourceGroup(ctx, req) if err != nil { @@ -54,6 +57,24 @@ func (c *GrpcClient) CreateResourceGroup(ctx context.Context, rgName string) err return handleRespStatus(resp) } +// UpdateResourceGroups updates resource groups with provided options. +func (c *GrpcClient) UpdateResourceGroups(ctx context.Context, opts ...UpdateResourceGroupsOption) error { + if c.Service == nil { + return ErrClientNotReady + } + + req := &milvuspb.UpdateResourceGroupsRequest{} + for _, opt := range opts { + opt(req) + } + + resp, err := c.Service.UpdateResourceGroups(ctx, req) + if err != nil { + return err + } + return handleRespStatus(resp) +} + // DescribeResourceGroup returns resource groups information. func (c *GrpcClient) DescribeResourceGroup(ctx context.Context, rgName string) (*entity.ResourceGroup, error) { if c.Service == nil { @@ -80,6 +101,8 @@ func (c *GrpcClient) DescribeResourceGroup(ctx context.Context, rgName string) ( LoadedReplica: rg.GetNumLoadedReplica(), OutgoingNodeNum: rg.GetNumOutgoingNode(), IncomingNodeNum: rg.GetNumIncomingNode(), + Config: rg.GetConfig(), + Nodes: rg.GetNodes(), } return result, nil diff --git a/client/resource_group_test.go b/client/resource_group_test.go index 7b4e9096..ce9096f9 100644 --- a/client/resource_group_test.go +++ b/client/resource_group_test.go @@ -23,6 +23,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-sdk-go/v2/entity" ) type ResourceGroupSuite struct { @@ -121,6 +122,68 @@ func (s *ResourceGroupSuite) TestCreateResourceGroup() { }) } +func (s *ResourceGroupSuite) TestUpdateResourceGroups() { + c := s.client + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("normal_run", func() { + defer s.resetMock() + rgName := randStr(10) + + s.mock.EXPECT().UpdateResourceGroups(mock.Anything, mock.AnythingOfType("*milvuspb.UpdateResourceGroupsRequest")). + Run(func(_ context.Context, req *milvuspb.UpdateResourceGroupsRequest) { + s.Len(req.ResourceGroups, 1) + s.NotNil(req.ResourceGroups[rgName]) + s.Equal(int32(1), req.ResourceGroups[rgName].Requests.NodeNum) + }). + Return(&commonpb.Status{}, nil) + + err := c.UpdateResourceGroups(ctx, WithUpdateResourceGroupConfig(rgName, &entity.ResourceGroupConfig{ + Requests: &entity.ResourceGroupLimit{NodeNum: 1}, + })) + s.NoError(err) + }) + + s.Run("request_fails", func() { + defer s.resetMock() + + rgName := randStr(10) + + s.mock.EXPECT().UpdateResourceGroups(mock.Anything, mock.AnythingOfType("*milvuspb.UpdateResourceGroupsRequest")). + Run(func(_ context.Context, req *milvuspb.UpdateResourceGroupsRequest) { + s.Len(req.ResourceGroups, 1) + s.NotNil(req.ResourceGroups[rgName]) + s.Equal(int32(1), req.ResourceGroups[rgName].Requests.NodeNum) + }). + Return(nil, errors.New("mocked grpc error")) + + err := c.UpdateResourceGroups(ctx, WithUpdateResourceGroupConfig(rgName, &entity.ResourceGroupConfig{ + Requests: &entity.ResourceGroupLimit{NodeNum: 1}, + })) + s.Error(err) + }) + + s.Run("server_return_err", func() { + defer s.resetMock() + + rgName := randStr(10) + + s.mock.EXPECT().UpdateResourceGroups(mock.Anything, mock.AnythingOfType("*milvuspb.UpdateResourceGroupsRequest")). + Run(func(_ context.Context, req *milvuspb.UpdateResourceGroupsRequest) { + s.Len(req.ResourceGroups, 1) + s.NotNil(req.ResourceGroups[rgName]) + s.Equal(int32(1), req.ResourceGroups[rgName].Requests.NodeNum) + }). + Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil) + + err := c.UpdateResourceGroups(ctx, WithUpdateResourceGroupConfig(rgName, &entity.ResourceGroupConfig{ + Requests: &entity.ResourceGroupLimit{NodeNum: 1}, + })) + s.Error(err) + }) +} + func (s *ResourceGroupSuite) TestDescribeResourceGroup() { c := s.client ctx, cancel := context.WithCancel(context.Background()) @@ -153,7 +216,6 @@ func (s *ResourceGroupSuite) TestDescribeResourceGroup() { s.Equal(rgName, req.GetResourceGroup()) }). Call.Return(func(_ context.Context, req *milvuspb.DescribeResourceGroupRequest) *milvuspb.DescribeResourceGroupResponse { - return &milvuspb.DescribeResourceGroupResponse{ Status: &commonpb.Status{}, ResourceGroup: &milvuspb.ResourceGroup{ diff --git a/entity/resource_group.go b/entity/resource_group.go index 77eb070f..873edb89 100644 --- a/entity/resource_group.go +++ b/entity/resource_group.go @@ -1,5 +1,17 @@ package entity +import ( + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/rgpb" +) + +type ( + ResourceGroupConfig = rgpb.ResourceGroupConfig + ResourceGroupLimit = rgpb.ResourceGroupLimit + ResourceGroupTransfer = rgpb.ResourceGroupTransfer + NodeInfo = commonpb.NodeInfo +) + // ResourceGroup information model struct. type ResourceGroup struct { Name string @@ -8,4 +20,6 @@ type ResourceGroup struct { LoadedReplica map[string]int32 OutgoingNodeNum map[string]int32 IncomingNodeNum map[string]int32 + Config *ResourceGroupConfig + Nodes []*NodeInfo } diff --git a/examples/resourcegroup/resourcegroup.go b/examples/resourcegroup/resourcegroup.go new file mode 100644 index 00000000..169bcff9 --- /dev/null +++ b/examples/resourcegroup/resourcegroup.go @@ -0,0 +1,180 @@ +package main + +import ( + "context" + "encoding/json" + "log" + "time" + + "github.com/milvus-io/milvus-sdk-go/v2/client" + "github.com/milvus-io/milvus-sdk-go/v2/entity" +) + +const ( + milvusAddr = `localhost:19530` + recycleResourceGroup = `__recycle_resource_group` + defaultResourceGroup = `__default_resource_group` + rg1 = `rg1` + rg2 = `rg2` +) + +func main() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + c, err := client.NewClient(ctx, client.Config{ + Address: milvusAddr, + }) + if err != nil { + log.Fatal("failed to connect to milvus, err: ", err.Error()) + } + defer c.Close() + + ctx = context.Background() + showAllResourceGroup(ctx, c) + + // query node count: 1 + // | RG | Request | Limit | Nodes | + // | -- | ------- | ----- | ----- | + // | __default__resource_group | 1 | 1 | 1 | + // | __recycle__resource_group | 0 | 10000 | 0 | + // | rg1 | 0 | 0 | 0 | + // | rg2 | 0 | 0 | 0 | + if err := initializeCluster(ctx, c); err != nil { + log.Fatal("failed to initialize cluster, err: ", err.Error()) + } + + showAllResourceGroup(ctx, c) + + // do some resource group managements. + if err := resourceGroupManagement(ctx, c); err != nil { + log.Fatal("failed to manage resource group, err: ", err.Error()) + } +} + +// initializeCluster initializes the cluster with 4 resource groups. +func initializeCluster(ctx context.Context, c client.Client) error { + // Use a huge resource group to hold the redundant query node. + if err := c.CreateResourceGroup(ctx, recycleResourceGroup, client.WithCreateResourceGroupConfig( + &entity.ResourceGroupConfig{ + Requests: &entity.ResourceGroupLimit{NodeNum: 0}, + Limits: &entity.ResourceGroupLimit{NodeNum: 10000}, + }, + )); err != nil { + return err + } + + if err := c.UpdateResourceGroups(ctx, client.WithUpdateResourceGroupConfig(defaultResourceGroup, newResourceGroupCfg(1, 1))); err != nil { + return err + } + + if err := c.CreateResourceGroup(ctx, rg1); err != nil { + return err + } + + return c.CreateResourceGroup(ctx, rg2) +} + +// resourceGroupManagement manages the resource groups. +func resourceGroupManagement(ctx context.Context, c client.Client) error { + // Update resource group config. + // | RG | Request | Limit | Nodes | + // | -- | ------- | ----- | ----- | + // | __default__resource_group | 1 | 1 | 1 | + // | __recycle__resource_group | 0 | 10000 | 0 | + // | rg1 | 1 | 1 | 0 | + // | rg2 | 2 | 2 | 0 | + if err := c.UpdateResourceGroups(ctx, + client.WithUpdateResourceGroupConfig(rg1, newResourceGroupCfg(1, 1)), + client.WithUpdateResourceGroupConfig(rg2, newResourceGroupCfg(2, 2)), + ); err != nil { + return err + } + showAllResourceGroup(ctx, c) + + // scale out cluster, new query node will be added to rg1 and rg2. + // | RG | Request | Limit | Nodes | + // | -- | ------- | ----- | ----- | + // | __default__resource_group | 1 | 1 | 1 | + // | __recycle__resource_group | 0 | 10000 | 0 | + // | rg1 | 1 | 1 | 1 | + // | rg2 | 2 | 2 | 2 | + scaleTo(ctx, 4) + showAllResourceGroup(ctx, c) + + // scale out cluster, new query node will be added to __recycle__resource_group. + // | RG | Request | Limit | Nodes | + // | -- | ------- | ----- | ----- | + // | __default__resource_group | 1 | 1 | 1 | + // | __recycle__resource_group | 0 | 10000 | 1 | + // | rg1 | 1 | 1 | 1 | + // | rg2 | 2 | 2 | 2 | + scaleTo(ctx, 5) + showAllResourceGroup(ctx, c) + + // Update resource group config, redundant query node will be transferred to recycle resource group. + // | RG | Request | Limit | Nodes | + // | -- | ------- | ----- | ----- | + // | __default__resource_group | 1 | 1 | 1 | + // | __recycle__resource_group | 0 | 10000 | 2 | + // | rg1 | 1 | 1 | 1 | + // | rg2 | 1 | 1 | 1 | + if err := c.UpdateResourceGroups(ctx, + client.WithUpdateResourceGroupConfig(rg1, newResourceGroupCfg(1, 1)), + client.WithUpdateResourceGroupConfig(rg2, newResourceGroupCfg(1, 1)), + ); err != nil { + return err + } + showAllResourceGroup(ctx, c) + + // Update resource group config, rg1 and rg2 will transfer missing node from __recycle__resource_group. + // | RG | Request | Limit | Nodes | + // | -- | ------- | ----- | ----- | + // | __default__resource_group | 1 | 1 | 1 | + // | __recycle__resource_group | 0 | 10000 | 0 | + // | rg1 | 2 | 2 | 2 | + // | rg2 | 2 | 2 | 2 | + if err := c.UpdateResourceGroups(ctx, + client.WithUpdateResourceGroupConfig(rg1, newResourceGroupCfg(2, 2)), + client.WithUpdateResourceGroupConfig(rg2, newResourceGroupCfg(2, 2)), + ); err != nil { + return err + } + showAllResourceGroup(ctx, c) + + return nil +} + +// scaleTo scales the cluster to the specified node number. +func scaleTo(_ context.Context, _ int) { + // Cannot implement by milvus core and sdk, + // Need to be implement by orchestration system. +} + +func newResourceGroupCfg(request int32, limit int32) *entity.ResourceGroupConfig { + return &entity.ResourceGroupConfig{ + Requests: &entity.ResourceGroupLimit{NodeNum: request}, + Limits: &entity.ResourceGroupLimit{NodeNum: limit}, + TransferFrom: []*entity.ResourceGroupTransfer{{ResourceGroup: recycleResourceGroup}}, + TransferTo: []*entity.ResourceGroupTransfer{{ResourceGroup: recycleResourceGroup}}, + } +} + +// showAllResourceGroup shows all resource groups. +func showAllResourceGroup(ctx context.Context, c client.Client) { + rgs, err := c.ListResourceGroups(ctx) + if err != nil { + log.Fatal("failed to list resource groups, err: ", err.Error()) + } + log.Println("resource groups:") + for _, rg := range rgs { + rg, err := c.DescribeResourceGroup(ctx, rg) + if err != nil { + log.Fatal("failed to describe resource group, err: ", err.Error()) + } + results, err := json.Marshal(rg) + if err != nil { + log.Fatal("failed to marshal resource group, err: ", err.Error()) + } + log.Printf("%s\n", results) + } +} diff --git a/test/testcases/resource_group_test.go b/test/testcases/resource_group_test.go index d8b41629..6c3cf92c 100644 --- a/test/testcases/resource_group_test.go +++ b/test/testcases/resource_group_test.go @@ -18,8 +18,10 @@ import ( "github.com/milvus-io/milvus-sdk-go/v2/test/common" ) -const configQnNodes = int32(4) -const newRgNode = int32(2) +const ( + configQnNodes = int32(4) + newRgNode = int32(2) +) func resetRgs(t *testing.T, ctx context.Context, mc *base.MilvusClient) { // release and drop all collections @@ -371,7 +373,6 @@ func TestTransferReplicas(t *testing.T) { // check search result contains search vector, which from all partitions common.CheckErr(t, err, true) common.CheckSearchResult(t, searchRes, common.DefaultNq, common.DefaultTopK) - } // test transfer replica of not existed collection From 6404f1b617e5aa79831ad174e41d9d97821a7afe Mon Sep 17 00:00:00 2001 From: Sidharth Suvarna Date: Tue, 18 Jun 2024 09:21:52 +0530 Subject: [PATCH 11/11] fix: failed to collection error in examples (#773) Added create index before executing LoadCollection in examples --------- Signed-off-by: Sidharh Suvarna Signed-off-by: Sidharth Signed-off-by: chyezh Signed-off-by: Sidharth Co-authored-by: Sidharth Co-authored-by: chyezh --- examples/auth/auth.go | 10 ++++++++++ examples/insert/insert.go | 10 ++++++++++ examples/tls/tls.go | 10 ++++++++++ 3 files changed, 30 insertions(+) diff --git a/examples/auth/auth.go b/examples/auth/auth.go index b1ee23a1..38c9dd70 100644 --- a/examples/auth/auth.go +++ b/examples/auth/auth.go @@ -113,6 +113,16 @@ func main() { } log.Println("flush completed") + // Now add index + idx, err := entity.NewIndexIvfFlat(entity.L2, 2) + if err != nil { + log.Fatal("fail to create ivf flat index:", err.Error()) + } + err = c.CreateIndex(ctx, collectionName, "Vector", idx, false) + if err != nil { + log.Fatal("fail to create index:", err.Error()) + } + // load collection with async=false err = c.LoadCollection(ctx, collectionName, false) if err != nil { diff --git a/examples/insert/insert.go b/examples/insert/insert.go index 48d2594a..17fcbe91 100644 --- a/examples/insert/insert.go +++ b/examples/insert/insert.go @@ -91,6 +91,16 @@ func main() { } log.Println("flush completed") + // Now add index + idx, err := entity.NewIndexIvfFlat(entity.L2, 2) + if err != nil { + log.Fatal("fail to create ivf flat index:", err.Error()) + } + err = c.CreateIndex(ctx, collectionName, "Vector", idx, false) + if err != nil { + log.Fatal("fail to create index:", err.Error()) + } + // load collection with async=false err = c.LoadCollection(ctx, collectionName, false) if err != nil { diff --git a/examples/tls/tls.go b/examples/tls/tls.go index f6eb40d7..e483a9a7 100644 --- a/examples/tls/tls.go +++ b/examples/tls/tls.go @@ -142,6 +142,16 @@ func main() { } log.Println("flush completed") + // Now add index + idx, err := entity.NewIndexIvfFlat(entity.L2, 2) + if err != nil { + log.Fatal("fail to create ivf flat index:", err.Error()) + } + err = c.CreateIndex(ctx, collectionName, "Vector", idx, false) + if err != nil { + log.Fatal("fail to create index:", err.Error()) + } + // load collection with async=false err = c.LoadCollection(ctx, collectionName, false) if err != nil {