From f197a8bae0c87e1b7cc2e32e399a40665a82f077 Mon Sep 17 00:00:00 2001 From: wei Date: Mon, 6 Jun 2022 18:43:53 +0800 Subject: [PATCH] feat(context): add ContextWithFallback feature flag (#3166) (#3172) Enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value() --- context.go | 8 ++-- context_test.go | 115 +++++++++++++++++++++++++++++++++++++++++++----- gin.go | 3 ++ 3 files changed, 110 insertions(+), 16 deletions(-) diff --git a/context.go b/context.go index 6b25b3a31d..b1ad95e623 100644 --- a/context.go +++ b/context.go @@ -1158,7 +1158,7 @@ func (c *Context) SetAccepted(formats ...string) { // Deadline returns that there is no deadline (ok==false) when c.Request has no Context. func (c *Context) Deadline() (deadline time.Time, ok bool) { - if c.Request == nil || c.Request.Context() == nil { + if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil { return } return c.Request.Context().Deadline() @@ -1166,7 +1166,7 @@ func (c *Context) Deadline() (deadline time.Time, ok bool) { // Done returns nil (chan which will wait forever) when c.Request has no Context. func (c *Context) Done() <-chan struct{} { - if c.Request == nil || c.Request.Context() == nil { + if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil { return nil } return c.Request.Context().Done() @@ -1174,7 +1174,7 @@ func (c *Context) Done() <-chan struct{} { // Err returns nil when c.Request has no Context. func (c *Context) Err() error { - if c.Request == nil || c.Request.Context() == nil { + if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil { return nil } return c.Request.Context().Err() @@ -1195,7 +1195,7 @@ func (c *Context) Value(key any) any { return val } } - if c.Request == nil || c.Request.Context() == nil { + if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil { return nil } return c.Request.Context().Value(key) diff --git a/context_test.go b/context_test.go index f9f0c1dc19..6c0be544ef 100644 --- a/context_test.go +++ b/context_test.go @@ -2097,12 +2097,18 @@ func TestRemoteIPFail(t *testing.T) { } func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) { - c := &Context{} + c, _ := CreateTestContext(httptest.NewRecorder()) + // enable ContextWithFallback feature flag + c.engine.ContextWithFallback = true + deadline, ok := c.Deadline() assert.Zero(t, deadline) assert.False(t, ok) - c2 := &Context{} + c2, _ := CreateTestContext(httptest.NewRecorder()) + // enable ContextWithFallback feature flag + c2.engine.ContextWithFallback = true + c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil) d := time.Now().Add(time.Second) ctx, cancel := context.WithDeadline(context.Background(), d) @@ -2114,10 +2120,16 @@ func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) { } func TestContextWithFallbackDoneFromRequestContext(t *testing.T) { - c := &Context{} + c, _ := CreateTestContext(httptest.NewRecorder()) + // enable ContextWithFallback feature flag + c.engine.ContextWithFallback = true + assert.Nil(t, c.Done()) - c2 := &Context{} + c2, _ := CreateTestContext(httptest.NewRecorder()) + // enable ContextWithFallback feature flag + c2.engine.ContextWithFallback = true + c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil) ctx, cancel := context.WithCancel(context.Background()) c2.Request = c2.Request.WithContext(ctx) @@ -2126,10 +2138,16 @@ func TestContextWithFallbackDoneFromRequestContext(t *testing.T) { } func TestContextWithFallbackErrFromRequestContext(t *testing.T) { - c := &Context{} + c, _ := CreateTestContext(httptest.NewRecorder()) + // enable ContextWithFallback feature flag + c.engine.ContextWithFallback = true + assert.Nil(t, c.Err()) - c2 := &Context{} + c2, _ := CreateTestContext(httptest.NewRecorder()) + // enable ContextWithFallback feature flag + c2.engine.ContextWithFallback = true + c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil) ctx, cancel := context.WithCancel(context.Background()) c2.Request = c2.Request.WithContext(ctx) @@ -2138,9 +2156,9 @@ func TestContextWithFallbackErrFromRequestContext(t *testing.T) { assert.EqualError(t, c2.Err(), context.Canceled.Error()) } -type contextKey string - func TestContextWithFallbackValueFromRequestContext(t *testing.T) { + type contextKey string + tests := []struct { name string getContextAndKey func() (*Context, any) @@ -2150,7 +2168,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) { name: "c with struct context key", getContextAndKey: func() (*Context, any) { var key struct{} - c := &Context{} + c, _ := CreateTestContext(httptest.NewRecorder()) + // enable ContextWithFallback feature flag + c.engine.ContextWithFallback = true c.Request, _ = http.NewRequest("POST", "/", nil) c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "value")) return c, key @@ -2160,7 +2180,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) { { name: "c with string context key", getContextAndKey: func() (*Context, any) { - c := &Context{} + c, _ := CreateTestContext(httptest.NewRecorder()) + // enable ContextWithFallback feature flag + c.engine.ContextWithFallback = true c.Request, _ = http.NewRequest("POST", "/", nil) c.Request = c.Request.WithContext(context.WithValue(context.TODO(), contextKey("key"), "value")) return c, contextKey("key") @@ -2170,7 +2192,10 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) { { name: "c with nil http.Request", getContextAndKey: func() (*Context, any) { - c := &Context{} + c, _ := CreateTestContext(httptest.NewRecorder()) + // enable ContextWithFallback feature flag + c.engine.ContextWithFallback = true + c.Request = nil return c, "key" }, value: nil, @@ -2178,7 +2203,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) { { name: "c with nil http.Request.Context()", getContextAndKey: func() (*Context, any) { - c := &Context{} + c, _ := CreateTestContext(httptest.NewRecorder()) + // enable ContextWithFallback feature flag + c.engine.ContextWithFallback = true c.Request, _ = http.NewRequest("POST", "/", nil) return c, "key" }, @@ -2193,6 +2220,70 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) { } } +func TestContextCopyShouldNotCancel(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + ensureRequestIsOver := make(chan struct{}) + + wg := &sync.WaitGroup{} + + r := New() + r.GET("/", func(ginctx *Context) { + wg.Add(1) + + ginctx = ginctx.Copy() + + // start async goroutine for calling srv + go func() { + defer wg.Done() + + <-ensureRequestIsOver // ensure request is done + + req, err := http.NewRequestWithContext(ginctx, http.MethodGet, srv.URL, nil) + must(err) + + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Error(fmt.Errorf("request error: %w", err)) + return + } + + if res.StatusCode != http.StatusOK { + t.Error(fmt.Errorf("unexpected status code: %s", res.Status)) + } + }() + }) + + l, err := net.Listen("tcp", ":0") + must(err) + go func() { + s := &http.Server{ + Handler: r, + } + + must(s.Serve(l)) + }() + + addr := strings.Split(l.Addr().String(), ":") + res, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/", addr[len(addr)-1])) + if err != nil { + t.Error(fmt.Errorf("request error: %w", err)) + return + } + + close(ensureRequestIsOver) + + if res.StatusCode != http.StatusOK { + t.Error(fmt.Errorf("unexpected status code: %s", res.Status)) + return + } + + wg.Wait() +} + func TestContextAddParam(t *testing.T) { c := &Context{} id := "id" diff --git a/gin.go b/gin.go index e516360caa..f9324299b7 100644 --- a/gin.go +++ b/gin.go @@ -147,6 +147,9 @@ type Engine struct { // UseH2C enable h2c support. UseH2C bool + // ContextWithFallback enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value() when Context.Request.Context() is not nil. + ContextWithFallback bool + delims render.Delims secureJSONPrefix string HTMLRender render.HTMLRender