Skip to content

Commit

Permalink
enhance: Update latest sdk update to client pkg (#33105)
Browse files Browse the repository at this point in the history
  • Loading branch information
congqixia authored May 17, 2024
1 parent f1c9986 commit 1ef975d
Show file tree
Hide file tree
Showing 18 changed files with 553 additions and 123 deletions.
1 change: 1 addition & 0 deletions client/OWNERS
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
reviewers:
- congqixia
- ThreadDao

approvers:
- maintainers
Expand Down
67 changes: 63 additions & 4 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,20 @@ package client

import (
"context"
"crypto/tls"
"fmt"
"math"
"os"
"strconv"
"sync"
"time"

"github.com/gogo/status"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
Expand All @@ -39,6 +45,11 @@ type Client struct {
service milvuspb.MilvusServiceClient
config *ClientConfig

// mutable status
stateMut sync.RWMutex
currentDB string
identifier string

collCache *CollectionCache
}

Expand All @@ -54,8 +65,10 @@ func New(ctx context.Context, config *ClientConfig) (*Client, error) {
// Parse remote address.
addr := c.config.getParsedAddress()

// parse authentication parameters
c.config.parseAuthentication()
// Parse grpc options
options := c.config.getDialOption()
options := c.dialOptions()

// Connect the grpc server.
if err := c.connect(ctx, addr, options...); err != nil {
Expand All @@ -69,6 +82,40 @@ func New(ctx context.Context, config *ClientConfig) (*Client, error) {
return c, nil
}

func (c *Client) dialOptions() []grpc.DialOption {
var options []grpc.DialOption
// Construct dial option.
if c.config.EnableTLSAuth {
options = append(options, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})))
} else {
options = append(options, grpc.WithTransportCredentials(insecure.NewCredentials()))
}

if c.config.DialOptions == nil {
// Add default connection options.
options = append(options, DefaultGrpcOpts...)
} else {
options = append(options, c.config.DialOptions...)
}

options = append(options,
grpc.WithChainUnaryInterceptor(grpc_retry.UnaryClientInterceptor(
grpc_retry.WithMax(6),
grpc_retry.WithBackoff(func(attempt uint) time.Duration {
return 60 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
}),
grpc_retry.WithCodes(codes.Unavailable, codes.ResourceExhausted)),

// c.getRetryOnRateLimitInterceptor(),
))

options = append(options, grpc.WithChainUnaryInterceptor(
c.MetadataUnaryInterceptor(),
))

return options
}

func (c *Client) Close(ctx context.Context) error {
if c.conn == nil {
return nil
Expand All @@ -82,6 +129,18 @@ func (c *Client) Close(ctx context.Context) error {
return nil
}

func (c *Client) usingDatabase(dbName string) {
c.stateMut.Lock()
defer c.stateMut.Unlock()
c.currentDB = dbName
}

func (c *Client) setIdentifier(identifier string) {
c.stateMut.Lock()
defer c.stateMut.Unlock()
c.identifier = identifier
}

func (c *Client) connect(ctx context.Context, addr string, options ...grpc.DialOption) error {
if addr == "" {
return fmt.Errorf("address is empty")
Expand Down Expand Up @@ -112,7 +171,7 @@ func (c *Client) connectInternal(ctx context.Context) error {

req := &milvuspb.ConnectRequest{
ClientInfo: &commonpb.ClientInfo{
SdkType: "Golang",
SdkType: "GoMilvusClient",
SdkVersion: common.SDKVersion,
LocalTime: time.Now().String(),
User: c.config.Username,
Expand All @@ -131,8 +190,8 @@ func (c *Client) connectInternal(ctx context.Context) error {
disableJSON |
disableParitionKey |
disableDynamicSchema)
return nil
}
return nil
}
return err
}
Expand All @@ -142,7 +201,7 @@ func (c *Client) connectInternal(ctx context.Context) error {
}

c.config.setServerInfo(resp.GetServerInfo().GetBuildTags())
c.config.setIdentifier(strconv.FormatInt(resp.GetIdentifier(), 10))
c.setIdentifier(strconv.FormatInt(resp.GetIdentifier(), 10))

return nil
}
Expand Down
82 changes: 34 additions & 48 deletions client/client_config.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package client

import (
"crypto/tls"
"context"
"fmt"
"math"
"net/url"
Expand All @@ -10,12 +10,9 @@ import (
"time"

"github.com/cockroachdb/errors"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
"github.com/milvus-io/milvus/pkg/util/crypto"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
)

Expand Down Expand Up @@ -59,16 +56,23 @@ type ClientConfig struct {

DialOptions []grpc.DialOption // Dial options for GRPC.

// RetryRateLimit *RetryRateLimitOption // option for retry on rate limit inteceptor
RetryRateLimit *RetryRateLimitOption // option for retry on rate limit inteceptor

DisableConn bool

metadataHeaders map[string]string

identifier string // Identifier for this connection
ServerVersion string // ServerVersion
parsedAddress *url.URL
flags uint64 // internal flags
}

type RetryRateLimitOption struct {
MaxRetry uint
MaxBackoff time.Duration
}

func (cfg *ClientConfig) parse() error {
// Prepend default fake tcp:// scheme for remote address.
address := cfg.Address
Expand Down Expand Up @@ -118,54 +122,36 @@ func (c *ClientConfig) setServerInfo(serverInfo string) {
c.ServerVersion = serverInfo
}

// Get parsed grpc dial options, should be called after parse was called.
func (c *ClientConfig) getDialOption() []grpc.DialOption {
options := c.DialOptions
if c.DialOptions == nil {
// Add default connection options.
options = make([]grpc.DialOption, len(DefaultGrpcOpts))
copy(options, DefaultGrpcOpts)
// parseAuthentication prepares authentication headers for grpc inteceptors based on the provided username, password or API key.
func (c *ClientConfig) parseAuthentication() {
c.metadataHeaders = make(map[string]string)
if c.Username != "" || c.Password != "" {
value := crypto.Base64Encode(fmt.Sprintf("%s:%s", c.Username, c.Password))
c.metadataHeaders[authorizationHeader] = value
}
// API overwrites username & passwd
if c.APIKey != "" {
value := crypto.Base64Encode(c.APIKey)
c.metadataHeaders[authorizationHeader] = value
}
}

// Construct dial option.
if c.EnableTLSAuth {
options = append(options, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})))
} else {
options = append(options, grpc.WithTransportCredentials(insecure.NewCredentials()))
func (c *ClientConfig) getRetryOnRateLimitInterceptor() grpc.UnaryClientInterceptor {
if c.RetryRateLimit == nil {
c.RetryRateLimit = c.defaultRetryRateLimitOption()
}

options = append(options,
grpc.WithChainUnaryInterceptor(grpc_retry.UnaryClientInterceptor(
grpc_retry.WithMax(6),
grpc_retry.WithBackoff(func(attempt uint) time.Duration {
return 60 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
}),
grpc_retry.WithCodes(codes.Unavailable, codes.ResourceExhausted)),
// c.getRetryOnRateLimitInterceptor(),
))

// options = append(options, grpc.WithChainUnaryInterceptor(
// createMetaDataUnaryInterceptor(c),
// ))
return options
return RetryOnRateLimitInterceptor(c.RetryRateLimit.MaxRetry, c.RetryRateLimit.MaxBackoff, func(ctx context.Context, attempt uint) time.Duration {
return 10 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
})
}

// func (c *ClientConfig) getRetryOnRateLimitInterceptor() grpc.UnaryClientInterceptor {
// if c.RetryRateLimit == nil {
// c.RetryRateLimit = c.defaultRetryRateLimitOption()
// }

// return RetryOnRateLimitInterceptor(c.RetryRateLimit.MaxRetry, c.RetryRateLimit.MaxBackoff, func(ctx context.Context, attempt uint) time.Duration {
// return 10 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
// })
// }

// func (c *ClientConfig) defaultRetryRateLimitOption() *RetryRateLimitOption {
// return &RetryRateLimitOption{
// MaxRetry: 75,
// MaxBackoff: 3 * time.Second,
// }
// }
func (c *ClientConfig) defaultRetryRateLimitOption() *RetryRateLimitOption {
return &RetryRateLimitOption{
MaxRetry: 75,
MaxBackoff: 3 * time.Second,
}
}

// addFlags set internal flags
func (c *ClientConfig) addFlags(flags uint64) {
Expand Down
1 change: 1 addition & 0 deletions client/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ func (c *Client) DescribeCollection(ctx context.Context, option *describeCollect
VirtualChannels: resp.GetVirtualChannelNames(),
ConsistencyLevel: entity.ConsistencyLevel(resp.ConsistencyLevel),
ShardNum: resp.GetShardsNum(),
Properties: entity.KvPairsMap(resp.GetProperties()),
}
collection.Name = collection.Schema.CollectionName
return nil
Expand Down
8 changes: 5 additions & 3 deletions client/collection_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ func SimpleCreateCollectionOptions(name string, dim int64) *createCollectionOpti
autoID: true,
dim: dim,
enabledDynamicSchema: true,
consistencyLevel: entity.DefaultConsistencyLevel,

isFast: true,
metricType: entity.COSINE,
Expand All @@ -149,9 +150,10 @@ func SimpleCreateCollectionOptions(name string, dim int64) *createCollectionOpti
// NewCreateCollectionOption returns a CreateCollectionOption with customized collection schema
func NewCreateCollectionOption(name string, collectionSchema *entity.Schema) *createCollectionOption {
return &createCollectionOption{
name: name,
shardNum: 1,
schema: collectionSchema,
name: name,
shardNum: 1,
schema: collectionSchema,
consistencyLevel: entity.DefaultConsistencyLevel,

metricType: entity.COSINE,
}
Expand Down
32 changes: 22 additions & 10 deletions client/column/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,26 +64,38 @@ var errFieldDataTypeNotMatch = errors.New("FieldData type not matched")

// IDColumns converts schemapb.IDs to corresponding column
// currently Int64 / string may be in IDs
func IDColumns(idField *schemapb.IDs, begin, end int) (Column, error) {
func IDColumns(schema *entity.Schema, ids *schemapb.IDs, begin, end int) (Column, error) {
var idColumn Column
if idField == nil {
pkField := schema.PKField()
if pkField == nil {
return nil, errors.New("PK Field not found")
}
if ids == nil {
return nil, errors.New("nil Ids from response")
}
switch field := idField.GetIdField().(type) {
case *schemapb.IDs_IntId:
switch pkField.DataType {
case entity.FieldTypeInt64:
data := ids.GetIntId().GetData()
if data == nil {
return NewColumnInt64(pkField.Name, nil), nil
}
if end >= 0 {
idColumn = NewColumnInt64("", field.IntId.GetData()[begin:end])
idColumn = NewColumnInt64(pkField.Name, data[begin:end])
} else {
idColumn = NewColumnInt64("", field.IntId.GetData()[begin:])
idColumn = NewColumnInt64(pkField.Name, data[begin:])
}
case entity.FieldTypeVarChar, entity.FieldTypeString:
data := ids.GetStrId().GetData()
if data == nil {
return NewColumnVarChar(pkField.Name, nil), nil
}
case *schemapb.IDs_StrId:
if end >= 0 {
idColumn = NewColumnVarChar("", field.StrId.GetData()[begin:end])
idColumn = NewColumnVarChar(pkField.Name, data[begin:end])
} else {
idColumn = NewColumnVarChar("", field.StrId.GetData()[begin:])
idColumn = NewColumnVarChar(pkField.Name, data[begin:])
}
default:
return nil, fmt.Errorf("unsupported id type %v", field)
return nil, fmt.Errorf("unsupported id type %v", pkField.DataType)
}
return idColumn, nil
}
Expand Down
Loading

0 comments on commit 1ef975d

Please sign in to comment.