diff --git a/codegen/testserver/generated_test.go b/codegen/testserver/generated_test.go index 33b31564dd8..6b56582e20d 100644 --- a/codegen/testserver/generated_test.go +++ b/codegen/testserver/generated_test.go @@ -41,16 +41,26 @@ type testTracer struct { append func(string) } +func (tt *testTracer) StartRequestTracing(ctx context.Context) (context.Context, error) { + tt.append(fmt.Sprintf("request:start:%d", tt.id)) + return ctx, nil +} + +func (tt *testTracer) EndRequestTracing(ctx context.Context) error { + tt.append(fmt.Sprintf("request:end:%d", tt.id)) + return nil +} + func (tt *testTracer) StartFieldTracing(ctx context.Context) (context.Context, error) { tracerIDs, _ := ctx.Value("tracer").([]int) ctx = context.WithValue(ctx, "tracer", append(tracerIDs, tt.id)) rc := graphql.GetResolverContext(ctx) - tt.append(fmt.Sprintf("start:%d:%v", tt.id, rc.Path())) + tt.append(fmt.Sprintf("resolver:start:%d:%v", tt.id, rc.Path())) return ctx, nil } func (tt *testTracer) EndFieldTracing(ctx context.Context) error { - tt.append(fmt.Sprintf("end:%d", tt.id)) + tt.append(fmt.Sprintf("resolver:end:%d", tt.id)) return nil } @@ -178,18 +188,22 @@ func TestGeneratedServer(t *testing.T) { require.NoError(t, err) require.True(t, called) assert.Equal(t, []string{ - "start:1:[user]", - "start:2:[user]", - "start:1:[user id]", - "start:2:[user id]", - "end:2", - "end:1", - "start:1:[user friends]", - "start:2:[user friends]", - "end:2", - "end:1", - "end:2", - "end:1", + "request:start:1", + "request:start:2", + "resolver:start:1:[user]", + "resolver:start:2:[user]", + "resolver:start:1:[user id]", + "resolver:start:2:[user id]", + "resolver:end:2", + "resolver:end:1", + "resolver:start:1:[user friends]", + "resolver:start:2:[user friends]", + "resolver:end:2", + "resolver:end:1", + "resolver:end:2", + "resolver:end:1", + "request:end:2", + "request:end:1", }, tracerLog) }) diff --git a/graphql/tracer.go b/graphql/tracer.go index 580a4761fef..e0afda46b64 100644 --- a/graphql/tracer.go +++ b/graphql/tracer.go @@ -5,12 +5,22 @@ import "context" var _ Tracer = (*NopTracer)(nil) type Tracer interface { + StartRequestTracing(ctx context.Context) (context.Context, error) + EndRequestTracing(ctx context.Context) error StartFieldTracing(ctx context.Context) (context.Context, error) EndFieldTracing(ctx context.Context) error } type NopTracer struct{} +func (NopTracer) StartRequestTracing(ctx context.Context) (context.Context, error) { + return ctx, nil +} + +func (NopTracer) EndRequestTracing(ctx context.Context) error { + return nil +} + func (NopTracer) StartFieldTracing(ctx context.Context) (context.Context, error) { return ctx, nil } diff --git a/handler/graphql.go b/handler/graphql.go index 431cac076cf..71bf51f5c73 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -135,14 +135,23 @@ func Tracer(tracer graphql.Tracer) Option { return func(cfg *Config) { if cfg.tracer == nil { cfg.tracer = tracer - return - } - lastResolve := cfg.tracer - cfg.tracer = &tracerWrapper{ - tracer1: lastResolve, - tracer2: tracer, + } else { + lastResolve := cfg.tracer + cfg.tracer = &tracerWrapper{ + tracer1: lastResolve, + tracer2: tracer, + } } + + opt := RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte { + ctx, _ = tracer.StartRequestTracing(ctx) + resp := next(ctx) + _ = tracer.EndRequestTracing(ctx) + + return resp + }) + opt(cfg) } } @@ -151,6 +160,23 @@ type tracerWrapper struct { tracer2 graphql.Tracer } +func (tw *tracerWrapper) StartRequestTracing(ctx context.Context) (context.Context, error) { + ctx, err := tw.tracer1.StartRequestTracing(ctx) + if err != nil { + return ctx, err + } + return tw.tracer2.StartRequestTracing(ctx) +} + +func (tw *tracerWrapper) EndRequestTracing(ctx context.Context) error { + err2 := tw.tracer2.EndRequestTracing(ctx) + err1 := tw.tracer1.EndRequestTracing(ctx) + if err2 != nil { + return err2 + } + return err1 +} + func (tw *tracerWrapper) StartFieldTracing(ctx context.Context) (context.Context, error) { ctx, err := tw.tracer1.StartFieldTracing(ctx) if err != nil {