diff --git a/extendtimeout/connection.go b/extendtimeout/connection.go index 288a872..8c4fc4f 100644 --- a/extendtimeout/connection.go +++ b/extendtimeout/connection.go @@ -21,7 +21,6 @@ import ( "context" "time" - "github.com/edwarnicke/log" "go.fd.io/govpp/api" ) @@ -30,15 +29,6 @@ type extendedConnection struct { contextTimeout time.Duration } -type extendedContext struct { - context.Context - valuesContext context.Context -} - -func (ec *extendedContext) Value(key interface{}) interface{} { - return ec.valuesContext.Value(key) -} - // NewConnection - creates a wrapper for vpp connection that uses extended context timeout for all operations func NewConnection(vppConn api.Connection, contextTimeout time.Duration) api.Connection { return &extendedConnection{ @@ -48,27 +38,26 @@ func NewConnection(vppConn api.Connection, contextTimeout time.Duration) api.Con } func (c *extendedConnection) Invoke(ctx context.Context, req, reply api.Message) error { - ctx, cancel := c.withExtendedTimeoutCtx(ctx) + ctx, cancel := c.withExtendedTimeoutContext(ctx) err := c.Connection.Invoke(ctx, req, reply) cancel() return err } -func (c *extendedConnection) withExtendedTimeoutCtx(ctx context.Context) (extendedCtx context.Context, cancel func()) { - deadline, ok := ctx.Deadline() - if !ok { - return ctx, func() {} - } - - minDeadline := time.Now().Add(c.contextTimeout) - if minDeadline.Before(deadline) { - return ctx, func() {} - } - log.Entry(ctx).Warnf("Context deadline has been extended by extendtimeout from %v to %v", deadline, minDeadline) - deadline = minDeadline - postponedCtx, cancel := context.WithDeadline(context.Background(), deadline) - return &extendedContext{ - Context: postponedCtx, - valuesContext: ctx, - }, cancel +func (c *extendedConnection) withExtendedTimeoutContext(ctx context.Context) (context.Context, context.CancelFunc) { + var cancelContext, cancel = context.WithCancel(context.Background()) + var timeoutContext, timeoutCancel = context.WithTimeout(cancelContext, c.contextTimeout) + go func() { + <-timeoutContext.Done() + timeoutCancel() + select { + case <-cancelContext.Done(): + return + case <-ctx.Done(): + cancel() + return + } + }() + + return cancelContext, cancel } diff --git a/extendtimeout/connection_test.go b/extendtimeout/connection_test.go index 4d0025e..1ce342f 100644 --- a/extendtimeout/connection_test.go +++ b/extendtimeout/connection_test.go @@ -39,56 +39,7 @@ func (c *testConn) Invoke(ctx context.Context, req, reply api.Message) error { return nil } -type key struct{} - -const value = "value" - -func TestSmallTimeout(t *testing.T) { - testConn := &testConn{invokeBody: func(ctx context.Context) { - deadline, ok := ctx.Deadline() - require.True(t, ok) - timeout := time.Until(deadline) - require.Greater(t, timeout, time.Second) - require.Equal(t, ctx.Value(&key{}), value) - }} - - smallCtx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) - smallCtx = context.WithValue(smallCtx, &key{}, value) - defer cancel() - - err := extendtimeout.NewConnection(testConn, 2*time.Second).Invoke(smallCtx, nil, nil) - require.NoError(t, err) -} - -func TestBigTimeout(t *testing.T) { - testConn := &testConn{invokeBody: func(ctx context.Context) { - deadline, ok := ctx.Deadline() - require.True(t, ok) - timeout := time.Until(deadline) - require.Greater(t, timeout, 7*time.Second) - require.Equal(t, ctx.Value(&key{}), value) - }} - - bigCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - bigCtx = context.WithValue(bigCtx, &key{}, value) - defer cancel() - - err := extendtimeout.NewConnection(testConn, 2*time.Second).Invoke(bigCtx, nil, nil) - require.NoError(t, err) -} - -func TestNoTimeout(t *testing.T) { - testConn := &testConn{invokeBody: func(ctx context.Context) { - _, ok := ctx.Deadline() - require.False(t, ok) - }} - - emptyCtx := context.Background() - err := extendtimeout.NewConnection(testConn, 2*time.Second).Invoke(emptyCtx, nil, nil) - require.NoError(t, err) -} - -func TestCanceledContext(t *testing.T) { +func TestOriginalContextCanceled(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) @@ -105,16 +56,18 @@ func TestCanceledContext(t *testing.T) { } }} - cancelCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + cancelCtx, cancel := context.WithCancel(context.Background()) + go func() { - err := extendtimeout.NewConnection(testConn, 20*time.Second).Invoke(cancelCtx, nil, nil) + err := extendtimeout.NewConnection(testConn, 10*time.Second).Invoke(cancelCtx, nil, nil) require.NoError(t, err) }() cancel() + time.Sleep(50 * time.Millisecond) ch <- struct{}{} require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second, 100*time.Millisecond) + }, 200*time.Millisecond, 10*time.Millisecond) }