Skip to content

Commit

Permalink
address feedback from pr, mostly moving things around
Browse files Browse the repository at this point in the history
  • Loading branch information
Madrigal committed Jun 6, 2024
1 parent 49e7140 commit 6d365fb
Show file tree
Hide file tree
Showing 13 changed files with 407 additions and 378 deletions.
56 changes: 0 additions & 56 deletions aws/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package middleware
import (
"context"
"fmt"
"sync/atomic"
"time"

"github.com/aws/aws-sdk-go-v2/internal/rand"
Expand Down Expand Up @@ -125,19 +124,6 @@ func setAttemptSkew(metadata *middleware.Metadata, v time.Duration) {
metadata.Set(attemptSkewKey{}, v)
}

type clockSkew struct{}

// SetAttemptSkewContext sets the clock skew value on the context
func SetAttemptSkewContext(ctx context.Context, v time.Duration) context.Context {
return middleware.WithStackValue(ctx, clockSkew{}, v)
}

// GetAttemptSkewContext gets the clock skew value from the context
func GetAttemptSkewContext(ctx context.Context) time.Duration {
x, _ := middleware.GetStackValue(ctx, clockSkew{}).(time.Duration)
return x
}

// AddClientRequestIDMiddleware adds ClientRequestID to the middleware stack
func AddClientRequestIDMiddleware(stack *middleware.Stack) error {
return stack.Build.Add(&ClientRequestID{}, middleware.After)
Expand Down Expand Up @@ -180,45 +166,3 @@ func AddRawResponseToMetadata(stack *middleware.Stack) error {
func GetRawResponse(metadata middleware.Metadata) interface{} {
return metadata.Get(rawResponseKey{})
}

// AddTimeOffsetBuildMiddleware sets a value representing clock skew on the request context.
// This can be read by other operations (such as signing) to correct the date value they send
// on the request
type AddTimeOffsetBuildMiddleware struct {
Offset *atomic.Int64
}

// ID the identifier for AddTimeOffsetBuildMiddleware
func (m *AddTimeOffsetBuildMiddleware) ID() string { return "AddTimeOffsetMiddleware" }

// HandleBuild sets a value for attemptSkew on the request context if one is set on the client.
func (m AddTimeOffsetBuildMiddleware) HandleBuild(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) (
out middleware.BuildOutput, metadata middleware.Metadata, err error,
) {
if m.Offset != nil {
offset := time.Duration(m.Offset.Load())
ctx = SetAttemptSkewContext(ctx, offset)
}
return next.HandleBuild(ctx, in)
}

// AddTimeOffsetDeserializeMiddleware sets the clock skew on the client if it's present on the context
// at the end of the request
type AddTimeOffsetDeserializeMiddleware struct {
Offset *atomic.Int64
}

// ID the identifier for AddTimeOffsetDeserializeMiddleware
func (m *AddTimeOffsetDeserializeMiddleware) ID() string { return "AddTimeOffsetDeserializeMiddleware" }

// HandleDeserialize gets the clock skew context from the context, and if set, sets it on the pointer
// held by AddTimeOffsetDeserializeMiddleware
func (m *AddTimeOffsetDeserializeMiddleware) HandleDeserialize(ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) (
out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
) {
v := GetAttemptSkewContext(ctx)
if v != 0 {
m.Offset.Store(v.Nanoseconds())
}
return next.HandleDeserialize(ctx, in)
}
240 changes: 0 additions & 240 deletions aws/middleware/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,9 @@ package middleware_test
import (
"bytes"
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/retry"
"net/http"
"reflect"
"strings"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -191,239 +187,3 @@ func TestAttemptClockSkewHandler(t *testing.T) {
})
}
}

type HTTPClient interface {
Do(*http.Request) (*http.Response, error)
}

type Options struct {
HTTPClient HTTPClient
RetryMode aws.RetryMode
Retryer aws.Retryer
Offset *atomic.Int64
}

type MockClient struct {
options Options
}

func addRetry(stack *smithymiddleware.Stack, o Options) error {
attempt := retry.NewAttemptMiddleware(o.Retryer, smithyhttp.RequestCloner, func(m *retry.Attempt) {
m.LogAttempts = false
})
return stack.Finalize.Add(attempt, smithymiddleware.After)
}

func addOffset(stack *smithymiddleware.Stack, o Options) error {
buildOffset := middleware.AddTimeOffsetBuildMiddleware{Offset: o.Offset}
deserializeOffset := middleware.AddTimeOffsetDeserializeMiddleware{Offset: o.Offset}
err := stack.Build.Add(&buildOffset, smithymiddleware.After)
if err != nil {
return err
}
err = stack.Deserialize.Add(&deserializeOffset, smithymiddleware.Before)
if err != nil {
return err
}
return nil
}

// Middleware to set a `Date` object that includes sdk time and offset
type MockAddDateHeader struct {
}

func (l *MockAddDateHeader) ID() string {
return "MockAddDateHeader"
}

func (l *MockAddDateHeader) HandleFinalize(
ctx context.Context, in smithymiddleware.FinalizeInput, next smithymiddleware.FinalizeHandler,
) (
out smithymiddleware.FinalizeOutput, metadata smithymiddleware.Metadata, attemptError error,
) {
req := in.Request.(*smithyhttp.Request)
date := sdk.NowTime()
skew := middleware.GetAttemptSkewContext(ctx)
date = date.Add(skew)
req.Header.Set("Date", date.Format(time.RFC850))
return next.HandleFinalize(ctx, in)
}

// Middleware to deserialize the response which just says "OK" if the response is 200
type DeserializeFailIfNotHTTP200 struct {
}

func (*DeserializeFailIfNotHTTP200) ID() string {
return "DeserializeFailIfNotHTTP200"
}

func (m *DeserializeFailIfNotHTTP200) HandleDeserialize(ctx context.Context, in smithymiddleware.DeserializeInput, next smithymiddleware.DeserializeHandler) (
out smithymiddleware.DeserializeOutput, metadata smithymiddleware.Metadata, err error,
) {
out, metadata, err = next.HandleDeserialize(ctx, in)
if err != nil {
return out, metadata, err
}
response, ok := out.RawResponse.(*smithyhttp.Response)
if !ok {
return out, metadata, fmt.Errorf("expected raw response to be set on testing")
}
if response.StatusCode != 200 {
return out, metadata, mockRetryableError{true}
}
return out, metadata, err
}

func (c *MockClient) setupMiddleware(stack *smithymiddleware.Stack) error {
err := error(nil)
if c.options.Retryer != nil {
err = addRetry(stack, c.options)
if err != nil {
return err
}
}
if c.options.Offset != nil {
err = addOffset(stack, c.options)
if err != nil {
return err
}
}
err = stack.Finalize.Add(&MockAddDateHeader{}, smithymiddleware.After)
if err != nil {
return err
}
err = middleware.AddRecordResponseTiming(stack)
if err != nil {
return err
}
err = stack.Deserialize.Add(&DeserializeFailIfNotHTTP200{}, smithymiddleware.After)
if err != nil {
return err
}
return nil
}

func (c *MockClient) Do(ctx context.Context) (interface{}, error) {
// setup middlewares
ctx = smithymiddleware.ClearStackValues(ctx)
stack := smithymiddleware.NewStack("stack", smithyhttp.NewStackRequest)
err := c.setupMiddleware(stack)
if err != nil {
return nil, err
}
handler := smithymiddleware.DecorateHandler(smithyhttp.NewClientHandler(c.options.HTTPClient), stack)
result, _, err := handler.Handle(ctx, 1)
if err != nil {
return nil, err
}
return result, err
}

type mockRetryableError struct{ b bool }

func (m mockRetryableError) RetryableError() bool { return m.b }
func (m mockRetryableError) Error() string {
return fmt.Sprintf("mock retryable %t", m.b)
}

func failRequestIfSkewed() smithyhttp.ClientDoFunc {
return func(req *http.Request) (*http.Response, error) {
dateHeader := req.Header.Get("Date")
if dateHeader == "" {
return nil, fmt.Errorf("expected `Date` header to be set")
}
reqDate, err := time.Parse(time.RFC850, dateHeader)
if err != nil {
return nil, err
}
parsedReqTime := time.Now().Sub(reqDate)
parsedReqTime = time.Duration.Abs(parsedReqTime)
thresholdForSkewError := 4 * time.Minute
if thresholdForSkewError-parsedReqTime <= 0 {
return &http.Response{
StatusCode: 403,
Header: http.Header{
"Date": {time.Now().Format(time.RFC850)},
},
}, nil
}
// else, return OK
return &http.Response{
StatusCode: 200,
Header: http.Header{},
}, nil
}
}

func TestSdkOffsetIsSet(t *testing.T) {
nowTime := sdk.NowTime
defer func() {
sdk.NowTime = nowTime
}()
fiveMinuteSkew := func() time.Time {
return time.Now().Add(5 * time.Minute)
}
sdk.NowTime = fiveMinuteSkew
c := MockClient{
Options{
HTTPClient: failRequestIfSkewed(),
},
}
resp, err := c.Do(context.Background())
if err == nil {
t.Errorf("Expected first request to fail since clock skew logic has not run. Got %v and err %v", resp, err)
}
}

func TestRetrySetsSkewInContext(t *testing.T) {
defer resetDefaults(sdk.TestingUseNopSleep())
fiveMinuteSkew := func() time.Time {
return time.Now().Add(5 * time.Minute)
}
sdk.NowTime = fiveMinuteSkew
c := MockClient{
Options{
HTTPClient: failRequestIfSkewed(),
Retryer: retry.NewStandard(func(s *retry.StandardOptions) {
}),
},
}
resp, err := c.Do(context.Background())
if err != nil {
t.Errorf("Expected request to succeed on retry. Got %v and err %v", resp, err)
}
}

func TestSkewIsSetOnTheWholeClient(t *testing.T) {
defer resetDefaults(sdk.TestingUseNopSleep())
fiveMinuteSkew := func() time.Time {
return time.Now().Add(5 * time.Minute)
}
sdk.NowTime = fiveMinuteSkew
var offset atomic.Int64
offset.Store(0)
c := MockClient{
Options{
HTTPClient: failRequestIfSkewed(),
Retryer: retry.NewStandard(func(s *retry.StandardOptions) {
}),
Offset: &offset,
},
}
resp, err := c.Do(context.Background())
if err != nil {
t.Errorf("Expected request to succeed on retry. Got %v and err %v", resp, err)
}
// Remove retryer so it has to succeed on first call
c.options.Retryer = nil
// same client, new request
resp, err = c.Do(context.Background())
if err != nil {
t.Errorf("Expected second request to succeed since the skew should be set on the client. Got %v and err %v", resp, err)
}
}

func resetDefaults(restoreSleepFunc func()) {
sdk.NowTime = time.Now
restoreSleepFunc()
}
Loading

0 comments on commit 6d365fb

Please sign in to comment.