Skip to content

Commit

Permalink
feat: client support credentialsProvider (#216)
Browse files Browse the repository at this point in the history
* feat: client support credentialsProvider

* fix build, remove 1.20 func usage

* fix conflict

* refine to code review

* refine readme

* use atomic, add ecs ram fetcher

* add ram fetcher ut
  • Loading branch information
crimson-gao authored Aug 29, 2023
1 parent 2c71d01 commit 386a0d0
Show file tree
Hide file tree
Showing 15 changed files with 753 additions and 43 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ go get -u github.com/aliyun/aliyun-log-go-sdk
```go
AccessKeyID = "your ak id"
AccessKeySecret = "your ak secret"
credentialsProvider := NewStaticCredentialsProvider(AccessKeyID, AccessKeySecret, "")
Endpoint = "your endpoint" // just like cn-hangzhou.log.aliyuncs.com
Client = sls.CreateNormalInterface(Endpoint,AccessKeyID,AccessKeySecret,"")
Client = sls.CreateNormalInterfaceV2(Endpoint, credentialsProvider)
```

2. **创建project**
Expand Down
24 changes: 19 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,18 @@ func IsTokenError(err error) bool {
// Client ...
type Client struct {
Endpoint string // IP or hostname of SLS endpoint
AccessKeyID string
AccessKeySecret string
SecurityToken string
AccessKeyID string // Deprecated: use credentialsProvider instead
AccessKeySecret string // Deprecated: use credentialsProvider instead
SecurityToken string // Deprecated: use credentialsProvider instead
UserAgent string // default defaultLogUserAgent
RequestTimeOut time.Duration
RetryTimeOut time.Duration
HTTPClient *http.Client
Region string
AuthVersion AuthVersionType // v1 or v4 signature,default is v1

accessKeyLock sync.RWMutex
accessKeyLock sync.RWMutex
credentialsProvider CredentialsProvider
// User defined common headers.
// When conflict with sdk pre-defined headers, the value will
// be ignored
Expand All @@ -120,7 +121,13 @@ func convert(c *Client, projName string) *LogProject {
}

func convertLocked(c *Client, projName string) *LogProject {
p, _ := NewLogProject(projName, c.Endpoint, c.AccessKeyID, c.AccessKeySecret)
var p *LogProject
if c.credentialsProvider != nil {
p, _ = NewLogProjectV2(projName, c.Endpoint, c.credentialsProvider)
} else { // back compatible
p, _ = NewLogProject(projName, c.Endpoint, c.AccessKeyID, c.AccessKeySecret)
}

p.SecurityToken = c.SecurityToken
p.UserAgent = c.UserAgent
p.AuthVersion = c.AuthVersion
Expand All @@ -139,6 +146,12 @@ func convertLocked(c *Client, projName string) *LogProject {
return p
}

// Set credentialsProvider for client and returns the same client.
func (c *Client) WithCredentialsProvider(provider CredentialsProvider) *Client {
c.credentialsProvider = provider
return c
}

// SetUserAgent set a custom userAgent
func (c *Client) SetUserAgent(userAgent string) {
c.UserAgent = userAgent
Expand Down Expand Up @@ -169,6 +182,7 @@ func (c *Client) ResetAccessKeyToken(accessKeyID, accessKeySecret, securityToken
c.AccessKeyID = accessKeyID
c.AccessKeySecret = accessKeySecret
c.SecurityToken = securityToken
c.credentialsProvider = NewStaticCredentialsProvider(accessKeyID, accessKeySecret, securityToken)
c.accessKeyLock.Unlock()
}

Expand Down
40 changes: 37 additions & 3 deletions client_interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,54 @@ import (
"time"
)

// CreateNormalInterface create a normal client
// CreateNormalInterface create a normal client.
//
// Deprecated: use CreateNormalInterfaceV2 instead.
// If you keep using long-lived AccessKeyID and AccessKeySecret,
// use the example code below.
//
// provider := NewStaticCredProvider(accessKeyID, accessKeySecret, securityToken)
// client := CreateNormalInterfaceV2(endpoint, provider)
func CreateNormalInterface(endpoint, accessKeyID, accessKeySecret, securityToken string) ClientInterface {
return &Client{
Endpoint: endpoint,
AccessKeyID: accessKeyID,
AccessKeySecret: accessKeySecret,
SecurityToken: securityToken,

credentialsProvider: NewStaticCredentialsProvider(
accessKeyID,
accessKeySecret,
securityToken,
),
}
}

// CreateNormalInterfaceV2 create a normal client, with a CredentialsProvider.
//
// It is highly recommended to use a CredentialsProvider that provides dynamic
// expirable credentials for security.
//
// See [credentials_provider.go] for more details.
func CreateNormalInterfaceV2(endpoint string, credentialsProvider CredentialsProvider) ClientInterface {
return &Client{
Endpoint: endpoint,
credentialsProvider: credentialsProvider,
}
}

type UpdateTokenFunction func() (accessKeyID, accessKeySecret, securityToken string, expireTime time.Time, err error)
type UpdateTokenFunction = func() (accessKeyID, accessKeySecret, securityToken string, expireTime time.Time, err error)

// CreateTokenAutoUpdateClient crate a TokenAutoUpdateClient
// CreateTokenAutoUpdateClient create a TokenAutoUpdateClient,
// this client will auto fetch security token and retry when operation is `Unauthorized`
//
// Deprecated: Use CreateNormalInterfaceV2 and UpdateFuncProviderAdapter instead.
//
// Example:
//
// provider := NewUpdateFuncProviderAdapter(updateStsTokenFunc)
// client := CreateNormalInterfaceV2(endpoint, provider)
//
// @note TokenAutoUpdateClient will destroy when shutdown channel is closed
func CreateTokenAutoUpdateClient(endpoint string, tokenUpdateFunc UpdateTokenFunction, shutdown <-chan struct{}) (client ClientInterface, err error) {
accessKeyID, accessKeySecret, securityToken, expireTime, err := tokenUpdateFunc()
Expand Down
10 changes: 10 additions & 0 deletions client_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ func (c *Client) request(project, method, uri string, headers map[string]string,
authVersion := c.AuthVersion
c.accessKeyLock.RUnlock()

if c.credentialsProvider != nil {
res, err := c.credentialsProvider.GetCredentials()
if err != nil {
return nil, fmt.Errorf("fail to fetch credentials: %w", err)
}
accessKeyID = res.AccessKeyID
accessKeySecret = res.AccessKeySecret
stsToken = res.SecurityToken
}

// Access with token
if stsToken != "" {
headers[HTTPHeaderAcsSecurityToken] = stsToken
Expand Down
69 changes: 69 additions & 0 deletions credentials.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package sls

import (
"time"
)

type Credentials struct {
AccessKeyID string
AccessKeySecret string
SecurityToken string
}

const DEFAULT_EXPIRED_FACTOR = 0.8

// Expirable credentials with an expiration.
type TempCredentials struct {
Credentials
expiredFactor float64
expirationInMills int64 // The time when the credentials expires, unix timestamp in millis
lastUpdatedInMills int64
}

func NewTempCredentials(accessKeyId, accessKeySecret, securityToken string,
expirationInMills, lastUpdatedInMills int64) *TempCredentials {

return &TempCredentials{
Credentials: Credentials{
AccessKeyID: accessKeyId,
AccessKeySecret: accessKeySecret,
SecurityToken: securityToken,
},
expirationInMills: expirationInMills,
lastUpdatedInMills: lastUpdatedInMills,
expiredFactor: DEFAULT_EXPIRED_FACTOR,
}
}

// @param factor must > 0.0 and <= 1.0, the less the factor is,
// the more frequently the credentials will be updated.
//
// If factor is set to 0, the credentials will be fetched every time
// [GetCredentials] is called.
//
// If factor is set to 1, the credentials will be fetched only when expired .
func (t *TempCredentials) WithExpiredFactor(factor float64) *TempCredentials {
if factor > 0.0 && factor <= 1.0 {
t.expiredFactor = factor
}
return t
}

// Returns true if credentials has expired already or will expire soon.
func (t *TempCredentials) ShouldRefresh() bool {
now := time.Now().UnixMilli()
if now >= t.expirationInMills {
return true
}
duration := (float64)(t.expirationInMills-t.lastUpdatedInMills) * t.expiredFactor
if duration < 0.0 { // check here
duration = 0
}
return (now - t.lastUpdatedInMills) >= int64(duration)
}

// Returns true if credentials has expired already.
func (t *TempCredentials) HasExpired() bool {
now := time.Now().UnixMilli()
return now >= t.expirationInMills
}
Loading

0 comments on commit 386a0d0

Please sign in to comment.