Skip to content

Commit

Permalink
Cherry pick partition key fix and example (#541)
Browse files Browse the repository at this point in the history
Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
  • Loading branch information
congqixia authored Aug 4, 2023
1 parent 36d7a7b commit 1536418
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 3 deletions.
4 changes: 1 addition & 3 deletions client/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ func (c *GrpcClient) Insert(ctx context.Context, collName string, partitionName
PartitionName: partitionName,
FieldsData: fieldsData,
}
if req.PartitionName == "" {
req.PartitionName = "_default" // use default partition
}

req.NumRows = uint32(rowSize)

resp, err := c.Service.Insert(ctx, req)
Expand Down
119 changes: 119 additions & 0 deletions examples/partitionkey/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package main

import (
"context"
"log"
"math/rand"
"time"

"github.com/milvus-io/milvus-sdk-go/v2/client"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
)

const (
milvusAddr = `localhost:19530`
nEntities, dim = 10000, 128
collectionName = "hello_partition_key"

idCol, keyCol, embeddingCol = "ID", "key", "embeddings"
topK = 3
)

func main() {
ctx := context.Background()

log.Println("start connecting to Milvus")
c, err := client.NewClient(ctx, client.Config{
Address: milvusAddr,
})
if err != nil {
log.Fatalf("failed to connect to milvus, err: %v", err)
}
defer c.Close()

// delete collection if exists
has, err := c.HasCollection(ctx, collectionName)
if err != nil {
log.Fatalf("failed to check collection exists, err: %v", err)
}
if has {
c.DropCollection(ctx, collectionName)
}

// create collection
log.Printf("create collection `%s`\n", collectionName)
schema := entity.NewSchema().WithName(collectionName).WithDescription("hello_partition_key is the a demo to introduce the partition key related APIs").
WithField(entity.NewField().WithName(idCol).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithIsAutoID(true)).
WithField(entity.NewField().WithName(keyCol).WithDataType(entity.FieldTypeInt64).WithIsPartitionKey(true)).
WithField(entity.NewField().WithName(embeddingCol).WithDataType(entity.FieldTypeFloatVector).WithDim(dim))

if err := c.CreateCollection(ctx, schema, entity.DefaultShardNumber, client.WithPartitionNum(32)); err != nil { // use default shard number
log.Fatalf("create collection failed, err: %v", err)
}

var keyList []int64
var embeddingList [][]float32
keyList = make([]int64, 0, nEntities)
embeddingList = make([][]float32, 0, nEntities)
for i := 0; i < nEntities; i++ {
keyList = append(keyList, rand.Int63()%512)
}
for i := 0; i < nEntities; i++ {
vec := make([]float32, 0, dim)
for j := 0; j < dim; j++ {
vec = append(vec, rand.Float32())
}
embeddingList = append(embeddingList, vec)
}
keyColData := entity.NewColumnInt64(keyCol, keyList)
embeddingColData := entity.NewColumnFloatVector(embeddingCol, dim, embeddingList)

log.Println("start to insert data into collection")

if _, err := c.Insert(ctx, collectionName, "", keyColData, embeddingColData); err != nil {
log.Fatalf("failed to insert random data into `%s`, err: %v", collectionName, err)
}

log.Println("insert data done, start to flush")

if err := c.Flush(ctx, collectionName, false); err != nil {
log.Fatalf("failed to flush data, err: %v", err)
}
log.Println("flush data done")

// build index
log.Println("start creating index HNSW")
idx, err := entity.NewIndexHNSW(entity.L2, 16, 256)
if err != nil {
log.Fatalf("failed to create ivf flat index, err: %v", err)
}
if err := c.CreateIndex(ctx, collectionName, embeddingCol, idx, false); err != nil {
log.Fatalf("failed to create index, err: %v", err)
}

log.Printf("build HNSW index done for collection `%s`\n", collectionName)
log.Printf("start to load collection `%s`\n", collectionName)

// load collection
if err := c.LoadCollection(ctx, collectionName, false); err != nil {
log.Fatalf("failed to load collection, err: %v", err)
}

log.Println("load collection done")

vec2search := []entity.Vector{
entity.FloatVector(embeddingList[len(embeddingList)-2]),
entity.FloatVector(embeddingList[len(embeddingList)-1]),
}
begin := time.Now()
sp, _ := entity.NewIndexHNSWSearchParam(30)
_, err = c.Search(ctx, collectionName, nil, "", []string{keyCol}, vec2search,
embeddingCol, entity.L2, topK, sp)
if err != nil {
log.Fatalf("failed to search collection, err: %v", err)
}

log.Printf("search `%s` done, latency %v\n", collectionName, time.Since(begin))

c.DropCollection(ctx, collectionName)
}

0 comments on commit 1536418

Please sign in to comment.